add: provider in NetScheduler, missing await

This commit is contained in:
alikia2x (寒寒) 2025-02-26 00:55:48 +08:00
parent 15312f4078
commit 232585594a
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
8 changed files with 164 additions and 136 deletions

1
.gitignore vendored
View File

@ -76,7 +76,6 @@ node_modules/
# project specific # project specific
.env
logs/ logs/
__pycache__ __pycache__
filter/runs filter/runs

View File

@ -1,7 +1,7 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
import * as ort from "onnxruntime"; import * as ort from "onnxruntime";
import logger from "lib/log/logger.ts"; import logger from "lib/log/logger.ts";
import { WorkerError } from "../mq/schema.ts"; import {WorkerError} from "lib/mq/schema.ts";
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx"; const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
@ -29,12 +29,11 @@ export async function initializeModels() {
sessionEmbedding = embeddingSession; sessionEmbedding = embeddingSession;
logger.log("Filter models initialized", "ml"); logger.log("Filter models initialized", "ml");
} catch (error) { } catch (error) {
const e = new WorkerError(error as Error, "ml", "fn:initializeModels"); throw new WorkerError(error as Error, "ml", "fn:initializeModels");
throw e;
} }
} }
function softmax(logits: Float32Array): number[] { export function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits); const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit)); const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0); const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);

View File

