add: insert labelled songs into songs table

This commit is contained in:
alikia2x (寒寒) 2025-03-08 00:55:29 +08:00
parent 2a2e65804f
commit fa414e89ce
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
12 changed files with 231 additions and 148 deletions

View File

@ -1,6 +1,6 @@
import {Client} from "https://deno.land/x/postgres@v0.19.3/mod.ts"; import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import {AllDataType, BiliUserType} from "lib/db/schema.d.ts"; import { AllDataType, BiliUserType } from "lib/db/schema.d.ts";
import {modelVersion} from "lib/ml/filter_inference.ts"; import { modelVersion } from "lib/ml/filter_inference.ts";
export async function videoExistsInAllData(client: Client, aid: number) { export async function videoExistsInAllData(client: Client, aid: number) {
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid]) return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
@ -8,7 +8,9 @@ export async function videoExistsInAllData(client: Client, aid: number) {
} }
export async function userExistsInBiliUsers(client: Client, uid: number) { export async function userExistsInBiliUsers(client: Client, uid: number) {
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM bili_user WHERE uid = $1)`, [uid]) return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM bili_user WHERE uid = $1)`, [
uid,
]);
} }
export async function getUnlabelledVideos(client: Client) { export async function getUnlabelledVideos(client: Client) {
@ -36,28 +38,29 @@ export async function getVideoInfoFromAllData(client: Client, aid: number) {
const q = await client.queryObject<BiliUserType>( const q = await client.queryObject<BiliUserType>(
`SELECT * FROM bili_user WHERE uid = $1`, `SELECT * FROM bili_user WHERE uid = $1`,
[row.uid], [row.uid],
) );
const userRow = q.rows[0]; const userRow = q.rows[0];
if (userRow) if (userRow) {
authorInfo = userRow.desc; authorInfo = userRow.desc;
}
} }
return { return {
title: row.title, title: row.title,
description: row.description, description: row.description,
tags: row.tags, tags: row.tags,
author_info: authorInfo author_info: authorInfo,
}; };
} }
export async function getUnArchivedBiliUsers(client: Client) { export async function getUnArchivedBiliUsers(client: Client) {
const queryResult = await client.queryObject<{uid: number}>( const queryResult = await client.queryObject<{ uid: number }>(
` `
SELECT ad.uid SELECT ad.uid
FROM all_data ad FROM all_data ad
LEFT JOIN bili_user bu ON ad.uid = bu.uid LEFT JOIN bili_user bu ON ad.uid = bu.uid
WHERE bu.uid IS NULL; WHERE bu.uid IS NULL;
`, `,
[] [],
); );
const rows = queryResult.rows; const rows = queryResult.rows;
return rows.map((row) => row.uid); return rows.map((row) => row.uid);

29
lib/db/songs.ts Normal file
View File

@ -0,0 +1,29 @@
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
export async function getNotCollectedSongs(client: Client) {
const queryResult = await client.queryObject<{ aid: number }>(`
SELECT lr.aid
FROM labelling_result lr
WHERE lr.label != 0
AND NOT EXISTS (
SELECT 1
FROM songs s
WHERE s.aid = lr.aid
);
`);
return queryResult.rows.map((row) => row.aid);
}
export async function aidExistsInSongs(client: Client, aid: number) {
const queryResult = await client.queryObject<{ exists: boolean }>(
`
SELECT EXISTS (
SELECT 1
FROM songs
WHERE aid = $1
);
`,
[aid],
);
return queryResult.rows[0].exists;
}

View File

@ -10,164 +10,164 @@ const testDataPath = "./data/filter/test1.jsonl";
// 初始化会话 // 初始化会话
const [sessionClassifier, sessionEmbedding] = await Promise.all([ const [sessionClassifier, sessionEmbedding] = await Promise.all([
ort.InferenceSession.create(onnxClassifierPath), ort.InferenceSession.create(onnxClassifierPath),
ort.InferenceSession.create(onnxEmbeddingPath), ort.InferenceSession.create(onnxEmbeddingPath),
]); ]);
let tokenizer: PreTrainedTokenizer; let tokenizer: PreTrainedTokenizer;
// 初始化分词器 // 初始化分词器
async function loadTokenizer() { async function loadTokenizer() {
const tokenizerConfig = { local_files_only: true }; const tokenizerConfig = { local_files_only: true };
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig); tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig);
} }
// 新的嵌入生成函数使用ONNX // 新的嵌入生成函数使用ONNX
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> { async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
const { input_ids } = await tokenizer(texts, { const { input_ids } = await tokenizer(texts, {
add_special_tokens: false, add_special_tokens: false,
return_tensor: false, return_tensor: false,
}); });
// 构造输入参数 // 构造输入参数
const cumsum = (arr: number[]): number[] => const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))]; const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const flattened_input_ids = input_ids.flat(); const flattened_input_ids = input_ids.flat();
// 准备ONNX输入 // 准备ONNX输入
const inputs = { const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
flattened_input_ids.length, flattened_input_ids.length,
]), ]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]), offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
}; };
// 执行推理 // 执行推理
const { embeddings } = await session.run(inputs); const { embeddings } = await session.run(inputs);
return Array.from(embeddings.data as Float32Array); return Array.from(embeddings.data as Float32Array);
} }
// 分类推理函数 // 分类推理函数
async function runClassification(embeddings: number[]): Promise<number[]> { async function runClassification(embeddings: number[]): Promise<number[]> {
const inputTensor = new ort.Tensor( const inputTensor = new ort.Tensor(
Float32Array.from(embeddings), Float32Array.from(embeddings),
[1, 3, 1024], [1, 3, 1024],
); );
const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
return softmax(logits.data as Float32Array); return softmax(logits.data as Float32Array);
} }
// 指标计算函数 // 指标计算函数
function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): { function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): {
accuracy: number; accuracy: number;
precision: number; precision: number;
recall: number; recall: number;
f1: number; f1: number;
"Class 0 Prec": number; "Class 0 Prec": number;
speed: string; speed: string;
} { } {
// 输出label和prediction不一样的index列表 // 输出label和prediction不一样的index列表
const arr = [] const arr = [];
for (let i = 0; i < labels.length; i++) { for (let i = 0; i < labels.length; i++) {
if (labels[i] !== predictions[i] && predictions[i] == 0) { if (labels[i] !== predictions[i] && predictions[i] == 0) {
arr.push([i + 1, labels[i], predictions[i]]) arr.push([i + 1, labels[i], predictions[i]]);
} }
} }
console.log(arr) console.log(arr);
// 初始化混淆矩阵 // 初始化混淆矩阵
const classCount = Math.max(...labels, ...predictions) + 1; const classCount = Math.max(...labels, ...predictions) + 1;
const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0)); const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
// 填充矩阵 // 填充矩阵
labels.forEach((trueLabel, i) => { labels.forEach((trueLabel, i) => {
matrix[trueLabel][predictions[i]]++; matrix[trueLabel][predictions[i]]++;
}); });
// 计算各指标 // 计算各指标
let totalTP = 0, totalFP = 0, totalFN = 0; let totalTP = 0, totalFP = 0, totalFN = 0;
for (let c = 0; c < classCount; c++) { for (let c = 0; c < classCount; c++) {
const TP = matrix[c][c]; const TP = matrix[c][c];
const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0); const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0);
const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0); const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0);
totalTP += TP; totalTP += TP;
totalFP += FP; totalFP += FP;
totalFN += FN; totalFN += FN;
} }
const precision = totalTP / (totalTP + totalFP); const precision = totalTP / (totalTP + totalFP);
const recall = totalTP / (totalTP + totalFN); const recall = totalTP / (totalTP + totalFN);
const f1 = 2 * (precision * recall) / (precision + recall) || 0; const f1 = 2 * (precision * recall) / (precision + recall) || 0;
// 计算Class 0 Precision // 计算Class 0 Precision
const class0TP = matrix[0][0]; const class0TP = matrix[0][0];
const class0FP = matrix.flatMap((row, i) => i === 0 ? [] : [row[0]]).reduce((a, b) => a + b, 0); const class0FP = matrix.flatMap((row, i) => i === 0 ? [] : [row[0]]).reduce((a, b) => a + b, 0);
const class0Precision = class0TP / (class0TP + class0FP) || 0; const class0Precision = class0TP / (class0TP + class0FP) || 0;
return { return {
accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length, accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length,
precision, precision,
recall, recall,
f1, f1,
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`, speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`,
"Class 0 Prec": class0Precision, "Class 0 Prec": class0Precision,
}; };
} }
// 改造后的评估函数 // 改造后的评估函数
async function evaluateModel(session: ort.InferenceSession): Promise<{ async function evaluateModel(session: ort.InferenceSession): Promise<{
accuracy: number; accuracy: number;
precision: number; precision: number;
recall: number; recall: number;
f1: number; f1: number;
"Class 0 Prec": number; "Class 0 Prec": number;
}> { }> {
const data = await Deno.readTextFile(testDataPath); const data = await Deno.readTextFile(testDataPath);
const samples = data.split("\n") const samples = data.split("\n")
.map((line) => { .map((line) => {
try { try {
return JSON.parse(line); return JSON.parse(line);
} catch { } catch {
return null; return null;
} }
}) })
.filter(Boolean); .filter(Boolean);
const allPredictions: number[] = []; const allPredictions: number[] = [];
const allLabels: number[] = []; const allLabels: number[] = [];
const t = new Date().getTime(); const t = new Date().getTime();
for (const sample of samples) { for (const sample of samples) {
try { try {
const embeddings = await getONNXEmbeddings([ const embeddings = await getONNXEmbeddings([
sample.title, sample.title,
sample.description, sample.description,
sample.tags.join(",") sample.tags.join(","),
], session); ], session);
const probabilities = await runClassification(embeddings); const probabilities = await runClassification(embeddings);
allPredictions.push(probabilities.indexOf(Math.max(...probabilities))); allPredictions.push(probabilities.indexOf(Math.max(...probabilities)));
allLabels.push(sample.label); allLabels.push(sample.label);
} catch (error) { } catch (error) {
console.error("Processing error:", error); console.error("Processing error:", error);
} }
} }
const elapsed = new Date().getTime() - t; const elapsed = new Date().getTime() - t;
return calculateMetrics(allLabels, allPredictions, elapsed); return calculateMetrics(allLabels, allPredictions, elapsed);
} }
// 主函数 // 主函数
async function main() { async function main() {
await loadTokenizer(); await loadTokenizer();
const metrics = await evaluateModel(sessionEmbedding); const metrics = await evaluateModel(sessionEmbedding);
console.log("Model Metrics:"); console.log("Model Metrics:");
console.table(metrics); console.table(metrics);
} }
await main(); await main();

