From fa414e89ceec1250f72685e7e28f10abb221ec3a Mon Sep 17 00:00:00 2001 From: alikia2x Date: Sat, 8 Mar 2025 00:55:29 +0800 Subject: [PATCH] add: insert labelled songs into songs table --- lib/db/allData.ts | 21 +-- lib/db/songs.ts | 29 ++++ lib/ml/benchmark.ts | 228 ++++++++++++++++---------------- lib/mq/exec/classifyVideo.ts | 7 + lib/mq/exec/getBiliUserInfo.ts | 0 lib/mq/exec/getLatestVideos.ts | 25 ++++ lib/mq/exec/getVideoInfo.ts | 17 --- lib/mq/init.ts | 4 + lib/mq/task/collectSongs.ts | 29 ++++ lib/mq/task/getVideoInfo.ts | 5 +- lib/mq/task/queueLatestVideo.ts | 7 +- src/worker.ts | 7 +- 12 files changed, 231 insertions(+), 148 deletions(-) create mode 100644 lib/db/songs.ts delete mode 100644 lib/mq/exec/getBiliUserInfo.ts delete mode 100644 lib/mq/exec/getVideoInfo.ts create mode 100644 lib/mq/task/collectSongs.ts diff --git a/lib/db/allData.ts b/lib/db/allData.ts index 7c4d990..8e30780 100644 --- a/lib/db/allData.ts +++ b/lib/db/allData.ts @@ -1,6 +1,6 @@ -import {Client} from "https://deno.land/x/postgres@v0.19.3/mod.ts"; -import {AllDataType, BiliUserType} from "lib/db/schema.d.ts"; -import {modelVersion} from "lib/ml/filter_inference.ts"; +import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; +import { AllDataType, BiliUserType } from "lib/db/schema.d.ts"; +import { modelVersion } from "lib/ml/filter_inference.ts"; export async function videoExistsInAllData(client: Client, aid: number) { return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid]) @@ -8,7 +8,9 @@ export async function videoExistsInAllData(client: Client, aid: number) { } export async function userExistsInBiliUsers(client: Client, uid: number) { - return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM bili_user WHERE uid = $1)`, [uid]) + return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM bili_user WHERE uid = $1)`, [ + uid, + ]); } export async function getUnlabelledVideos(client: Client) { @@ -36,28 +38,29 @@ export async function getVideoInfoFromAllData(client: Client, aid: number) { const q = await client.queryObject( `SELECT * FROM bili_user WHERE uid = $1`, [row.uid], - ) + ); const userRow = q.rows[0]; - if (userRow) + if (userRow) { authorInfo = userRow.desc; + } } return { title: row.title, description: row.description, tags: row.tags, - author_info: authorInfo + author_info: authorInfo, }; } export async function getUnArchivedBiliUsers(client: Client) { - const queryResult = await client.queryObject<{uid: number}>( + const queryResult = await client.queryObject<{ uid: number }>( ` SELECT ad.uid FROM all_data ad LEFT JOIN bili_user bu ON ad.uid = bu.uid WHERE bu.uid IS NULL; `, - [] + [], ); const rows = queryResult.rows; return rows.map((row) => row.uid); diff --git a/lib/db/songs.ts b/lib/db/songs.ts new file mode 100644 index 0000000..0d5a096 --- /dev/null +++ b/lib/db/songs.ts @@ -0,0 +1,29 @@ +import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; + +export async function getNotCollectedSongs(client: Client) { + const queryResult = await client.queryObject<{ aid: number }>(` + SELECT lr.aid + FROM labelling_result lr + WHERE lr.label != 0 + AND NOT EXISTS ( + SELECT 1 + FROM songs s + WHERE s.aid = lr.aid + ); + `); + return queryResult.rows.map((row) => row.aid); +} + +export async function aidExistsInSongs(client: Client, aid: number) { + const queryResult = await client.queryObject<{ exists: boolean }>( + ` + SELECT EXISTS ( + SELECT 1 + FROM songs + WHERE aid = $1 + ); + `, + [aid], + ); + return queryResult.rows[0].exists; +} diff --git a/lib/ml/benchmark.ts b/lib/ml/benchmark.ts index 478b224..0cfc193 100644 --- a/lib/ml/benchmark.ts +++ b/lib/ml/benchmark.ts @@ -10,164 +10,164 @@ const testDataPath = "./data/filter/test1.jsonl"; // 初始化会话 const [sessionClassifier, sessionEmbedding] = await Promise.all([ - ort.InferenceSession.create(onnxClassifierPath), - ort.InferenceSession.create(onnxEmbeddingPath), + ort.InferenceSession.create(onnxClassifierPath), + ort.InferenceSession.create(onnxEmbeddingPath), ]); let tokenizer: PreTrainedTokenizer; // 初始化分词器 async function loadTokenizer() { - const tokenizerConfig = { local_files_only: true }; - tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig); + const tokenizerConfig = { local_files_only: true }; + tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig); } // 新的嵌入生成函数(使用ONNX) async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise { - const { input_ids } = await tokenizer(texts, { - add_special_tokens: false, - return_tensor: false, - }); + const { input_ids } = await tokenizer(texts, { + add_special_tokens: false, + return_tensor: false, + }); - // 构造输入参数 - const cumsum = (arr: number[]): number[] => - arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); + // 构造输入参数 + const cumsum = (arr: number[]): number[] => + arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); - const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))]; - const flattened_input_ids = input_ids.flat(); + const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))]; + const flattened_input_ids = input_ids.flat(); - // 准备ONNX输入 - const inputs = { - input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ - flattened_input_ids.length, - ]), - offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]), - }; + // 准备ONNX输入 + const inputs = { + input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ + flattened_input_ids.length, + ]), + offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]), + }; - // 执行推理 - const { embeddings } = await session.run(inputs); - return Array.from(embeddings.data as Float32Array); + // 执行推理 + const { embeddings } = await session.run(inputs); + return Array.from(embeddings.data as Float32Array); } // 分类推理函数 async function runClassification(embeddings: number[]): Promise { - const inputTensor = new ort.Tensor( - Float32Array.from(embeddings), - [1, 3, 1024], - ); + const inputTensor = new ort.Tensor( + Float32Array.from(embeddings), + [1, 3, 1024], + ); - const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); - return softmax(logits.data as Float32Array); + const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); + return softmax(logits.data as Float32Array); } // 指标计算函数 function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): { - accuracy: number; - precision: number; - recall: number; - f1: number; - "Class 0 Prec": number; - speed: string; + accuracy: number; + precision: number; + recall: number; + f1: number; + "Class 0 Prec": number; + speed: string; } { - // 输出label和prediction不一样的index列表 - const arr = [] - for (let i = 0; i < labels.length; i++) { - if (labels[i] !== predictions[i] && predictions[i] == 0) { - arr.push([i + 1, labels[i], predictions[i]]) - } - } - console.log(arr) - // 初始化混淆矩阵 - const classCount = Math.max(...labels, ...predictions) + 1; - const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0)); + // 输出label和prediction不一样的index列表 + const arr = []; + for (let i = 0; i < labels.length; i++) { + if (labels[i] !== predictions[i] && predictions[i] == 0) { + arr.push([i + 1, labels[i], predictions[i]]); + } + } + console.log(arr); + // 初始化混淆矩阵 + const classCount = Math.max(...labels, ...predictions) + 1; + const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0)); - // 填充矩阵 - labels.forEach((trueLabel, i) => { - matrix[trueLabel][predictions[i]]++; - }); + // 填充矩阵 + labels.forEach((trueLabel, i) => { + matrix[trueLabel][predictions[i]]++; + }); - // 计算各指标 - let totalTP = 0, totalFP = 0, totalFN = 0; + // 计算各指标 + let totalTP = 0, totalFP = 0, totalFN = 0; - for (let c = 0; c < classCount; c++) { - const TP = matrix[c][c]; - const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0); - const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0); + for (let c = 0; c < classCount; c++) { + const TP = matrix[c][c]; + const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0); + const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0); - totalTP += TP; - totalFP += FP; - totalFN += FN; - } + totalTP += TP; + totalFP += FP; + totalFN += FN; + } - const precision = totalTP / (totalTP + totalFP); - const recall = totalTP / (totalTP + totalFN); - const f1 = 2 * (precision * recall) / (precision + recall) || 0; + const precision = totalTP / (totalTP + totalFP); + const recall = totalTP / (totalTP + totalFN); + const f1 = 2 * (precision * recall) / (precision + recall) || 0; - // 计算Class 0 Precision - const class0TP = matrix[0][0]; - const class0FP = matrix.flatMap((row, i) => i === 0 ? [] : [row[0]]).reduce((a, b) => a + b, 0); - const class0Precision = class0TP / (class0TP + class0FP) || 0; + // 计算Class 0 Precision + const class0TP = matrix[0][0]; + const class0FP = matrix.flatMap((row, i) => i === 0 ? [] : [row[0]]).reduce((a, b) => a + b, 0); + const class0Precision = class0TP / (class0TP + class0FP) || 0; - return { - accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length, - precision, - recall, - f1, - speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`, - "Class 0 Prec": class0Precision, - }; + return { + accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length, + precision, + recall, + f1, + speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`, + "Class 0 Prec": class0Precision, + }; } // 改造后的评估函数 async function evaluateModel(session: ort.InferenceSession): Promise<{ - accuracy: number; - precision: number; - recall: number; - f1: number; - "Class 0 Prec": number; + accuracy: number; + precision: number; + recall: number; + f1: number; + "Class 0 Prec": number; }> { - const data = await Deno.readTextFile(testDataPath); - const samples = data.split("\n") - .map((line) => { - try { - return JSON.parse(line); - } catch { - return null; - } - }) - .filter(Boolean); + const data = await Deno.readTextFile(testDataPath); + const samples = data.split("\n") + .map((line) => { + try { + return JSON.parse(line); + } catch { + return null; + } + }) + .filter(Boolean); - const allPredictions: number[] = []; - const allLabels: number[] = []; + const allPredictions: number[] = []; + const allLabels: number[] = []; - const t = new Date().getTime(); - for (const sample of samples) { - try { - const embeddings = await getONNXEmbeddings([ - sample.title, - sample.description, - sample.tags.join(",") - ], session); + const t = new Date().getTime(); + for (const sample of samples) { + try { + const embeddings = await getONNXEmbeddings([ + sample.title, + sample.description, + sample.tags.join(","), + ], session); - const probabilities = await runClassification(embeddings); - allPredictions.push(probabilities.indexOf(Math.max(...probabilities))); - allLabels.push(sample.label); - } catch (error) { - console.error("Processing error:", error); - } - } - const elapsed = new Date().getTime() - t; + const probabilities = await runClassification(embeddings); + allPredictions.push(probabilities.indexOf(Math.max(...probabilities))); + allLabels.push(sample.label); + } catch (error) { + console.error("Processing error:", error); + } + } + const elapsed = new Date().getTime() - t; - return calculateMetrics(allLabels, allPredictions, elapsed); + return calculateMetrics(allLabels, allPredictions, elapsed); } // 主函数 async function main() { - await loadTokenizer(); + await loadTokenizer(); - const metrics = await evaluateModel(sessionEmbedding); - console.log("Model Metrics:"); - console.table(metrics); + const metrics = await evaluateModel(sessionEmbedding); + console.log("Model Metrics:"); + console.table(metrics); } -await main(); \ No newline at end of file +await main(); diff --git a/lib/mq/exec/classifyVideo.ts b/lib/mq/exec/classifyVideo.ts index 26d7053..3541892 100644 --- a/lib/mq/exec/classifyVideo.ts +++ b/lib/mq/exec/classifyVideo.ts @@ -5,6 +5,8 @@ import { classifyVideo } from "lib/ml/filter_inference.ts"; import { ClassifyVideoQueue } from "lib/mq/index.ts"; import logger from "lib/log/logger.ts"; import { lockManager } from "lib/mq/lockManager.ts"; +import { aidExistsInSongs } from "lib/db/songs.ts"; +import { insertIntoSongs } from "lib/mq/task/collectSongs.ts"; export const classifyVideoWorker = async (job: Job) => { const client = await db.connect(); @@ -23,6 +25,11 @@ export const classifyVideoWorker = async (job: Job) => { } await insertVideoLabel(client, aid, label); + const exists = await aidExistsInSongs(client, aid); + if (!exists) { + await insertIntoSongs(client, aid); + } + client.release(); await job.updateData({ diff --git a/lib/mq/exec/getBiliUserInfo.ts b/lib/mq/exec/getBiliUserInfo.ts deleted file mode 100644 index e69de29..0000000 diff --git a/lib/mq/exec/getLatestVideos.ts b/lib/mq/exec/getLatestVideos.ts index 4f795e0..65067cd 100644 --- a/lib/mq/exec/getLatestVideos.ts +++ b/lib/mq/exec/getLatestVideos.ts @@ -1,6 +1,8 @@ import { Job } from "bullmq"; import { queueLatestVideos } from "lib/mq/task/queueLatestVideo.ts"; import { db } from "lib/db/init.ts"; +import { insertVideoInfo } from "lib/mq/task/getVideoInfo.ts"; +import { collectSongs } from "lib/mq/task/collectSongs.ts"; export const getLatestVideosWorker = async (_job: Job): Promise => { const client = await db.connect(); @@ -10,3 +12,26 @@ export const getLatestVideosWorker = async (_job: Job): Promise => { client.release(); } }; + +export const collectSongsWorker = async (_job: Job): Promise => { + const client = await db.connect(); + try { + await collectSongs(client); + } finally { + client.release(); + } +}; + +export const getVideoInfoWorker = async (job: Job): Promise => { + const client = await db.connect(); + try { + const aid = job.data.aid; + if (!aid) { + return 3; + } + await insertVideoInfo(client, aid); + return 0; + } finally { + client.release(); + } +}; diff --git a/lib/mq/exec/getVideoInfo.ts b/lib/mq/exec/getVideoInfo.ts deleted file mode 100644 index dfc5e89..0000000 --- a/lib/mq/exec/getVideoInfo.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { Job } from "bullmq"; -import { db } from "lib/db/init.ts"; -import { insertVideoInfo } from "lib/mq/task/getVideoInfo.ts"; - -export const getVideoInfoWorker = async (job: Job): Promise => { - const client = await db.connect(); - try { - const aid = job.data.aid; - if (!aid) { - return 3; - } - await insertVideoInfo(client, aid); - return 0; - } finally { - client.release(); - } -}; diff --git a/lib/mq/init.ts b/lib/mq/init.ts index 3eb2d81..1073471 100644 --- a/lib/mq/init.ts +++ b/lib/mq/init.ts @@ -11,6 +11,10 @@ export async function initMQ() { every: 5 * MINUTE, immediately: true, }); + await LatestVideosQueue.upsertJobScheduler("collectSongs", { + every: 3 * MINUTE, + immediately: true, + }); logger.log("Message queue initialized."); } diff --git a/lib/mq/task/collectSongs.ts b/lib/mq/task/collectSongs.ts new file mode 100644 index 0000000..04e033d --- /dev/null +++ b/lib/mq/task/collectSongs.ts @@ -0,0 +1,29 @@ +import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; +import { aidExistsInSongs, getNotCollectedSongs } from "lib/db/songs.ts"; +import logger from "lib/log/logger.ts"; + +export async function collectSongs(client: Client) { + const aids = await getNotCollectedSongs(client); + for (const aid of aids) { + const exists = await aidExistsInSongs(client, aid); + if (exists) continue; + await insertIntoSongs(client, aid); + logger.log(`Video ${aid} was added into the songs table.`, "mq", "fn:collectSongs"); + } +} + +export async function insertIntoSongs(client: Client, aid: number) { + await client.queryObject( + ` + INSERT INTO songs (aid, bvid, published_at, duration) + VALUES ( + $1, + (SELECT bvid FROM all_data WHERE aid = $1), + (SELECT published_at FROM all_data WHERE aid = $1), + (SELECT duration FROM all_data WHERE aid = $1) + ) + ON CONFLICT DO NOTHING + `, + [aid], + ); +} diff --git a/lib/mq/task/getVideoInfo.ts b/lib/mq/task/getVideoInfo.ts index 3a49628..6f0ba58 100644 --- a/lib/mq/task/getVideoInfo.ts +++ b/lib/mq/task/getVideoInfo.ts @@ -30,12 +30,11 @@ export async function insertVideoInfo(client: Client, aid: number) { ); const userExists = await userExistsInBiliUsers(client, aid); if (!userExists) { - await client.queryObject( + await client.queryObject( `INSERT INTO bili_user (uid, username, "desc", fans) VALUES ($1, $2, $3, $4)`, [uid, data.View.owner.name, data.Card.card.sign, data.Card.follower], ); - } - else { + } else { await client.queryObject( `UPDATE bili_user SET fans = $1 WHERE uid = $2`, [data.Card.follower, uid], diff --git a/lib/mq/task/queueLatestVideo.ts b/lib/mq/task/queueLatestVideo.ts index eac26b5..d2e938b 100644 --- a/lib/mq/task/queueLatestVideo.ts +++ b/lib/mq/task/queueLatestVideo.ts @@ -26,12 +26,13 @@ export async function queueLatestVideos( if (videoExists) { continue; } - await LatestVideosQueue.add("getVideoInfo", { aid }, { delay, + await LatestVideosQueue.add("getVideoInfo", { aid }, { + delay, attempts: 100, backoff: { type: "fixed", - delay: SECOND * 5 - } + delay: SECOND * 5, + }, }); videosFound.add(aid); allExists = false; diff --git a/src/worker.ts b/src/worker.ts index fbe791c..bc9af5b 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -1,10 +1,10 @@ import { Job, Worker } from "bullmq"; -import { getLatestVideosWorker } from "lib/mq/executors.ts"; +import { collectSongsWorker, getLatestVideosWorker } from "lib/mq/executors.ts"; import { redis } from "lib/db/redis.ts"; import logger from "lib/log/logger.ts"; import { lockManager } from "lib/mq/lockManager.ts"; import { WorkerError } from "lib/mq/schema.ts"; -import { getVideoInfoWorker } from "lib/mq/exec/getVideoInfo.ts"; +import { getVideoInfoWorker } from "lib/mq/exec/getLatestVideos.ts"; Deno.addSignalListener("SIGINT", async () => { logger.log("SIGINT Received: Shutting down workers...", "mq"); @@ -28,6 +28,9 @@ const latestVideoWorker = new Worker( case "getVideoInfo": await getVideoInfoWorker(job); break; + case "collectSongs": + await collectSongsWorker(job); + break; default: break; }