@ -1,5 +1,6 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
import * as ort from "onnxruntime"; import * as ort from "onnxruntime";
import {softmax} from "lib/ml/filter_inference.ts";
// 配置参数 // 配置参数
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024"; const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
@ -9,160 +10,156 @@ const onnxEmbeddingQuantizedPath = "./model/model.onnx";
// 初始化会话 // 初始化会话
const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = await Promise.all([ const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = await Promise.all([
ort.InferenceSession.create(onnxClassifierPath), ort.InferenceSession.create(onnxClassifierPath),
ort.InferenceSession.create(onnxEmbeddingOriginalPath), ort.InferenceSession.create(onnxEmbeddingOriginalPath),
ort.InferenceSession.create(onnxEmbeddingQuantizedPath) ort.InferenceSession.create(onnxEmbeddingQuantizedPath),
]); ]);
let tokenizer: PreTrainedTokenizer; let tokenizer: PreTrainedTokenizer;
// 初始化分词器 // 初始化分词器
async function loadTokenizer() { async function loadTokenizer() {
const tokenizerConfig = { local_files_only: true }; const tokenizerConfig = { local_files_only: true };
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig); tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig);
} }
// 新的嵌入生成函数使用ONNX // 新的嵌入生成函数使用ONNX
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> { async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
const { input_ids } = await tokenizer(texts, { const { input_ids } = await tokenizer(texts, {
add_special_tokens: false, add_special_tokens: false,
return_tensor: false return_tensor: false,
}); });
// 构造输入参数 // 构造输入参数
const cumsum = (arr: number[]): number[] => const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const flattened_input_ids = input_ids.flat();
// 准备ONNX输入 const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const inputs = { const flattened_input_ids = input_ids.flat();
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length])
};
// 执行推理 // 准备ONNX输入
const { embeddings } = await session.run(inputs); const inputs = {
return Array.from(embeddings.data as Float32Array); input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
} flattened_input_ids.length,
]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
};
function softmax(logits: Float32Array): number[] { // 执行推理
const maxLogit = Math.max(...logits); const { embeddings } = await session.run(inputs);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit)); return Array.from(embeddings.data as Float32Array);
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
} }
// 分类推理函数 // 分类推理函数
async function runClassification(embeddings: number[]): Promise<number[]> { async function runClassification(embeddings: number[]): Promise<number[]> {
const inputTensor = new ort.Tensor( const inputTensor = new ort.Tensor(
Float32Array.from(embeddings), Float32Array.from(embeddings),
[1, 4, 1024] [1, 4, 1024],
); );
const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
return softmax(logits.data as Float32Array); return softmax(logits.data as Float32Array);
} }
// 指标计算函数 // 指标计算函数
function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): { function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): {
accuracy: number, accuracy: number;
precision: number, precision: number;
recall: number, recall: number;
f1: number, f1: number;
speed: string speed: string;
} { } {
// 初始化混淆矩阵 // 初始化混淆矩阵
const classCount = Math.max(...labels, ...predictions) + 1; const classCount = Math.max(...labels, ...predictions) + 1;
const matrix = Array.from({ length: classCount }, () => const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
Array.from({ length: classCount }, () => 0)
);
// 填充矩阵 // 填充矩阵
labels.forEach((trueLabel, i) => { labels.forEach((trueLabel, i) => {
matrix[trueLabel][predictions[i]]++; matrix[trueLabel][predictions[i]]++;
}); });
// 计算各指标 // 计算各指标
let totalTP = 0, totalFP = 0, totalFN = 0; let totalTP = 0, totalFP = 0, totalFN = 0;
for (let c = 0; c < classCount; c++) {
const TP = matrix[c][c];
const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0);
const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0);
totalTP += TP; for (let c = 0; c < classCount; c++) {
totalFP += FP; const TP = matrix[c][c];
totalFN += FN; const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0);
} const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0);
const precision = totalTP / (totalTP + totalFP); totalTP += TP;
const recall = totalTP / (totalTP + totalFN); totalFP += FP;
const f1 = 2 * (precision * recall) / (precision + recall) || 0; totalFN += FN;
}
return { const precision = totalTP / (totalTP + totalFP);
accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length, const recall = totalTP / (totalTP + totalFN);
precision, const f1 = 2 * (precision * recall) / (precision + recall) || 0;
recall,
f1, return {
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec` accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length,
}; precision,
recall,
f1,
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`,
};
} }
// 改造后的评估函数 // 改造后的评估函数
async function evaluateModel(session: ort.InferenceSession): Promise<{ async function evaluateModel(session: ort.InferenceSession): Promise<{
accuracy: number; accuracy: number;
precision: number; precision: number;
recall: number; recall: number;
f1: number; f1: number;
}> { }> {
const data = await Deno.readTextFile("./data/filter/test.jsonl"); const data = await Deno.readTextFile("./data/filter/test.jsonl");
const samples = data.split("\n") const samples = data.split("\n")
.map(line => { .map((line) => {
try { return JSON.parse(line); } try {
catch { return null; } return JSON.parse(line);
}) } catch {
.filter(Boolean); return null;
}
})
.filter(Boolean);
const allPredictions: number[] = []; const allPredictions: number[] = [];
const allLabels: number[] = []; const allLabels: number[] = [];
const t = new Date().getTime();
for (const sample of samples) {
try {
const embeddings = await getONNXEmbeddings([
sample.title,
sample.description,
sample.tags.join(","),
sample.author_info
], session);
const probabilities = await runClassification(embeddings); const t = new Date().getTime();
allPredictions.push(probabilities.indexOf(Math.max(...probabilities))); for (const sample of samples) {
allLabels.push(sample.label); try {
} catch (error) { const embeddings = await getONNXEmbeddings([
console.error("Processing error:", error); sample.title,
} sample.description,
} sample.tags.join(","),
const elapsed = new Date().getTime() - t; sample.author_info,
], session);
return calculateMetrics(allLabels, allPredictions, elapsed); const probabilities = await runClassification(embeddings);
allPredictions.push(probabilities.indexOf(Math.max(...probabilities)));
allLabels.push(sample.label);
} catch (error) {
console.error("Processing error:", error);
}
}
const elapsed = new Date().getTime() - t;
return calculateMetrics(allLabels, allPredictions, elapsed);
} }
// 主函数 // 主函数
async function main() { async function main() {
await loadTokenizer(); await loadTokenizer();
// 评估原始模型 // 评估原始模型
const originalMetrics = await evaluateModel(sessionEmbeddingOriginal); const originalMetrics = await evaluateModel(sessionEmbeddingOriginal);
console.log("Original Model Metrics:"); console.log("Original Model Metrics:");
console.table(originalMetrics); console.table(originalMetrics);
// 评估量化模型 // 评估量化模型
const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized); const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized);
console.log("Quantized Model Metrics:"); console.log("Quantized Model Metrics:");
console.table(quantizedMetrics); console.table(quantizedMetrics);
} }
await main(); await main();

