From cecc1c1d2ca9f242f72935723f814a1a1d70f26b Mon Sep 17 00:00:00 2001 From: alikia2x Date: Sat, 22 Feb 2025 19:57:52 +0800 Subject: [PATCH] add: classifyVideo & classifyVideos implementation --- lib/db/allData.ts | 35 ++++++++++++++++++++++++++++++++++- lib/ml/filter_inference.ts | 29 +++++++++++++---------------- lib/mq/exec/classifyVideo.ts | 30 ++++++++++++++++++++++++++++++ src/filterWorker.ts | 7 ++++--- 4 files changed, 81 insertions(+), 20 deletions(-) create mode 100644 lib/mq/exec/classifyVideo.ts diff --git a/lib/db/allData.ts b/lib/db/allData.ts index ab3f6e6..6c39256 100644 --- a/lib/db/allData.ts +++ b/lib/db/allData.ts @@ -1,8 +1,10 @@ import { Client, Transaction } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; +import { AllDataType } from "lib/db/schema.d.ts"; import logger from "lib/log/logger.ts"; import { formatTimestampToPsql, parseTimestampFromPsql } from "lib/utils/formatTimestampToPostgre.ts"; import { VideoListVideo } from "lib/net/bilibili.d.ts"; import { HOUR, SECOND } from "$std/datetime/constants.ts"; +import { modelVersion } from "lib/ml/filter_inference.ts"; export async function videoExistsInAllData(client: Client, aid: number) { return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid]) @@ -20,7 +22,16 @@ export async function insertIntoAllData(client: Client, data: VideoListVideo) { `INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at, duration) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (aid) DO NOTHING`, - [data.aid, data.bvid, data.desc, data.owner.mid, null, data.title, formatTimestampToPsql(data.pubdate * SECOND + 8 * HOUR), data.duration], + [ + data.aid, + data.bvid, + data.desc, + data.owner.mid, + null, + data.title, + formatTimestampToPsql(data.pubdate * SECOND + 8 * HOUR), + data.duration, + ], ); } @@ -65,3 +76,25 @@ export async function getNullVideoTagsList(client: Client) { }, ); } + +export async function getUnlabeledVideos(client: Client) { + const queryResult = await client.queryObject<{ aid: number }>( + `SELECT a.aid FROM all_data a LEFT JOIN labelling_result l ON a.aid = l.aid WHERE l.aid IS NULL`, + ); + return queryResult.rows.map((row) => row.aid); +} + +export async function insertVideoLabel(client: Client, aid: number, label: number) { + return await client.queryObject( + `INSERT INTO labelling_result (aid, label, model_version) VALUES ($1, $2, $3) ON CONFLICT (aid, model_version) DO NOTHING`, + [aid, label, modelVersion], + ); +} + +export async function getVideoInfoFromAllData(client: Client, aid: number) { + const queryResult = await client.queryObject( + `SELECT * FROM all_data WHERE aid = $1`, + [aid], + ); + return queryResult.rows[0]; +} diff --git a/lib/ml/filter_inference.ts b/lib/ml/filter_inference.ts index d59c798..4f7cb25 100644 --- a/lib/ml/filter_inference.ts +++ b/lib/ml/filter_inference.ts @@ -1,21 +1,20 @@ -import { AutoTokenizer } from "@huggingface/transformers"; +import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import * as ort from "onnxruntime"; +import logger from "lib/log/logger.ts"; +import { WorkerError } from "src/worker.ts"; -// 模型路径和名称 const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; const onnxClassifierPath = "./model/video_classifier_v3_11.onnx"; const onnxEmbeddingOriginalPath = "./model/model.onnx"; export const modelVersion = "3.11"; -// 全局变量,用于存储模型和分词器 let sessionClassifier: ort.InferenceSession | null = null; let sessionEmbedding: ort.InferenceSession | null = null; -let tokenizer: any | null = null; +let tokenizer: PreTrainedTokenizer | null = null; -// 初始化分词器和ONNX会话 -async function initializeModels() { +export async function initializeModels() { if (tokenizer && sessionClassifier && sessionEmbedding) { - return; // 模型已加载,无需重复加载 + return; } try { @@ -30,8 +29,8 @@ async function initializeModels() { sessionClassifier = classifierSession; sessionEmbedding = embeddingSession; } catch (error) { - console.error("Error initializing models:", error); - throw error; // 重新抛出错误,以便调用方处理 + const e = new WorkerError(error as Error, "ml", "fn:initializeModels"); + throw e; } } @@ -51,14 +50,12 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession) return_tensor: false, }); - // 构造输入参数 const cumsum = (arr: number[]): number[] => arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))]; const flattened_input_ids = input_ids.flat(); - // 准备ONNX输入 const inputs = { input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ flattened_input_ids.length, @@ -66,12 +63,11 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession) offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]), }; - // 执行推理 const { embeddings } = await session.run(inputs); return Array.from(embeddings.data as Float32Array); } -// 分类推理函数 + async function runClassification(embeddings: number[]): Promise { if (!sessionClassifier) { throw new Error("Classifier session is not initialized. Call initializeModels() first."); @@ -85,13 +81,13 @@ async function runClassification(embeddings: number[]): Promise { return softmax(logits.data as Float32Array); } -// 导出分类函数 export async function classifyVideo( title: string, description: string, tags: string, author_info: string, -): Promise { + aid: number +): Promise { if (!sessionEmbedding) { throw new Error("Embedding session is not initialized. Call initializeModels() first."); } @@ -103,5 +99,6 @@ export async function classifyVideo( ], sessionEmbedding); const probabilities = await runClassification(embeddings); - return probabilities; + logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml") + return probabilities.indexOf(Math.max(...probabilities)); } diff --git a/lib/mq/exec/classifyVideo.ts b/lib/mq/exec/classifyVideo.ts new file mode 100644 index 0000000..848dd7b --- /dev/null +++ b/lib/mq/exec/classifyVideo.ts @@ -0,0 +1,30 @@ +import { Job } from "bullmq"; +import { db } from "lib/db/init.ts"; +import { getUnlabeledVideos, getVideoInfoFromAllData, insertVideoLabel} from "lib/db/allData.ts"; +import { classifyVideo, initializeModels } from "lib/ml/filter_inference.ts"; +import { ClassifyVideoQueue } from "lib/mq/index.ts"; + +export const classifyVideoWorker = async (job: Job) => { + const client = await db.connect(); + const aid = job.data.aid; + if (!aid) { + return 3; + } + + const videoInfo = await getVideoInfoFromAllData(client, aid); + const label = await classifyVideo(videoInfo.title ?? "", videoInfo.description ?? "", videoInfo.tags ?? "", "", aid); + insertVideoLabel(client, aid, label); + + client.release(); + return 0; +}; + +export const classifyVideosWorker = async () => { + await initializeModels(); + const client = await db.connect(); + const videos = await getUnlabeledVideos(client); + client.release(); + for (const aid of videos) { + await ClassifyVideoQueue.add("classifyVideo", { aid }); + } +}; diff --git a/src/filterWorker.ts b/src/filterWorker.ts index d0b7e2a..79a06db 100644 --- a/src/filterWorker.ts +++ b/src/filterWorker.ts @@ -2,20 +2,21 @@ import { Job, Worker } from "bullmq"; import { redis } from "lib/db/redis.ts"; import logger from "lib/log/logger.ts"; import { WorkerError } from "src/worker.ts"; +import { classifyVideosWorker, classifyVideoWorker } from "lib/mq/exec/classifyVideo.ts"; const filterWorker = new Worker( "classifyVideo", async (job: Job) => { switch (job.name) { case "classifyVideo": - return await getVideoTagsWorker(job); + return await classifyVideoWorker(job); case "classifyVideos": - return await getVideoTagsInitializer(); + return await classifyVideosWorker(); default: break; } }, - { connection: redis, concurrency: 1, removeOnComplete: { count: 1440 } }, + { connection: redis, concurrency: 1, removeOnComplete: { count: 1000 } }, ); filterWorker.on("active", () => {