add: classifyVideo & classifyVideos implementation
This commit is contained in:
parent
7946cb6e96
commit
cecc1c1d2c
@ -1,8 +1,10 @@
|
|||||||
import { Client, Transaction } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
import { Client, Transaction } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||||
|
import { AllDataType } from "lib/db/schema.d.ts";
|
||||||
import logger from "lib/log/logger.ts";
|
import logger from "lib/log/logger.ts";
|
||||||
import { formatTimestampToPsql, parseTimestampFromPsql } from "lib/utils/formatTimestampToPostgre.ts";
|
import { formatTimestampToPsql, parseTimestampFromPsql } from "lib/utils/formatTimestampToPostgre.ts";
|
||||||
import { VideoListVideo } from "lib/net/bilibili.d.ts";
|
import { VideoListVideo } from "lib/net/bilibili.d.ts";
|
||||||
import { HOUR, SECOND } from "$std/datetime/constants.ts";
|
import { HOUR, SECOND } from "$std/datetime/constants.ts";
|
||||||
|
import { modelVersion } from "lib/ml/filter_inference.ts";
|
||||||
|
|
||||||
export async function videoExistsInAllData(client: Client, aid: number) {
|
export async function videoExistsInAllData(client: Client, aid: number) {
|
||||||
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
|
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
|
||||||
@ -20,7 +22,16 @@ export async function insertIntoAllData(client: Client, data: VideoListVideo) {
|
|||||||
`INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at, duration)
|
`INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at, duration)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
ON CONFLICT (aid) DO NOTHING`,
|
ON CONFLICT (aid) DO NOTHING`,
|
||||||
[data.aid, data.bvid, data.desc, data.owner.mid, null, data.title, formatTimestampToPsql(data.pubdate * SECOND + 8 * HOUR), data.duration],
|
[
|
||||||
|
data.aid,
|
||||||
|
data.bvid,
|
||||||
|
data.desc,
|
||||||
|
data.owner.mid,
|
||||||
|
null,
|
||||||
|
data.title,
|
||||||
|
formatTimestampToPsql(data.pubdate * SECOND + 8 * HOUR),
|
||||||
|
data.duration,
|
||||||
|
],
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,3 +76,25 @@ export async function getNullVideoTagsList(client: Client) {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getUnlabeledVideos(client: Client) {
|
||||||
|
const queryResult = await client.queryObject<{ aid: number }>(
|
||||||
|
`SELECT a.aid FROM all_data a LEFT JOIN labelling_result l ON a.aid = l.aid WHERE l.aid IS NULL`,
|
||||||
|
);
|
||||||
|
return queryResult.rows.map((row) => row.aid);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function insertVideoLabel(client: Client, aid: number, label: number) {
|
||||||
|
return await client.queryObject(
|
||||||
|
`INSERT INTO labelling_result (aid, label, model_version) VALUES ($1, $2, $3) ON CONFLICT (aid, model_version) DO NOTHING`,
|
||||||
|
[aid, label, modelVersion],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getVideoInfoFromAllData(client: Client, aid: number) {
|
||||||
|
const queryResult = await client.queryObject<AllDataType>(
|
||||||
|
`SELECT * FROM all_data WHERE aid = $1`,
|
||||||
|
[aid],
|
||||||
|
);
|
||||||
|
return queryResult.rows[0];
|
||||||
|
}
|
||||||
|
@ -1,21 +1,20 @@
|
|||||||
import { AutoTokenizer } from "@huggingface/transformers";
|
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||||
import * as ort from "onnxruntime";
|
import * as ort from "onnxruntime";
|
||||||
|
import logger from "lib/log/logger.ts";
|
||||||
|
import { WorkerError } from "src/worker.ts";
|
||||||
|
|
||||||
// 模型路径和名称
|
|
||||||
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||||
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
|
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
|
||||||
const onnxEmbeddingOriginalPath = "./model/model.onnx";
|
const onnxEmbeddingOriginalPath = "./model/model.onnx";
|
||||||
export const modelVersion = "3.11";
|
export const modelVersion = "3.11";
|
||||||
|
|
||||||
// 全局变量,用于存储模型和分词器
|
|
||||||
let sessionClassifier: ort.InferenceSession | null = null;
|
let sessionClassifier: ort.InferenceSession | null = null;
|
||||||
let sessionEmbedding: ort.InferenceSession | null = null;
|
let sessionEmbedding: ort.InferenceSession | null = null;
|
||||||
let tokenizer: any | null = null;
|
let tokenizer: PreTrainedTokenizer | null = null;
|
||||||
|
|
||||||
// 初始化分词器和ONNX会话
|
export async function initializeModels() {
|
||||||
async function initializeModels() {
|
|
||||||
if (tokenizer && sessionClassifier && sessionEmbedding) {
|
if (tokenizer && sessionClassifier && sessionEmbedding) {
|
||||||
return; // 模型已加载,无需重复加载
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -30,8 +29,8 @@ async function initializeModels() {
|
|||||||
sessionClassifier = classifierSession;
|
sessionClassifier = classifierSession;
|
||||||
sessionEmbedding = embeddingSession;
|
sessionEmbedding = embeddingSession;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error initializing models:", error);
|
const e = new WorkerError(error as Error, "ml", "fn:initializeModels");
|
||||||
throw error; // 重新抛出错误,以便调用方处理
|
throw e;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -51,14 +50,12 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
|
|||||||
return_tensor: false,
|
return_tensor: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
// 构造输入参数
|
|
||||||
const cumsum = (arr: number[]): number[] =>
|
const cumsum = (arr: number[]): number[] =>
|
||||||
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||||
|
|
||||||
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
|
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
|
||||||
const flattened_input_ids = input_ids.flat();
|
const flattened_input_ids = input_ids.flat();
|
||||||
|
|
||||||
// 准备ONNX输入
|
|
||||||
const inputs = {
|
const inputs = {
|
||||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
||||||
flattened_input_ids.length,
|
flattened_input_ids.length,
|
||||||
@ -66,12 +63,11 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
|
|||||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
||||||
};
|
};
|
||||||
|
|
||||||
// 执行推理
|
|
||||||
const { embeddings } = await session.run(inputs);
|
const { embeddings } = await session.run(inputs);
|
||||||
return Array.from(embeddings.data as Float32Array);
|
return Array.from(embeddings.data as Float32Array);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 分类推理函数
|
|
||||||
async function runClassification(embeddings: number[]): Promise<number[]> {
|
async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||||
if (!sessionClassifier) {
|
if (!sessionClassifier) {
|
||||||
throw new Error("Classifier session is not initialized. Call initializeModels() first.");
|
throw new Error("Classifier session is not initialized. Call initializeModels() first.");
|
||||||
@ -85,13 +81,13 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
|
|||||||
return softmax(logits.data as Float32Array);
|
return softmax(logits.data as Float32Array);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 导出分类函数
|
|
||||||
export async function classifyVideo(
|
export async function classifyVideo(
|
||||||
title: string,
|
title: string,
|
||||||
description: string,
|
description: string,
|
||||||
tags: string,
|
tags: string,
|
||||||
author_info: string,
|
author_info: string,
|
||||||
): Promise<number[]> {
|
aid: number
|
||||||
|
): Promise<number> {
|
||||||
if (!sessionEmbedding) {
|
if (!sessionEmbedding) {
|
||||||
throw new Error("Embedding session is not initialized. Call initializeModels() first.");
|
throw new Error("Embedding session is not initialized. Call initializeModels() first.");
|
||||||
}
|
}
|
||||||
@ -103,5 +99,6 @@ export async function classifyVideo(
|
|||||||
], sessionEmbedding);
|
], sessionEmbedding);
|
||||||
|
|
||||||
const probabilities = await runClassification(embeddings);
|
const probabilities = await runClassification(embeddings);
|
||||||
return probabilities;
|
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml")
|
||||||
|
return probabilities.indexOf(Math.max(...probabilities));
|
||||||
}
|
}
|
||||||
|
30
lib/mq/exec/classifyVideo.ts
Normal file
30
lib/mq/exec/classifyVideo.ts
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import { Job } from "bullmq";
|
||||||
|
import { db } from "lib/db/init.ts";
|
||||||
|
import { getUnlabeledVideos, getVideoInfoFromAllData, insertVideoLabel} from "lib/db/allData.ts";
|
||||||
|
import { classifyVideo, initializeModels } from "lib/ml/filter_inference.ts";
|
||||||
|
import { ClassifyVideoQueue } from "lib/mq/index.ts";
|
||||||
|
|
||||||
|
export const classifyVideoWorker = async (job: Job) => {
|
||||||
|
const client = await db.connect();
|
||||||
|
const aid = job.data.aid;
|
||||||
|
if (!aid) {
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
const videoInfo = await getVideoInfoFromAllData(client, aid);
|
||||||
|
const label = await classifyVideo(videoInfo.title ?? "", videoInfo.description ?? "", videoInfo.tags ?? "", "", aid);
|
||||||
|
insertVideoLabel(client, aid, label);
|
||||||
|
|
||||||
|
client.release();
|
||||||
|
return 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const classifyVideosWorker = async () => {
|
||||||
|
await initializeModels();
|
||||||
|
const client = await db.connect();
|
||||||
|
const videos = await getUnlabeledVideos(client);
|
||||||
|
client.release();
|
||||||
|
for (const aid of videos) {
|
||||||
|
await ClassifyVideoQueue.add("classifyVideo", { aid });
|
||||||
|
}
|
||||||
|
};
|
@ -2,20 +2,21 @@ import { Job, Worker } from "bullmq";
|
|||||||
import { redis } from "lib/db/redis.ts";
|
import { redis } from "lib/db/redis.ts";
|
||||||
import logger from "lib/log/logger.ts";
|
import logger from "lib/log/logger.ts";
|
||||||
import { WorkerError } from "src/worker.ts";
|
import { WorkerError } from "src/worker.ts";
|
||||||
|
import { classifyVideosWorker, classifyVideoWorker } from "lib/mq/exec/classifyVideo.ts";
|
||||||
|
|
||||||
const filterWorker = new Worker(
|
const filterWorker = new Worker(
|
||||||
"classifyVideo",
|
"classifyVideo",
|
||||||
async (job: Job) => {
|
async (job: Job) => {
|
||||||
switch (job.name) {
|
switch (job.name) {
|
||||||
case "classifyVideo":
|
case "classifyVideo":
|
||||||
return await getVideoTagsWorker(job);
|
return await classifyVideoWorker(job);
|
||||||
case "classifyVideos":
|
case "classifyVideos":
|
||||||
return await getVideoTagsInitializer();
|
return await classifyVideosWorker();
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{ connection: redis, concurrency: 1, removeOnComplete: { count: 1440 } },
|
{ connection: redis, concurrency: 1, removeOnComplete: { count: 1000 } },
|
||||||
);
|
);
|
||||||
|
|
||||||
filterWorker.on("active", () => {
|
filterWorker.on("active", () => {
|
||||||
|
Loading…
Reference in New Issue
Block a user