diff --git a/lib/db/allData.ts b/lib/db/allData.ts index 8e30780..26840fb 100644 --- a/lib/db/allData.ts +++ b/lib/db/allData.ts @@ -1,6 +1,6 @@ import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; import { AllDataType, BiliUserType } from "lib/db/schema.d.ts"; -import { modelVersion } from "lib/ml/filter_inference.ts"; +import Akari from "lib/ml/akari.ts"; export async function videoExistsInAllData(client: Client, aid: number) { return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid]) @@ -23,7 +23,7 @@ export async function getUnlabelledVideos(client: Client) { export async function insertVideoLabel(client: Client, aid: number, label: number) { return await client.queryObject( `INSERT INTO labelling_result (aid, label, model_version) VALUES ($1, $2, $3) ON CONFLICT (aid, model_version) DO NOTHING`, - [aid, label, modelVersion], + [aid, label, Akari.getModelVersion()], ); } diff --git a/lib/db/snapshot.ts b/lib/db/snapshot.ts index 663a628..c3f515b 100644 --- a/lib/db/snapshot.ts +++ b/lib/db/snapshot.ts @@ -28,7 +28,8 @@ export async function getSongsNearMilestone(client: Client) { max_views_per_aid WHERE (max_views >= 90000 AND max_views < 100000) OR - (max_views >= 900000 AND max_views < 1000000) + (max_views >= 900000 AND max_views < 1000000) OR + (max_views >= 9900000 AND max_views < 10000000) ) -- 获取符合条件的完整行数据 SELECT diff --git a/lib/db/snapshotSchedule.ts b/lib/db/snapshotSchedule.ts index c719eb6..111ffa1 100644 --- a/lib/db/snapshotSchedule.ts +++ b/lib/db/snapshotSchedule.ts @@ -1,11 +1,12 @@ import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; -export async function getUnsnapshotedSongs(client: Client) { - const queryResult = await client.queryObject<{ aid: bigint }>(` - SELECT DISTINCT s.aid - FROM songs s - LEFT JOIN video_snapshot v ON s.aid = v.aid - WHERE v.aid IS NULL; - `); - return queryResult.rows.map((row) => Number(row.aid)); -} +/* + Returns true if the specified `aid` has at least one record with "pending" or "processing" status. +*/ +export async function videoHasActiveSchedule(client: Client, aid: number) { + const res = await client.queryObject<{ status: string }>( + `SELECT status FROM snapshot_schedule WHERE aid = $1 AND (status = 'pending' OR status = 'processing')`, + [aid], + ); + return res.rows.length > 0; +} \ No newline at end of file diff --git a/lib/ml/akari.ts b/lib/ml/akari.ts new file mode 100644 index 0000000..386bb56 --- /dev/null +++ b/lib/ml/akari.ts @@ -0,0 +1,106 @@ +import { AIManager } from "lib/ml/manager.ts"; +import * as ort from "onnxruntime"; +import logger from "lib/log/logger.ts"; +import { WorkerError } from "lib/mq/schema.ts"; +import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; + +const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; +const onnxClassifierPath = "./model/video_classifier_v3_17.onnx"; +const onnxEmbeddingPath = "./model/model.onnx"; + +class AkariProto extends AIManager { + private tokenizer: PreTrainedTokenizer | null = null; + private readonly modelVersion = "3.17"; + + constructor() { + super(); + this.models = { + "classifier": onnxClassifierPath, + "embedding": onnxEmbeddingPath, + } + } + + public override async init(): Promise { + super.init(); + await this.initJinaTokenizer(); + } + + private tokenizerInitialized(): boolean { + return this.tokenizer !== null; + } + + private getTokenizer(): PreTrainedTokenizer { + if (!this.tokenizerInitialized()) { + throw new Error("Tokenizer is not initialized. Call init() first."); + } + return this.tokenizer!; + } + + private async initJinaTokenizer(): Promise { + if (this.tokenizerInitialized()) { + return; + } + try { + this.tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel); + logger.log("Tokenizer initialized", "ml"); + } catch (error) { + throw new WorkerError(error as Error, "ml", "fn:initTokenizer"); + } + } + + private async getJinaEmbeddings1024(texts: string[]): Promise { + const tokenizer = this.getTokenizer(); + const session = this.getModelSession("embedding"); + + const { input_ids } = await tokenizer(texts, { + add_special_tokens: false, + return_tensors: "js", + }); + + const cumsum = (arr: number[]): number[] => + arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); + + const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string[]) => x.length))]; + const flattened_input_ids = input_ids.flat(); + + const inputs = { + input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ + flattened_input_ids.length, + ]), + offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]), + }; + + const { embeddings } = await session.run(inputs); + return Array.from(embeddings.data as Float32Array); + } + + private async runClassification(embeddings: number[]): Promise { + const session = this.getModelSession("classifier"); + const inputTensor = new ort.Tensor( + Float32Array.from(embeddings), + [1, 3, 1024], + ); + + const { logits } = await session.run({ channel_features: inputTensor }); + return this.softmax(logits.data as Float32Array); + } + + public async classifyVideo(title: string, description: string, tags: string, aid: number): Promise { + const embeddings = await this.getJinaEmbeddings1024([ + title, + description, + tags, + ]); + const probabilities = await this.runClassification(embeddings); + logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml"); + return probabilities.indexOf(Math.max(...probabilities)); + } + + public getModelVersion(): string { + return this.modelVersion; + } +} + +const Akari = new AkariProto(); +export default Akari; + diff --git a/lib/ml/benchmark.ts b/lib/ml/benchmark.ts index 0cfc193..3911c31 100644 --- a/lib/ml/benchmark.ts +++ b/lib/ml/benchmark.ts @@ -1,6 +1,13 @@ import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import * as ort from "onnxruntime"; -import { softmax } from "lib/ml/filter_inference.ts"; + + +function softmax(logits: Float32Array): number[] { + const maxLogit = Math.max(...logits); + const exponents = logits.map((logit) => Math.exp(logit - maxLogit)); + const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0); + return Array.from(exponents.map((exp) => exp / sumOfExponents)); +} // 配置参数 const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024"; diff --git a/lib/ml/filter_inference.ts b/lib/ml/filter_inference.ts deleted file mode 100644 index 019061f..0000000 --- a/lib/ml/filter_inference.ts +++ /dev/null @@ -1,99 +0,0 @@ -import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; -import * as ort from "onnxruntime"; -import logger from "lib/log/logger.ts"; -import { WorkerError } from "lib/mq/schema.ts"; - -const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; -const onnxClassifierPath = "./model/video_classifier_v3_17.onnx"; -const onnxEmbeddingOriginalPath = "./model/model.onnx"; -export const modelVersion = "3.17"; - -let sessionClassifier: ort.InferenceSession | null = null; -let sessionEmbedding: ort.InferenceSession | null = null; -let tokenizer: PreTrainedTokenizer | null = null; - -export async function initializeModels() { - if (tokenizer && sessionClassifier && sessionEmbedding) { - return; - } - - try { - tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel); - - const [classifierSession, embeddingSession] = await Promise.all([ - ort.InferenceSession.create(onnxClassifierPath), - ort.InferenceSession.create(onnxEmbeddingOriginalPath), - ]); - - sessionClassifier = classifierSession; - sessionEmbedding = embeddingSession; - logger.log("Filter models initialized", "ml"); - } catch (error) { - throw new WorkerError(error as Error, "ml", "fn:initializeModels"); - } -} - -export function softmax(logits: Float32Array): number[] { - const maxLogit = Math.max(...logits); - const exponents = logits.map((logit) => Math.exp(logit - maxLogit)); - const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0); - return Array.from(exponents.map((exp) => exp / sumOfExponents)); -} - -async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise { - if (!tokenizer) { - throw new Error("Tokenizer is not initialized. Call initializeModels() first."); - } - const { input_ids } = await tokenizer(texts, { - add_special_tokens: false, - return_tensor: false, - }); - - const cumsum = (arr: number[]): number[] => - arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); - - const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))]; - const flattened_input_ids = input_ids.flat(); - - const inputs = { - input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ - flattened_input_ids.length, - ]), - offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]), - }; - - const { embeddings } = await session.run(inputs); - return Array.from(embeddings.data as Float32Array); -} - -async function runClassification(embeddings: number[]): Promise { - if (!sessionClassifier) { - throw new Error("Classifier session is not initialized. Call initializeModels() first."); - } - const inputTensor = new ort.Tensor( - Float32Array.from(embeddings), - [1, 3, 1024], - ); - - const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); - return softmax(logits.data as Float32Array); -} - -export async function classifyVideo( - title: string, - description: string, - tags: string, - aid: number, -): Promise { - if (!sessionEmbedding) { - throw new Error("Embedding session is not initialized. Call initializeModels() first."); - } - const embeddings = await getONNXEmbeddings([ - title, - description, - tags, - ], sessionEmbedding); - const probabilities = await runClassification(embeddings); - logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml"); - return probabilities.indexOf(Math.max(...probabilities)); -} diff --git a/lib/ml/manager.ts b/lib/ml/manager.ts new file mode 100644 index 0000000..268985d --- /dev/null +++ b/lib/ml/manager.ts @@ -0,0 +1,37 @@ +import * as ort from "onnxruntime"; +import logger from "lib/log/logger.ts"; +import { WorkerError } from "lib/mq/schema.ts"; + +export class AIManager { + public sessions: { [key: string]: ort.InferenceSession } = {}; + public models: { [key: string]: string } = {}; + + constructor() { + } + + public async init() { + const modelKeys = Object.keys(this.models); + for (const key of modelKeys) { + try { + this.sessions[key] = await ort.InferenceSession.create(this.models[key]); + logger.log(`Model ${key} initialized`, "ml"); + } catch (error) { + throw new WorkerError(error as Error, "ml", "fn:init"); + } + } + } + + public getModelSession(key: string): ort.InferenceSession { + if (!this.sessions[key]) { + throw new WorkerError(new Error(`Model ${key} not found / not initialized.`), "ml", "fn:getModelSession"); + } + return this.sessions[key]; + } + + public softmax(logits: Float32Array): number[] { + const maxLogit = Math.max(...logits); + const exponents = logits.map((logit) => Math.exp(logit - maxLogit)); + const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0); + return Array.from(exponents.map((exp) => exp / sumOfExponents)); + } +} diff --git a/lib/ml/quant_benchmark.ts b/lib/ml/quant_benchmark.ts index bcc5044..aab6308 100644 --- a/lib/ml/quant_benchmark.ts +++ b/lib/ml/quant_benchmark.ts @@ -1,6 +1,12 @@ import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import * as ort from "onnxruntime"; -import { softmax } from "lib/ml/filter_inference.ts"; + +function softmax(logits: Float32Array): number[] { + const maxLogit = Math.max(...logits); + const exponents = logits.map((logit) => Math.exp(logit - maxLogit)); + const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0); + return Array.from(exponents.map((exp) => exp / sumOfExponents)); +} // 配置参数 const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024"; diff --git a/lib/mq/exec/classifyVideo.ts b/lib/mq/exec/classifyVideo.ts index 3541892..6649931 100644 --- a/lib/mq/exec/classifyVideo.ts +++ b/lib/mq/exec/classifyVideo.ts @@ -1,7 +1,7 @@ import { Job } from "bullmq"; import { db } from "lib/db/init.ts"; import { getUnlabelledVideos, getVideoInfoFromAllData, insertVideoLabel } from "lib/db/allData.ts"; -import { classifyVideo } from "lib/ml/filter_inference.ts"; +import Akari from "lib/ml/akari.ts"; import { ClassifyVideoQueue } from "lib/mq/index.ts"; import logger from "lib/log/logger.ts"; import { lockManager } from "lib/mq/lockManager.ts"; @@ -19,7 +19,7 @@ export const classifyVideoWorker = async (job: Job) => { const title = videoInfo.title?.trim() || "untitled"; const description = videoInfo.description?.trim() || "N/A"; const tags = videoInfo.tags?.trim() || "empty"; - const label = await classifyVideo(title, description, tags, aid); + const label = await Akari.classifyVideo(title, description, tags, aid); if (label == -1) { logger.warn(`Failed to classify video ${aid}`, "ml"); } diff --git a/src/filterWorker.ts b/src/filterWorker.ts index 8eb43d4..cb42048 100644 --- a/src/filterWorker.ts +++ b/src/filterWorker.ts @@ -4,7 +4,7 @@ import logger from "lib/log/logger.ts"; import { classifyVideosWorker, classifyVideoWorker } from "lib/mq/exec/classifyVideo.ts"; import { WorkerError } from "lib/mq/schema.ts"; import { lockManager } from "lib/mq/lockManager.ts"; -import { initializeModels } from "lib/ml/filter_inference.ts"; +import Akari from "lib/ml/akari.ts"; Deno.addSignalListener("SIGINT", async () => { logger.log("SIGINT Received: Shutting down workers...", "mq"); @@ -18,7 +18,7 @@ Deno.addSignalListener("SIGTERM", async () => { Deno.exit(); }); -await initializeModels(); +Akari.init(); const filterWorker = new Worker( "classifyVideo",