View File

@ -5,6 +5,8 @@ import { classifyVideo } from "lib/ml/filter_inference.ts";
import { ClassifyVideoQueue } from "lib/mq/index.ts"; import { ClassifyVideoQueue } from "lib/mq/index.ts";
import logger from "lib/log/logger.ts"; import logger from "lib/log/logger.ts";
import { lockManager } from "lib/mq/lockManager.ts"; import { lockManager } from "lib/mq/lockManager.ts";
import { aidExistsInSongs } from "lib/db/songs.ts";
import { insertIntoSongs } from "lib/mq/task/collectSongs.ts";
export const classifyVideoWorker = async (job: Job) => { export const classifyVideoWorker = async (job: Job) => {
const client = await db.connect(); const client = await db.connect();
@ -23,6 +25,11 @@ export const classifyVideoWorker = async (job: Job) => {
} }
await insertVideoLabel(client, aid, label); await insertVideoLabel(client, aid, label);
const exists = await aidExistsInSongs(client, aid);
if (!exists) {
await insertIntoSongs(client, aid);
}
client.release(); client.release();
await job.updateData({ await job.updateData({

View File

@ -1,6 +1,8 @@
import { Job } from "bullmq"; import { Job } from "bullmq";
import { queueLatestVideos } from "lib/mq/task/queueLatestVideo.ts"; import { queueLatestVideos } from "lib/mq/task/queueLatestVideo.ts";
import { db } from "lib/db/init.ts"; import { db } from "lib/db/init.ts";
import { insertVideoInfo } from "lib/mq/task/getVideoInfo.ts";
import { collectSongs } from "lib/mq/task/collectSongs.ts";
export const getLatestVideosWorker = async (_job: Job): Promise<void> => { export const getLatestVideosWorker = async (_job: Job): Promise<void> => {
const client = await db.connect(); const client = await db.connect();
@ -10,3 +12,26 @@ export const getLatestVideosWorker = async (_job: Job): Promise<void> => {
client.release(); client.release();
} }
}; };
export const collectSongsWorker = async (_job: Job): Promise<void> => {
const client = await db.connect();
try {
await collectSongs(client);
} finally {
client.release();
}
};
export const getVideoInfoWorker = async (job: Job): Promise<number> => {
const client = await db.connect();
try {
const aid = job.data.aid;
if (!aid) {
return 3;
}
await insertVideoInfo(client, aid);
return 0;
} finally {
client.release();
}
};

View File

@ -1,17 +0,0 @@
import { Job } from "bullmq";
import { db } from "lib/db/init.ts";
import { insertVideoInfo } from "lib/mq/task/getVideoInfo.ts";
export const getVideoInfoWorker = async (job: Job): Promise<number> => {
const client = await db.connect();
try {
const aid = job.data.aid;
if (!aid) {
return 3;
}
await insertVideoInfo(client, aid);
return 0;
} finally {
client.release();
}
};

View File

@ -11,6 +11,10 @@ export async function initMQ() {
every: 5 * MINUTE, every: 5 * MINUTE,
immediately: true, immediately: true,
}); });
await LatestVideosQueue.upsertJobScheduler("collectSongs", {
every: 3 * MINUTE,
immediately: true,
});
logger.log("Message queue initialized."); logger.log("Message queue initialized.");
} }