View File

@ -22,11 +22,11 @@ export const classifyVideoWorker = async (job: Job) => {
if (label == -1) { if (label == -1) {
logger.warn(`Failed to classify video ${aid}`, "ml"); logger.warn(`Failed to classify video ${aid}`, "ml");
} }
insertVideoLabel(client, aid, label); await insertVideoLabel(client, aid, label);
client.release(); client.release();
job.updateData({ await job.updateData({
...job.data, label: label, ...job.data, label: label,
}); });
@ -39,7 +39,7 @@ export const classifyVideosWorker = async () => {
return; return;
} }
lockManager.acquireLock("classifyVideos"); await lockManager.acquireLock("classifyVideos");
const client = await db.connect(); const client = await db.connect();
const videos = await getUnlabelledVideos(client); const videos = await getUnlabelledVideos(client);
@ -49,12 +49,12 @@ export const classifyVideosWorker = async () => {
let i = 0; let i = 0;
for (const aid of videos) { for (const aid of videos) {
if (i > 200) { if (i > 200) {
lockManager.releaseLock("classifyVideos"); await lockManager.releaseLock("classifyVideos");
return 10000 + i; return 10000 + i;
} }
await ClassifyVideoQueue.add("classifyVideo", { aid: Number(aid) }); await ClassifyVideoQueue.add("classifyVideo", { aid: Number(aid) });
i++; i++;
} }
lockManager.releaseLock("classifyVideos"); await lockManager.releaseLock("classifyVideos");
return 0; return 0;
}; };

View File

@ -37,7 +37,7 @@ export const getLatestVideosWorker = async (job: Job) => {
return; return;
} }
lockManager.acquireLock("getLatestVideos"); await lockManager.acquireLock("getLatestVideos");
const failedCount = (job.data.failedCount ?? 0) as number; const failedCount = (job.data.failedCount ?? 0) as number;
const client = await db.connect(); const client = await db.connect();
@ -46,7 +46,7 @@ export const getLatestVideosWorker = async (job: Job) => {
await executeTask(client, failedCount); await executeTask(client, failedCount);
} finally { } finally {
client.release(); client.release();
lockManager.releaseLock("getLatestVideos"); await lockManager.releaseLock("getLatestVideos");
} }
return; return;
}; };

View File

@ -8,7 +8,7 @@ import logger from "lib/log/logger.ts";
import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts"; import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts";
import { getVideoTags } from "lib/net/getVideoTags.ts"; import { getVideoTags } from "lib/net/getVideoTags.ts";
import { NetSchedulerError } from "lib/mq/scheduler.ts"; import { NetSchedulerError } from "lib/mq/scheduler.ts";
import { WorkerError } from "../schema.ts"; import { WorkerError } from "lib/mq/schema.ts";
const delayMap = [0.5, 3, 5, 15, 30, 60]; const delayMap = [0.5, 3, 5, 15, 30, 60];
const getJobPriority = (diff: number) => { const getJobPriority = (diff: number) => {

View File

@ -1,4 +1,4 @@
import { MINUTE, SECOND } from "$std/datetime/constants.ts"; import { MINUTE } from "$std/datetime/constants.ts";
import { ClassifyVideoQueue, LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts"; import { ClassifyVideoQueue, LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts";
import logger from "lib/log/logger.ts"; import logger from "lib/log/logger.ts";
@ -7,11 +7,11 @@ export async function initMQ() {
every: 1 * MINUTE every: 1 * MINUTE
}); });
await VideoTagsQueue.upsertJobScheduler("getVideosTags", { await VideoTagsQueue.upsertJobScheduler("getVideosTags", {
every: 30 * SECOND, every: 5 * MINUTE,
immediately: true, immediately: true,
}); });
await ClassifyVideoQueue.upsertJobScheduler("classifyVideos", { await ClassifyVideoQueue.upsertJobScheduler("classifyVideos", {
every: 30 * SECOND, every: 5 * MINUTE,
immediately: true, immediately: true,
}) })

