add: insert labelled songs into songs table
This commit is contained in:
parent
2a2e65804f
commit
fa414e89ce
@ -1,6 +1,6 @@
|
||||
import {Client} from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||
import {AllDataType, BiliUserType} from "lib/db/schema.d.ts";
|
||||
import {modelVersion} from "lib/ml/filter_inference.ts";
|
||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||
import { AllDataType, BiliUserType } from "lib/db/schema.d.ts";
|
||||
import { modelVersion } from "lib/ml/filter_inference.ts";
|
||||
|
||||
export async function videoExistsInAllData(client: Client, aid: number) {
|
||||
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
|
||||
@ -8,7 +8,9 @@ export async function videoExistsInAllData(client: Client, aid: number) {
|
||||
}
|
||||
|
||||
export async function userExistsInBiliUsers(client: Client, uid: number) {
|
||||
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM bili_user WHERE uid = $1)`, [uid])
|
||||
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM bili_user WHERE uid = $1)`, [
|
||||
uid,
|
||||
]);
|
||||
}
|
||||
|
||||
export async function getUnlabelledVideos(client: Client) {
|
||||
@ -36,28 +38,29 @@ export async function getVideoInfoFromAllData(client: Client, aid: number) {
|
||||
const q = await client.queryObject<BiliUserType>(
|
||||
`SELECT * FROM bili_user WHERE uid = $1`,
|
||||
[row.uid],
|
||||
)
|
||||
);
|
||||
const userRow = q.rows[0];
|
||||
if (userRow)
|
||||
if (userRow) {
|
||||
authorInfo = userRow.desc;
|
||||
}
|
||||
}
|
||||
return {
|
||||
title: row.title,
|
||||
description: row.description,
|
||||
tags: row.tags,
|
||||
author_info: authorInfo
|
||||
author_info: authorInfo,
|
||||
};
|
||||
}
|
||||
|
||||
export async function getUnArchivedBiliUsers(client: Client) {
|
||||
const queryResult = await client.queryObject<{uid: number}>(
|
||||
const queryResult = await client.queryObject<{ uid: number }>(
|
||||
`
|
||||
SELECT ad.uid
|
||||
FROM all_data ad
|
||||
LEFT JOIN bili_user bu ON ad.uid = bu.uid
|
||||
WHERE bu.uid IS NULL;
|
||||
`,
|
||||
[]
|
||||
[],
|
||||
);
|
||||
const rows = queryResult.rows;
|
||||
return rows.map((row) => row.uid);
|
||||
|
29
lib/db/songs.ts
Normal file
29
lib/db/songs.ts
Normal file
@ -0,0 +1,29 @@
|
||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||
|
||||
export async function getNotCollectedSongs(client: Client) {
|
||||
const queryResult = await client.queryObject<{ aid: number }>(`
|
||||
SELECT lr.aid
|
||||
FROM labelling_result lr
|
||||
WHERE lr.label != 0
|
||||
AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM songs s
|
||||
WHERE s.aid = lr.aid
|
||||
);
|
||||
`);
|
||||
return queryResult.rows.map((row) => row.aid);
|
||||
}
|
||||
|
||||
export async function aidExistsInSongs(client: Client, aid: number) {
|
||||
const queryResult = await client.queryObject<{ exists: boolean }>(
|
||||
`
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM songs
|
||||
WHERE aid = $1
|
||||
);
|
||||
`,
|
||||
[aid],
|
||||
);
|
||||
return queryResult.rows[0].exists;
|
||||
}
|
@ -10,164 +10,164 @@ const testDataPath = "./data/filter/test1.jsonl";
|
||||
|
||||
// 初始化会话
|
||||
const [sessionClassifier, sessionEmbedding] = await Promise.all([
|
||||
ort.InferenceSession.create(onnxClassifierPath),
|
||||
ort.InferenceSession.create(onnxEmbeddingPath),
|
||||
ort.InferenceSession.create(onnxClassifierPath),
|
||||
ort.InferenceSession.create(onnxEmbeddingPath),
|
||||
]);
|
||||
|
||||
let tokenizer: PreTrainedTokenizer;
|
||||
|
||||
// 初始化分词器
|
||||
async function loadTokenizer() {
|
||||
const tokenizerConfig = { local_files_only: true };
|
||||
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig);
|
||||
const tokenizerConfig = { local_files_only: true };
|
||||
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig);
|
||||
}
|
||||
|
||||
// 新的嵌入生成函数(使用ONNX)
|
||||
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
|
||||
const { input_ids } = await tokenizer(texts, {
|
||||
add_special_tokens: false,
|
||||
return_tensor: false,
|
||||
});
|
||||
const { input_ids } = await tokenizer(texts, {
|
||||
add_special_tokens: false,
|
||||
return_tensor: false,
|
||||
});
|
||||
|
||||
// 构造输入参数
|
||||
const cumsum = (arr: number[]): number[] =>
|
||||
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||
// 构造输入参数
|
||||
const cumsum = (arr: number[]): number[] =>
|
||||
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||
|
||||
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
|
||||
const flattened_input_ids = input_ids.flat();
|
||||
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
|
||||
const flattened_input_ids = input_ids.flat();
|
||||
|
||||
// 准备ONNX输入
|
||||
const inputs = {
|
||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
||||
flattened_input_ids.length,
|
||||
]),
|
||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
||||
};
|
||||
// 准备ONNX输入
|
||||
const inputs = {
|
||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
||||
flattened_input_ids.length,
|
||||
]),
|
||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
||||
};
|
||||
|
||||
// 执行推理
|
||||
const { embeddings } = await session.run(inputs);
|
||||
return Array.from(embeddings.data as Float32Array);
|
||||
// 执行推理
|
||||
const { embeddings } = await session.run(inputs);
|
||||
return Array.from(embeddings.data as Float32Array);
|
||||
}
|
||||
|
||||
// 分类推理函数
|
||||
async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||
const inputTensor = new ort.Tensor(
|
||||
Float32Array.from(embeddings),
|
||||
[1, 3, 1024],
|
||||
);
|
||||
const inputTensor = new ort.Tensor(
|
||||
Float32Array.from(embeddings),
|
||||
[1, 3, 1024],
|
||||
);
|
||||
|
||||
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
||||
return softmax(logits.data as Float32Array);
|
||||
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
||||
return softmax(logits.data as Float32Array);
|
||||
}
|
||||
|
||||
// 指标计算函数
|
||||
function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): {
|
||||
accuracy: number;
|
||||
precision: number;
|
||||
recall: number;
|
||||
f1: number;
|
||||
"Class 0 Prec": number;
|
||||
speed: string;
|
||||
accuracy: number;
|
||||
precision: number;
|
||||
recall: number;
|
||||
f1: number;
|
||||
"Class 0 Prec": number;
|
||||
speed: string;
|
||||
} {
|
||||
// 输出label和prediction不一样的index列表
|
||||
const arr = []
|
||||
for (let i = 0; i < labels.length; i++) {
|
||||
if (labels[i] !== predictions[i] && predictions[i] == 0) {
|
||||
arr.push([i + 1, labels[i], predictions[i]])
|
||||
}
|
||||
}
|
||||
console.log(arr)
|
||||
// 初始化混淆矩阵
|
||||
const classCount = Math.max(...labels, ...predictions) + 1;
|
||||
const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
|
||||
// 输出label和prediction不一样的index列表
|
||||
const arr = [];
|
||||
for (let i = 0; i < labels.length; i++) {
|
||||
if (labels[i] !== predictions[i] && predictions[i] == 0) {
|
||||
arr.push([i + 1, labels[i], predictions[i]]);
|
||||
}
|
||||
}
|
||||
console.log(arr);
|
||||
// 初始化混淆矩阵
|
||||
const classCount = Math.max(...labels, ...predictions) + 1;
|
||||
const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
|
||||
|
||||
// 填充矩阵
|
||||
labels.forEach((trueLabel, i) => {
|
||||
matrix[trueLabel][predictions[i]]++;
|
||||
});
|
||||
// 填充矩阵
|
||||
labels.forEach((trueLabel, i) => {
|
||||
matrix[trueLabel][predictions[i]]++;
|
||||
});
|
||||
|
||||
// 计算各指标
|
||||
let totalTP = 0, totalFP = 0, totalFN = 0;
|
||||
// 计算各指标
|
||||
let totalTP = 0, totalFP = 0, totalFN = 0;
|
||||
|
||||
for (let c = 0; c < classCount; c++) {
|
||||
const TP = matrix[c][c];
|
||||
const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0);
|
||||
const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0);
|
||||
for (let c = 0; c < classCount; c++) {
|
||||
const TP = matrix[c][c];
|
||||
const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0);
|
||||
const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0);
|
||||
|
||||
totalTP += TP;
|
||||
totalFP += FP;
|
||||
totalFN += FN;
|
||||
}
|
||||
totalTP += TP;
|
||||
totalFP += FP;
|
||||
totalFN += FN;
|
||||
}
|
||||
|
||||
const precision = totalTP / (totalTP + totalFP);
|
||||
const recall = totalTP / (totalTP + totalFN);
|
||||
const f1 = 2 * (precision * recall) / (precision + recall) || 0;
|
||||
const precision = totalTP / (totalTP + totalFP);
|
||||
const recall = totalTP / (totalTP + totalFN);
|
||||
const f1 = 2 * (precision * recall) / (precision + recall) || 0;
|
||||
|
||||
// 计算Class 0 Precision
|
||||
const class0TP = matrix[0][0];
|
||||
const class0FP = matrix.flatMap((row, i) => i === 0 ? [] : [row[0]]).reduce((a, b) => a + b, 0);
|
||||
const class0Precision = class0TP / (class0TP + class0FP) || 0;
|
||||
// 计算Class 0 Precision
|
||||
const class0TP = matrix[0][0];
|
||||
const class0FP = matrix.flatMap((row, i) => i === 0 ? [] : [row[0]]).reduce((a, b) => a + b, 0);
|
||||
const class0Precision = class0TP / (class0TP + class0FP) || 0;
|
||||
|
||||
return {
|
||||
accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length,
|
||||
precision,
|
||||
recall,
|
||||
f1,
|
||||
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`,
|
||||
"Class 0 Prec": class0Precision,
|
||||
};
|
||||
return {
|
||||
accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length,
|
||||
precision,
|
||||
recall,
|
||||
f1,
|
||||
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`,
|
||||
"Class 0 Prec": class0Precision,
|
||||
};
|
||||
}
|
||||
|
||||
// 改造后的评估函数
|
||||
async function evaluateModel(session: ort.InferenceSession): Promise<{
|
||||
accuracy: number;
|
||||
precision: number;
|
||||
recall: number;
|
||||
f1: number;
|
||||
"Class 0 Prec": number;
|
||||
accuracy: number;
|
||||
precision: number;
|
||||
recall: number;
|
||||
f1: number;
|
||||
"Class 0 Prec": number;
|
||||
}> {
|
||||
const data = await Deno.readTextFile(testDataPath);
|
||||
const samples = data.split("\n")
|
||||
.map((line) => {
|
||||
try {
|
||||
return JSON.parse(line);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Boolean);
|
||||
const data = await Deno.readTextFile(testDataPath);
|
||||
const samples = data.split("\n")
|
||||
.map((line) => {
|
||||
try {
|
||||
return JSON.parse(line);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Boolean);
|
||||
|
||||
const allPredictions: number[] = [];
|
||||
const allLabels: number[] = [];
|
||||
const allPredictions: number[] = [];
|
||||
const allLabels: number[] = [];
|
||||
|
||||
const t = new Date().getTime();
|
||||
for (const sample of samples) {
|
||||
try {
|
||||
const embeddings = await getONNXEmbeddings([
|
||||
sample.title,
|
||||
sample.description,
|
||||
sample.tags.join(",")
|
||||
], session);
|
||||
const t = new Date().getTime();
|
||||
for (const sample of samples) {
|
||||
try {
|
||||
const embeddings = await getONNXEmbeddings([
|
||||
sample.title,
|
||||
sample.description,
|
||||
sample.tags.join(","),
|
||||
], session);
|
||||
|
||||
const probabilities = await runClassification(embeddings);
|
||||
allPredictions.push(probabilities.indexOf(Math.max(...probabilities)));
|
||||
allLabels.push(sample.label);
|
||||
} catch (error) {
|
||||
console.error("Processing error:", error);
|
||||
}
|
||||
}
|
||||
const elapsed = new Date().getTime() - t;
|
||||
const probabilities = await runClassification(embeddings);
|
||||
allPredictions.push(probabilities.indexOf(Math.max(...probabilities)));
|
||||
allLabels.push(sample.label);
|
||||
} catch (error) {
|
||||
console.error("Processing error:", error);
|
||||
}
|
||||
}
|
||||
const elapsed = new Date().getTime() - t;
|
||||
|
||||
return calculateMetrics(allLabels, allPredictions, elapsed);
|
||||
return calculateMetrics(allLabels, allPredictions, elapsed);
|
||||
}
|
||||
|
||||
// 主函数
|
||||
async function main() {
|
||||
await loadTokenizer();
|
||||
await loadTokenizer();
|
||||
|
||||
const metrics = await evaluateModel(sessionEmbedding);
|
||||
console.log("Model Metrics:");
|
||||
console.table(metrics);
|
||||
const metrics = await evaluateModel(sessionEmbedding);
|
||||
console.log("Model Metrics:");
|
||||
console.table(metrics);
|
||||
}
|
||||
|
||||
await main();
|
||||
await main();
|
||||
|
@ -5,6 +5,8 @@ import { classifyVideo } from "lib/ml/filter_inference.ts";
|
||||
import { ClassifyVideoQueue } from "lib/mq/index.ts";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { lockManager } from "lib/mq/lockManager.ts";
|
||||
import { aidExistsInSongs } from "lib/db/songs.ts";
|
||||
import { insertIntoSongs } from "lib/mq/task/collectSongs.ts";
|
||||
|
||||
export const classifyVideoWorker = async (job: Job) => {
|
||||
const client = await db.connect();
|
||||
@ -23,6 +25,11 @@ export const classifyVideoWorker = async (job: Job) => {
|
||||
}
|
||||
await insertVideoLabel(client, aid, label);
|
||||
|
||||
const exists = await aidExistsInSongs(client, aid);
|
||||
if (!exists) {
|
||||
await insertIntoSongs(client, aid);
|
||||
}
|
||||
|
||||
client.release();
|
||||
|
||||
await job.updateData({
|
||||
|
@ -1,6 +1,8 @@
|
||||
import { Job } from "bullmq";
|
||||
import { queueLatestVideos } from "lib/mq/task/queueLatestVideo.ts";
|
||||
import { db } from "lib/db/init.ts";
|
||||
import { insertVideoInfo } from "lib/mq/task/getVideoInfo.ts";
|
||||
import { collectSongs } from "lib/mq/task/collectSongs.ts";
|
||||
|
||||
export const getLatestVideosWorker = async (_job: Job): Promise<void> => {
|
||||
const client = await db.connect();
|
||||
@ -10,3 +12,26 @@ export const getLatestVideosWorker = async (_job: Job): Promise<void> => {
|
||||
client.release();
|
||||
}
|
||||
};
|
||||
|
||||
export const collectSongsWorker = async (_job: Job): Promise<void> => {
|
||||
const client = await db.connect();
|
||||
try {
|
||||
await collectSongs(client);
|
||||
} finally {
|
||||
client.release();
|
||||
}
|
||||
};
|
||||
|
||||
export const getVideoInfoWorker = async (job: Job): Promise<number> => {
|
||||
const client = await db.connect();
|
||||
try {
|
||||
const aid = job.data.aid;
|
||||
if (!aid) {
|
||||
return 3;
|
||||
}
|
||||
await insertVideoInfo(client, aid);
|
||||
return 0;
|
||||
} finally {
|
||||
client.release();
|
||||
}
|
||||
};
|
||||
|
@ -1,17 +0,0 @@
|
||||
import { Job } from "bullmq";
|
||||
import { db } from "lib/db/init.ts";
|
||||
import { insertVideoInfo } from "lib/mq/task/getVideoInfo.ts";
|
||||
|
||||
export const getVideoInfoWorker = async (job: Job): Promise<number> => {
|
||||
const client = await db.connect();
|
||||
try {
|
||||
const aid = job.data.aid;
|
||||
if (!aid) {
|
||||
return 3;
|
||||
}
|
||||
await insertVideoInfo(client, aid);
|
||||
return 0;
|
||||
} finally {
|
||||
client.release();
|
||||
}
|
||||
};
|
@ -11,6 +11,10 @@ export async function initMQ() {
|
||||
every: 5 * MINUTE,
|
||||
immediately: true,
|
||||
});
|
||||
await LatestVideosQueue.upsertJobScheduler("collectSongs", {
|
||||
every: 3 * MINUTE,
|
||||
immediately: true,
|
||||
});
|
||||
|
||||
logger.log("Message queue initialized.");
|
||||
}
|
||||
|
29
lib/mq/task/collectSongs.ts
Normal file
29
lib/mq/task/collectSongs.ts
Normal file
@ -0,0 +1,29 @@
|
||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||
import { aidExistsInSongs, getNotCollectedSongs } from "lib/db/songs.ts";
|
||||
import logger from "lib/log/logger.ts";
|
||||
|
||||
export async function collectSongs(client: Client) {
|
||||
const aids = await getNotCollectedSongs(client);
|
||||
for (const aid of aids) {
|
||||
const exists = await aidExistsInSongs(client, aid);
|
||||
if (exists) continue;
|
||||
await insertIntoSongs(client, aid);
|
||||
logger.log(`Video ${aid} was added into the songs table.`, "mq", "fn:collectSongs");
|
||||
}
|
||||
}
|
||||
|
||||
export async function insertIntoSongs(client: Client, aid: number) {
|
||||
await client.queryObject(
|
||||
`
|
||||
INSERT INTO songs (aid, bvid, published_at, duration)
|
||||
VALUES (
|
||||
$1,
|
||||
(SELECT bvid FROM all_data WHERE aid = $1),
|
||||
(SELECT published_at FROM all_data WHERE aid = $1),
|
||||
(SELECT duration FROM all_data WHERE aid = $1)
|
||||
)
|
||||
ON CONFLICT DO NOTHING
|
||||
`,
|
||||
[aid],
|
||||
);
|
||||
}
|
@ -30,12 +30,11 @@ export async function insertVideoInfo(client: Client, aid: number) {
|
||||
);
|
||||
const userExists = await userExistsInBiliUsers(client, aid);
|
||||
if (!userExists) {
|
||||
await client.queryObject(
|
||||
await client.queryObject(
|
||||
`INSERT INTO bili_user (uid, username, "desc", fans) VALUES ($1, $2, $3, $4)`,
|
||||
[uid, data.View.owner.name, data.Card.card.sign, data.Card.follower],
|
||||
);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
await client.queryObject(
|
||||
`UPDATE bili_user SET fans = $1 WHERE uid = $2`,
|
||||
[data.Card.follower, uid],
|
||||
|
@ -26,12 +26,13 @@ export async function queueLatestVideos(
|
||||
if (videoExists) {
|
||||
continue;
|
||||
}
|
||||
await LatestVideosQueue.add("getVideoInfo", { aid }, { delay,
|
||||
await LatestVideosQueue.add("getVideoInfo", { aid }, {
|
||||
delay,
|
||||
attempts: 100,
|
||||
backoff: {
|
||||
type: "fixed",
|
||||
delay: SECOND * 5
|
||||
}
|
||||
delay: SECOND * 5,
|
||||
},
|
||||
});
|
||||
videosFound.add(aid);
|
||||
allExists = false;
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { Job, Worker } from "bullmq";
|
||||
import { getLatestVideosWorker } from "lib/mq/executors.ts";
|
||||
import { collectSongsWorker, getLatestVideosWorker } from "lib/mq/executors.ts";
|
||||
import { redis } from "lib/db/redis.ts";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { lockManager } from "lib/mq/lockManager.ts";
|
||||
import { WorkerError } from "lib/mq/schema.ts";
|
||||
import { getVideoInfoWorker } from "lib/mq/exec/getVideoInfo.ts";
|
||||
import { getVideoInfoWorker } from "lib/mq/exec/getLatestVideos.ts";
|
||||
|
||||
Deno.addSignalListener("SIGINT", async () => {
|
||||
logger.log("SIGINT Received: Shutting down workers...", "mq");
|
||||
@ -28,6 +28,9 @@ const latestVideoWorker = new Worker(
|
||||
case "getVideoInfo":
|
||||
await getVideoInfoWorker(job);
|
||||
break;
|
||||
case "collectSongs":
|
||||
await collectSongsWorker(job);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user