add: classifyVideo & classifyVideos implementation

This commit is contained in:
alikia2x (寒寒) 2025-02-22 19:57:52 +08:00
parent 7946cb6e96
commit cecc1c1d2c
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
4 changed files with 81 additions and 20 deletions

View File

@ -1,8 +1,10 @@
import { Client, Transaction } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import { AllDataType } from "lib/db/schema.d.ts";
import logger from "lib/log/logger.ts";
import { formatTimestampToPsql, parseTimestampFromPsql } from "lib/utils/formatTimestampToPostgre.ts";
import { VideoListVideo } from "lib/net/bilibili.d.ts";
import { HOUR, SECOND } from "$std/datetime/constants.ts";
import { modelVersion } from "lib/ml/filter_inference.ts";
export async function videoExistsInAllData(client: Client, aid: number) {
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
@ -20,7 +22,16 @@ export async function insertIntoAllData(client: Client, data: VideoListVideo) {
`INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at, duration)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (aid) DO NOTHING`,
[data.aid, data.bvid, data.desc, data.owner.mid, null, data.title, formatTimestampToPsql(data.pubdate * SECOND + 8 * HOUR), data.duration],
[
data.aid,
data.bvid,
data.desc,
data.owner.mid,
null,
data.title,
formatTimestampToPsql(data.pubdate * SECOND + 8 * HOUR),
data.duration,
],
);
}
@ -65,3 +76,25 @@ export async function getNullVideoTagsList(client: Client) {
},
);
}
export async function getUnlabeledVideos(client: Client) {
const queryResult = await client.queryObject<{ aid: number }>(
`SELECT a.aid FROM all_data a LEFT JOIN labelling_result l ON a.aid = l.aid WHERE l.aid IS NULL`,
);
return queryResult.rows.map((row) => row.aid);
}
export async function insertVideoLabel(client: Client, aid: number, label: number) {
return await client.queryObject(
`INSERT INTO labelling_result (aid, label, model_version) VALUES ($1, $2, $3) ON CONFLICT (aid, model_version) DO NOTHING`,
[aid, label, modelVersion],
);
}
export async function getVideoInfoFromAllData(client: Client, aid: number) {
const queryResult = await client.queryObject<AllDataType>(
`SELECT * FROM all_data WHERE aid = $1`,
[aid],
);
return queryResult.rows[0];
}

View File

@ -1,21 +1,20 @@
import { AutoTokenizer } from "@huggingface/transformers";
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime";
import logger from "lib/log/logger.ts";
import { WorkerError } from "src/worker.ts";
// 模型路径和名称
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
const onnxEmbeddingOriginalPath = "./model/model.onnx";
export const modelVersion = "3.11";
// 全局变量,用于存储模型和分词器
let sessionClassifier: ort.InferenceSession | null = null;
let sessionEmbedding: ort.InferenceSession | null = null;
let tokenizer: any | null = null;
let tokenizer: PreTrainedTokenizer | null = null;
// 初始化分词器和ONNX会话
async function initializeModels() {
export async function initializeModels() {
if (tokenizer && sessionClassifier && sessionEmbedding) {
return; // 模型已加载,无需重复加载
return;
}
try {
@ -30,8 +29,8 @@ async function initializeModels() {
sessionClassifier = classifierSession;
sessionEmbedding = embeddingSession;
} catch (error) {
console.error("Error initializing models:", error);
throw error; // 重新抛出错误,以便调用方处理
const e = new WorkerError(error as Error, "ml", "fn:initializeModels");
throw e;
}
}
@ -51,14 +50,12 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
return_tensor: false,
});
// 构造输入参数
const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const flattened_input_ids = input_ids.flat();
// 准备ONNX输入
const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
flattened_input_ids.length,
@ -66,12 +63,11 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
};
// 执行推理
const { embeddings } = await session.run(inputs);
return Array.from(embeddings.data as Float32Array);
}
// 分类推理函数
async function runClassification(embeddings: number[]): Promise<number[]> {
if (!sessionClassifier) {
throw new Error("Classifier session is not initialized. Call initializeModels() first.");
@ -85,13 +81,13 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
return softmax(logits.data as Float32Array);
}
// 导出分类函数
export async function classifyVideo(
title: string,
description: string,
tags: string,
author_info: string,
): Promise<number[]> {
aid: number
): Promise<number> {
if (!sessionEmbedding) {
throw new Error("Embedding session is not initialized. Call initializeModels() first.");
}
@ -103,5 +99,6 @@ export async function classifyVideo(
], sessionEmbedding);
const probabilities = await runClassification(embeddings);
return probabilities;
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml")
return probabilities.indexOf(Math.max(...probabilities));
}

View File

@ -0,0 +1,30 @@
import { Job } from "bullmq";
import { db } from "lib/db/init.ts";
import { getUnlabeledVideos, getVideoInfoFromAllData, insertVideoLabel} from "lib/db/allData.ts";
import { classifyVideo, initializeModels } from "lib/ml/filter_inference.ts";
import { ClassifyVideoQueue } from "lib/mq/index.ts";
export const classifyVideoWorker = async (job: Job) => {
const client = await db.connect();
const aid = job.data.aid;
if (!aid) {
return 3;
}
const videoInfo = await getVideoInfoFromAllData(client, aid);
const label = await classifyVideo(videoInfo.title ?? "", videoInfo.description ?? "", videoInfo.tags ?? "", "", aid);
insertVideoLabel(client, aid, label);
client.release();
return 0;
};
export const classifyVideosWorker = async () => {
await initializeModels();
const client = await db.connect();
const videos = await getUnlabeledVideos(client);
client.release();
for (const aid of videos) {
await ClassifyVideoQueue.add("classifyVideo", { aid });
}
};

View File

@ -2,20 +2,21 @@ import { Job, Worker } from "bullmq";
import { redis } from "lib/db/redis.ts";
import logger from "lib/log/logger.ts";
import { WorkerError } from "src/worker.ts";
import { classifyVideosWorker, classifyVideoWorker } from "lib/mq/exec/classifyVideo.ts";
const filterWorker = new Worker(
"classifyVideo",
async (job: Job) => {
switch (job.name) {
case "classifyVideo":
return await getVideoTagsWorker(job);
return await classifyVideoWorker(job);
case "classifyVideos":
return await getVideoTagsInitializer();
return await classifyVideosWorker();
default:
break;
}
},
{ connection: redis, concurrency: 1, removeOnComplete: { count: 1440 } },
{ connection: redis, concurrency: 1, removeOnComplete: { count: 1000 } },
);
filterWorker.on("active", () => {