add: classifyVideo & classifyVideos implementation
This commit is contained in:
parent
7946cb6e96
commit
cecc1c1d2c
@ -1,8 +1,10 @@
|
||||
import { Client, Transaction } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||
import { AllDataType } from "lib/db/schema.d.ts";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { formatTimestampToPsql, parseTimestampFromPsql } from "lib/utils/formatTimestampToPostgre.ts";
|
||||
import { VideoListVideo } from "lib/net/bilibili.d.ts";
|
||||
import { HOUR, SECOND } from "$std/datetime/constants.ts";
|
||||
import { modelVersion } from "lib/ml/filter_inference.ts";
|
||||
|
||||
export async function videoExistsInAllData(client: Client, aid: number) {
|
||||
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
|
||||
@ -20,7 +22,16 @@ export async function insertIntoAllData(client: Client, data: VideoListVideo) {
|
||||
`INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at, duration)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (aid) DO NOTHING`,
|
||||
[data.aid, data.bvid, data.desc, data.owner.mid, null, data.title, formatTimestampToPsql(data.pubdate * SECOND + 8 * HOUR), data.duration],
|
||||
[
|
||||
data.aid,
|
||||
data.bvid,
|
||||
data.desc,
|
||||
data.owner.mid,
|
||||
null,
|
||||
data.title,
|
||||
formatTimestampToPsql(data.pubdate * SECOND + 8 * HOUR),
|
||||
data.duration,
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
@ -65,3 +76,25 @@ export async function getNullVideoTagsList(client: Client) {
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function getUnlabeledVideos(client: Client) {
|
||||
const queryResult = await client.queryObject<{ aid: number }>(
|
||||
`SELECT a.aid FROM all_data a LEFT JOIN labelling_result l ON a.aid = l.aid WHERE l.aid IS NULL`,
|
||||
);
|
||||
return queryResult.rows.map((row) => row.aid);
|
||||
}
|
||||
|
||||
export async function insertVideoLabel(client: Client, aid: number, label: number) {
|
||||
return await client.queryObject(
|
||||
`INSERT INTO labelling_result (aid, label, model_version) VALUES ($1, $2, $3) ON CONFLICT (aid, model_version) DO NOTHING`,
|
||||
[aid, label, modelVersion],
|
||||
);
|
||||
}
|
||||
|
||||
export async function getVideoInfoFromAllData(client: Client, aid: number) {
|
||||
const queryResult = await client.queryObject<AllDataType>(
|
||||
`SELECT * FROM all_data WHERE aid = $1`,
|
||||
[aid],
|
||||
);
|
||||
return queryResult.rows[0];
|
||||
}
|
||||
|
@ -1,21 +1,20 @@
|
||||
import { AutoTokenizer } from "@huggingface/transformers";
|
||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||
import * as ort from "onnxruntime";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { WorkerError } from "src/worker.ts";
|
||||
|
||||
// 模型路径和名称
|
||||
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
|
||||
const onnxEmbeddingOriginalPath = "./model/model.onnx";
|
||||
export const modelVersion = "3.11";
|
||||
|
||||
// 全局变量,用于存储模型和分词器
|
||||
let sessionClassifier: ort.InferenceSession | null = null;
|
||||
let sessionEmbedding: ort.InferenceSession | null = null;
|
||||
let tokenizer: any | null = null;
|
||||
let tokenizer: PreTrainedTokenizer | null = null;
|
||||
|
||||
// 初始化分词器和ONNX会话
|
||||
async function initializeModels() {
|
||||
export async function initializeModels() {
|
||||
if (tokenizer && sessionClassifier && sessionEmbedding) {
|
||||
return; // 模型已加载,无需重复加载
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
@ -30,8 +29,8 @@ async function initializeModels() {
|
||||
sessionClassifier = classifierSession;
|
||||
sessionEmbedding = embeddingSession;
|
||||
} catch (error) {
|
||||
console.error("Error initializing models:", error);
|
||||
throw error; // 重新抛出错误,以便调用方处理
|
||||
const e = new WorkerError(error as Error, "ml", "fn:initializeModels");
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
@ -51,14 +50,12 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
|
||||
return_tensor: false,
|
||||
});
|
||||
|
||||
// 构造输入参数
|
||||
const cumsum = (arr: number[]): number[] =>
|
||||
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||
|
||||
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
|
||||
const flattened_input_ids = input_ids.flat();
|
||||
|
||||
// 准备ONNX输入
|
||||
const inputs = {
|
||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
||||
flattened_input_ids.length,
|
||||
@ -66,12 +63,11 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
|
||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
||||
};
|
||||
|
||||
// 执行推理
|
||||
const { embeddings } = await session.run(inputs);
|
||||
return Array.from(embeddings.data as Float32Array);
|
||||
}
|
||||
|
||||
// 分类推理函数
|
||||
|
||||
async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||
if (!sessionClassifier) {
|
||||
throw new Error("Classifier session is not initialized. Call initializeModels() first.");
|
||||
@ -85,13 +81,13 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||
return softmax(logits.data as Float32Array);
|
||||
}
|
||||
|
||||
// 导出分类函数
|
||||
export async function classifyVideo(
|
||||
title: string,
|
||||
description: string,
|
||||
tags: string,
|
||||
author_info: string,
|
||||
): Promise<number[]> {
|
||||
aid: number
|
||||
): Promise<number> {
|
||||
if (!sessionEmbedding) {
|
||||
throw new Error("Embedding session is not initialized. Call initializeModels() first.");
|
||||
}
|
||||
@ -103,5 +99,6 @@ export async function classifyVideo(
|
||||
], sessionEmbedding);
|
||||
|
||||
const probabilities = await runClassification(embeddings);
|
||||
return probabilities;
|
||||
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml")
|
||||
return probabilities.indexOf(Math.max(...probabilities));
|
||||
}
|
||||
|
30
lib/mq/exec/classifyVideo.ts
Normal file
30
lib/mq/exec/classifyVideo.ts
Normal file
@ -0,0 +1,30 @@
|
||||
import { Job } from "bullmq";
|
||||
import { db } from "lib/db/init.ts";
|
||||
import { getUnlabeledVideos, getVideoInfoFromAllData, insertVideoLabel} from "lib/db/allData.ts";
|
||||
import { classifyVideo, initializeModels } from "lib/ml/filter_inference.ts";
|
||||
import { ClassifyVideoQueue } from "lib/mq/index.ts";
|
||||
|
||||
export const classifyVideoWorker = async (job: Job) => {
|
||||
const client = await db.connect();
|
||||
const aid = job.data.aid;
|
||||
if (!aid) {
|
||||
return 3;
|
||||
}
|
||||
|
||||
const videoInfo = await getVideoInfoFromAllData(client, aid);
|
||||
const label = await classifyVideo(videoInfo.title ?? "", videoInfo.description ?? "", videoInfo.tags ?? "", "", aid);
|
||||
insertVideoLabel(client, aid, label);
|
||||
|
||||
client.release();
|
||||
return 0;
|
||||
};
|
||||
|
||||
export const classifyVideosWorker = async () => {
|
||||
await initializeModels();
|
||||
const client = await db.connect();
|
||||
const videos = await getUnlabeledVideos(client);
|
||||
client.release();
|
||||
for (const aid of videos) {
|
||||
await ClassifyVideoQueue.add("classifyVideo", { aid });
|
||||
}
|
||||
};
|
@ -2,20 +2,21 @@ import { Job, Worker } from "bullmq";
|
||||
import { redis } from "lib/db/redis.ts";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { WorkerError } from "src/worker.ts";
|
||||
import { classifyVideosWorker, classifyVideoWorker } from "lib/mq/exec/classifyVideo.ts";
|
||||
|
||||
const filterWorker = new Worker(
|
||||
"classifyVideo",
|
||||
async (job: Job) => {
|
||||
switch (job.name) {
|
||||
case "classifyVideo":
|
||||
return await getVideoTagsWorker(job);
|
||||
return await classifyVideoWorker(job);
|
||||
case "classifyVideos":
|
||||
return await getVideoTagsInitializer();
|
||||
return await classifyVideosWorker();
|
||||
default:
|
||||
break;
|
||||
}
|
||||
},
|
||||
{ connection: redis, concurrency: 1, removeOnComplete: { count: 1440 } },
|
||||
{ connection: redis, concurrency: 1, removeOnComplete: { count: 1000 } },
|
||||
);
|
||||
|
||||
filterWorker.on("active", () => {
|
||||
|
Loading…
Reference in New Issue
Block a user