View File

@ -0,0 +1,29 @@
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import { aidExistsInSongs, getNotCollectedSongs } from "lib/db/songs.ts";
import logger from "lib/log/logger.ts";
export async function collectSongs(client: Client) {
const aids = await getNotCollectedSongs(client);
for (const aid of aids) {
const exists = await aidExistsInSongs(client, aid);
if (exists) continue;
await insertIntoSongs(client, aid);
logger.log(`Video ${aid} was added into the songs table.`, "mq", "fn:collectSongs");
}
}
export async function insertIntoSongs(client: Client, aid: number) {
await client.queryObject(
`
INSERT INTO songs (aid, bvid, published_at, duration)
VALUES (
$1,
(SELECT bvid FROM all_data WHERE aid = $1),
(SELECT published_at FROM all_data WHERE aid = $1),
(SELECT duration FROM all_data WHERE aid = $1)
)
ON CONFLICT DO NOTHING
`,
[aid],
);
}

View File

@ -30,12 +30,11 @@ export async function insertVideoInfo(client: Client, aid: number) {
); );
const userExists = await userExistsInBiliUsers(client, aid); const userExists = await userExistsInBiliUsers(client, aid);
if (!userExists) { if (!userExists) {
await client.queryObject( await client.queryObject(
`INSERT INTO bili_user (uid, username, "desc", fans) VALUES ($1, $2, $3, $4)`, `INSERT INTO bili_user (uid, username, "desc", fans) VALUES ($1, $2, $3, $4)`,
[uid, data.View.owner.name, data.Card.card.sign, data.Card.follower], [uid, data.View.owner.name, data.Card.card.sign, data.Card.follower],
); );
} } else {
else {
await client.queryObject( await client.queryObject(
`UPDATE bili_user SET fans = $1 WHERE uid = $2`, `UPDATE bili_user SET fans = $1 WHERE uid = $2`,
[data.Card.follower, uid], [data.Card.follower, uid],

View File

@ -26,12 +26,13 @@ export async function queueLatestVideos(
if (videoExists) { if (videoExists) {
continue; continue;
} }
await LatestVideosQueue.add("getVideoInfo", { aid }, { delay, await LatestVideosQueue.add("getVideoInfo", { aid }, {
delay,
attempts: 100, attempts: 100,
backoff: { backoff: {
type: "fixed", type: "fixed",
delay: SECOND * 5 delay: SECOND * 5,
} },
}); });
videosFound.add(aid); videosFound.add(aid);
allExists = false; allExists = false;

View File

@ -1,10 +1,10 @@
import { Job, Worker } from "bullmq"; import { Job, Worker } from "bullmq";
import { getLatestVideosWorker } from "lib/mq/executors.ts"; import { collectSongsWorker, getLatestVideosWorker } from "lib/mq/executors.ts";
import { redis } from "lib/db/redis.ts"; import { redis } from "lib/db/redis.ts";
import logger from "lib/log/logger.ts"; import logger from "lib/log/logger.ts";
import { lockManager } from "lib/mq/lockManager.ts"; import { lockManager } from "lib/mq/lockManager.ts";
import { WorkerError } from "lib/mq/schema.ts"; import { WorkerError } from "lib/mq/schema.ts";
import { getVideoInfoWorker } from "lib/mq/exec/getVideoInfo.ts"; import { getVideoInfoWorker } from "lib/mq/exec/getLatestVideos.ts";
Deno.addSignalListener("SIGINT", async () => { Deno.addSignalListener("SIGINT", async () => {
logger.log("SIGINT Received: Shutting down workers...", "mq"); logger.log("SIGINT Received: Shutting down workers...", "mq");
@ -28,6 +28,9 @@ const latestVideoWorker = new Worker(
case "getVideoInfo": case "getVideoInfo":
await getVideoInfoWorker(job); await getVideoInfoWorker(job);
break; break;
case "collectSongs":
await collectSongsWorker(job);
break;
default: default:
break; break;
} }