View File

@ -7,6 +7,7 @@ import Redis from "ioredis";
interface Proxy { interface Proxy {
type: string; type: string;
task: string; task: string;
provider: string;
limiter?: RateLimiter; limiter?: RateLimiter;
} }
@ -32,11 +33,16 @@ export class NetSchedulerError extends Error {
} }
} }
interface LimiterMap {
[name: string]: RateLimiter;
}
class NetScheduler { class NetScheduler {
private proxies: ProxiesMap = {}; private proxies: ProxiesMap = {};
private providerLimiters: LimiterMap = {};
addProxy(name: string, type: string, task: string): void { addProxy(name: string, type: string, task: string, provider: string): void {
this.proxies[name] = { type, task }; this.proxies[name] = { type, task, provider };
} }
removeProxy(name: string): void { removeProxy(name: string): void {
@ -47,6 +53,10 @@ class NetScheduler {
this.proxies[name].limiter = limiter; this.proxies[name].limiter = limiter;
} }
setProviderLimiter(name: string, limiter: RateLimiter): void {
this.providerLimiters[name] = limiter;
}
/* /*
* Make a request to the specified URL with any available proxy * Make a request to the specified URL with any available proxy
* @param {string} url - The URL to request. * @param {string} url - The URL to request.
@ -117,7 +127,15 @@ class NetScheduler {
private async getProxyAvailability(name: string): Promise<boolean> { private async getProxyAvailability(name: string): Promise<boolean> {
try { try {
const proxyConfig = this.proxies[name]; const proxyConfig = this.proxies[name];
if (!proxyConfig || !proxyConfig.limiter) { if (!proxyConfig) {
return true;
}
const provider = proxyConfig.provider;
const providerLimiter = await this.providerLimiters[provider].getAvailability();
if (!providerLimiter) {
return false;
}
if (!proxyConfig.limiter) {
return true; return true;
} }
return await proxyConfig.limiter.getAvailability(); return await proxyConfig.limiter.getAvailability();
@ -143,8 +161,8 @@ class NetScheduler {
} }
const netScheduler = new NetScheduler(); const netScheduler = new NetScheduler();
netScheduler.addProxy("default", "native", "default"); netScheduler.addProxy("default", "native", "default", "bilibili-native");
netScheduler.addProxy("tags-native", "native", "getVideoTags"); netScheduler.addProxy("tags-native", "native", "getVideoTags", "bilibili-native");
const tagsRateLimiter = new RateLimiter("getVideoTags", [ const tagsRateLimiter = new RateLimiter("getVideoTags", [
{ {
window: new SlidingWindow(redis, 1), window: new SlidingWindow(redis, 1),
@ -159,6 +177,21 @@ const tagsRateLimiter = new RateLimiter("getVideoTags", [
max: 50, max: 50,
}, },
]); ]);
const biliLimiterNative = new RateLimiter("bilibili-native", [
{
window: new SlidingWindow(redis, 1),
max: 5
},
{
window: new SlidingWindow(redis, 30),
max: 100
},
{
window: new SlidingWindow(redis, 5 * 60),
max: 180
}
]);
netScheduler.setProxyLimiter("tags-native", tagsRateLimiter); netScheduler.setProxyLimiter("tags-native", tagsRateLimiter);
netScheduler.setProviderLimiter("bilibili-native", biliLimiterNative)
export default netScheduler; export default netScheduler;