From 232585594a26d4508602790251aecc30ffcec1c1 Mon Sep 17 00:00:00 2001 From: alikia2x Date: Wed, 26 Feb 2025 00:55:48 +0800 Subject: [PATCH] add: provider in NetScheduler, missing `await` --- .gitignore | 1 - lib/ml/filter_inference.ts | 9 +- lib/ml/quant_benchmark.ts | 225 ++++++++++++++++----------------- lib/mq/exec/classifyVideo.ts | 10 +- lib/mq/exec/getLatestVideos.ts | 4 +- lib/mq/exec/getVideoTags.ts | 2 +- lib/mq/init.ts | 6 +- lib/mq/scheduler.ts | 43 ++++++- 8 files changed, 164 insertions(+), 136 deletions(-) diff --git a/.gitignore b/.gitignore index fb075e0..b27a6b6 100644 --- a/.gitignore +++ b/.gitignore @@ -76,7 +76,6 @@ node_modules/ # project specific -.env logs/ __pycache__ filter/runs diff --git a/lib/ml/filter_inference.ts b/lib/ml/filter_inference.ts index e615bcd..da9ed4a 100644 --- a/lib/ml/filter_inference.ts +++ b/lib/ml/filter_inference.ts @@ -1,7 +1,7 @@ -import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; +import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers"; import * as ort from "onnxruntime"; import logger from "lib/log/logger.ts"; -import { WorkerError } from "../mq/schema.ts"; +import {WorkerError} from "lib/mq/schema.ts"; const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; const onnxClassifierPath = "./model/video_classifier_v3_11.onnx"; @@ -29,12 +29,11 @@ export async function initializeModels() { sessionEmbedding = embeddingSession; logger.log("Filter models initialized", "ml"); } catch (error) { - const e = new WorkerError(error as Error, "ml", "fn:initializeModels"); - throw e; + throw new WorkerError(error as Error, "ml", "fn:initializeModels"); } } -function softmax(logits: Float32Array): number[] { +export function softmax(logits: Float32Array): number[] { const maxLogit = Math.max(...logits); const exponents = logits.map((logit) => Math.exp(logit - maxLogit)); const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0); diff --git a/lib/ml/quant_benchmark.ts b/lib/ml/quant_benchmark.ts index 07777c2..f75bf9b 100644 --- a/lib/ml/quant_benchmark.ts +++ b/lib/ml/quant_benchmark.ts @@ -1,5 +1,6 @@ -import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; +import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers"; import * as ort from "onnxruntime"; +import {softmax} from "lib/ml/filter_inference.ts"; // 配置参数 const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024"; @@ -9,160 +10,156 @@ const onnxEmbeddingQuantizedPath = "./model/model.onnx"; // 初始化会话 const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = await Promise.all([ - ort.InferenceSession.create(onnxClassifierPath), - ort.InferenceSession.create(onnxEmbeddingOriginalPath), - ort.InferenceSession.create(onnxEmbeddingQuantizedPath) + ort.InferenceSession.create(onnxClassifierPath), + ort.InferenceSession.create(onnxEmbeddingOriginalPath), + ort.InferenceSession.create(onnxEmbeddingQuantizedPath), ]); let tokenizer: PreTrainedTokenizer; // 初始化分词器 async function loadTokenizer() { - const tokenizerConfig = { local_files_only: true }; - tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig); + const tokenizerConfig = { local_files_only: true }; + tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig); } // 新的嵌入生成函数(使用ONNX) async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise { - const { input_ids } = await tokenizer(texts, { - add_special_tokens: false, - return_tensor: false - }); + const { input_ids } = await tokenizer(texts, { + add_special_tokens: false, + return_tensor: false, + }); - // 构造输入参数 - const cumsum = (arr: number[]): number[] => - arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); - - const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))]; - const flattened_input_ids = input_ids.flat(); + // 构造输入参数 + const cumsum = (arr: number[]): number[] => + arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); - // 准备ONNX输入 - const inputs = { - input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]), - offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]) - }; + const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))]; + const flattened_input_ids = input_ids.flat(); - // 执行推理 - const { embeddings } = await session.run(inputs); - return Array.from(embeddings.data as Float32Array); -} + // 准备ONNX输入 + const inputs = { + input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ + flattened_input_ids.length, + ]), + offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]), + }; -function softmax(logits: Float32Array): number[] { - const maxLogit = Math.max(...logits); - const exponents = logits.map((logit) => Math.exp(logit - maxLogit)); - const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0); - return Array.from(exponents.map((exp) => exp / sumOfExponents)); + // 执行推理 + const { embeddings } = await session.run(inputs); + return Array.from(embeddings.data as Float32Array); } // 分类推理函数 async function runClassification(embeddings: number[]): Promise { - const inputTensor = new ort.Tensor( - Float32Array.from(embeddings), - [1, 4, 1024] - ); - - const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); - return softmax(logits.data as Float32Array); + const inputTensor = new ort.Tensor( + Float32Array.from(embeddings), + [1, 4, 1024], + ); + + const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); + return softmax(logits.data as Float32Array); } // 指标计算函数 function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): { - accuracy: number, - precision: number, - recall: number, - f1: number, - speed: string + accuracy: number; + precision: number; + recall: number; + f1: number; + speed: string; } { - // 初始化混淆矩阵 - const classCount = Math.max(...labels, ...predictions) + 1; - const matrix = Array.from({ length: classCount }, () => - Array.from({ length: classCount }, () => 0) - ); + // 初始化混淆矩阵 + const classCount = Math.max(...labels, ...predictions) + 1; + const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0)); - // 填充矩阵 - labels.forEach((trueLabel, i) => { - matrix[trueLabel][predictions[i]]++; - }); + // 填充矩阵 + labels.forEach((trueLabel, i) => { + matrix[trueLabel][predictions[i]]++; + }); - // 计算各指标 - let totalTP = 0, totalFP = 0, totalFN = 0; - - for (let c = 0; c < classCount; c++) { - const TP = matrix[c][c]; - const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0); - const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0); + // 计算各指标 + let totalTP = 0, totalFP = 0, totalFN = 0; - totalTP += TP; - totalFP += FP; - totalFN += FN; - } + for (let c = 0; c < classCount; c++) { + const TP = matrix[c][c]; + const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0); + const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0); - const precision = totalTP / (totalTP + totalFP); - const recall = totalTP / (totalTP + totalFN); - const f1 = 2 * (precision * recall) / (precision + recall) || 0; + totalTP += TP; + totalFP += FP; + totalFN += FN; + } - return { - accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length, - precision, - recall, - f1, - speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec` - }; + const precision = totalTP / (totalTP + totalFP); + const recall = totalTP / (totalTP + totalFN); + const f1 = 2 * (precision * recall) / (precision + recall) || 0; + + return { + accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length, + precision, + recall, + f1, + speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`, + }; } // 改造后的评估函数 async function evaluateModel(session: ort.InferenceSession): Promise<{ - accuracy: number; - precision: number; - recall: number; - f1: number; + accuracy: number; + precision: number; + recall: number; + f1: number; }> { - const data = await Deno.readTextFile("./data/filter/test.jsonl"); - const samples = data.split("\n") - .map(line => { - try { return JSON.parse(line); } - catch { return null; } - }) - .filter(Boolean); + const data = await Deno.readTextFile("./data/filter/test.jsonl"); + const samples = data.split("\n") + .map((line) => { + try { + return JSON.parse(line); + } catch { + return null; + } + }) + .filter(Boolean); - const allPredictions: number[] = []; - const allLabels: number[] = []; - - const t = new Date().getTime(); - for (const sample of samples) { - try { - const embeddings = await getONNXEmbeddings([ - sample.title, - sample.description, - sample.tags.join(","), - sample.author_info - ], session); + const allPredictions: number[] = []; + const allLabels: number[] = []; - const probabilities = await runClassification(embeddings); - allPredictions.push(probabilities.indexOf(Math.max(...probabilities))); - allLabels.push(sample.label); - } catch (error) { - console.error("Processing error:", error); - } - } - const elapsed = new Date().getTime() - t; + const t = new Date().getTime(); + for (const sample of samples) { + try { + const embeddings = await getONNXEmbeddings([ + sample.title, + sample.description, + sample.tags.join(","), + sample.author_info, + ], session); - return calculateMetrics(allLabels, allPredictions, elapsed); + const probabilities = await runClassification(embeddings); + allPredictions.push(probabilities.indexOf(Math.max(...probabilities))); + allLabels.push(sample.label); + } catch (error) { + console.error("Processing error:", error); + } + } + const elapsed = new Date().getTime() - t; + + return calculateMetrics(allLabels, allPredictions, elapsed); } // 主函数 async function main() { - await loadTokenizer(); + await loadTokenizer(); - // 评估原始模型 - const originalMetrics = await evaluateModel(sessionEmbeddingOriginal); - console.log("Original Model Metrics:"); - console.table(originalMetrics); + // 评估原始模型 + const originalMetrics = await evaluateModel(sessionEmbeddingOriginal); + console.log("Original Model Metrics:"); + console.table(originalMetrics); - // 评估量化模型 - const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized); - console.log("Quantized Model Metrics:"); - console.table(quantizedMetrics); + // 评估量化模型 + const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized); + console.log("Quantized Model Metrics:"); + console.table(quantizedMetrics); } -await main(); \ No newline at end of file +await main(); diff --git a/lib/mq/exec/classifyVideo.ts b/lib/mq/exec/classifyVideo.ts index d2a9c70..df45def 100644 --- a/lib/mq/exec/classifyVideo.ts +++ b/lib/mq/exec/classifyVideo.ts @@ -22,11 +22,11 @@ export const classifyVideoWorker = async (job: Job) => { if (label == -1) { logger.warn(`Failed to classify video ${aid}`, "ml"); } - insertVideoLabel(client, aid, label); + await insertVideoLabel(client, aid, label); client.release(); - job.updateData({ + await job.updateData({ ...job.data, label: label, }); @@ -39,7 +39,7 @@ export const classifyVideosWorker = async () => { return; } - lockManager.acquireLock("classifyVideos"); + await lockManager.acquireLock("classifyVideos"); const client = await db.connect(); const videos = await getUnlabelledVideos(client); @@ -49,12 +49,12 @@ export const classifyVideosWorker = async () => { let i = 0; for (const aid of videos) { if (i > 200) { - lockManager.releaseLock("classifyVideos"); + await lockManager.releaseLock("classifyVideos"); return 10000 + i; } await ClassifyVideoQueue.add("classifyVideo", { aid: Number(aid) }); i++; } - lockManager.releaseLock("classifyVideos"); + await lockManager.releaseLock("classifyVideos"); return 0; }; diff --git a/lib/mq/exec/getLatestVideos.ts b/lib/mq/exec/getLatestVideos.ts index 08bad1c..17d7677 100644 --- a/lib/mq/exec/getLatestVideos.ts +++ b/lib/mq/exec/getLatestVideos.ts @@ -37,7 +37,7 @@ export const getLatestVideosWorker = async (job: Job) => { return; } - lockManager.acquireLock("getLatestVideos"); + await lockManager.acquireLock("getLatestVideos"); const failedCount = (job.data.failedCount ?? 0) as number; const client = await db.connect(); @@ -46,7 +46,7 @@ export const getLatestVideosWorker = async (job: Job) => { await executeTask(client, failedCount); } finally { client.release(); - lockManager.releaseLock("getLatestVideos"); + await lockManager.releaseLock("getLatestVideos"); } return; }; diff --git a/lib/mq/exec/getVideoTags.ts b/lib/mq/exec/getVideoTags.ts index 1608098..83fe26f 100644 --- a/lib/mq/exec/getVideoTags.ts +++ b/lib/mq/exec/getVideoTags.ts @@ -8,7 +8,7 @@ import logger from "lib/log/logger.ts"; import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts"; import { getVideoTags } from "lib/net/getVideoTags.ts"; import { NetSchedulerError } from "lib/mq/scheduler.ts"; -import { WorkerError } from "../schema.ts"; +import { WorkerError } from "lib/mq/schema.ts"; const delayMap = [0.5, 3, 5, 15, 30, 60]; const getJobPriority = (diff: number) => { diff --git a/lib/mq/init.ts b/lib/mq/init.ts index 336d843..fbfaa54 100644 --- a/lib/mq/init.ts +++ b/lib/mq/init.ts @@ -1,4 +1,4 @@ -import { MINUTE, SECOND } from "$std/datetime/constants.ts"; +import { MINUTE } from "$std/datetime/constants.ts"; import { ClassifyVideoQueue, LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts"; import logger from "lib/log/logger.ts"; @@ -7,11 +7,11 @@ export async function initMQ() { every: 1 * MINUTE }); await VideoTagsQueue.upsertJobScheduler("getVideosTags", { - every: 30 * SECOND, + every: 5 * MINUTE, immediately: true, }); await ClassifyVideoQueue.upsertJobScheduler("classifyVideos", { - every: 30 * SECOND, + every: 5 * MINUTE, immediately: true, }) diff --git a/lib/mq/scheduler.ts b/lib/mq/scheduler.ts index 25e7705..14ed12f 100644 --- a/lib/mq/scheduler.ts +++ b/lib/mq/scheduler.ts @@ -7,6 +7,7 @@ import Redis from "ioredis"; interface Proxy { type: string; task: string; + provider: string; limiter?: RateLimiter; } @@ -32,11 +33,16 @@ export class NetSchedulerError extends Error { } } +interface LimiterMap { + [name: string]: RateLimiter; +} + class NetScheduler { private proxies: ProxiesMap = {}; + private providerLimiters: LimiterMap = {}; - addProxy(name: string, type: string, task: string): void { - this.proxies[name] = { type, task }; + addProxy(name: string, type: string, task: string, provider: string): void { + this.proxies[name] = { type, task, provider }; } removeProxy(name: string): void { @@ -47,6 +53,10 @@ class NetScheduler { this.proxies[name].limiter = limiter; } + setProviderLimiter(name: string, limiter: RateLimiter): void { + this.providerLimiters[name] = limiter; + } + /* * Make a request to the specified URL with any available proxy * @param {string} url - The URL to request. @@ -117,7 +127,15 @@ class NetScheduler { private async getProxyAvailability(name: string): Promise { try { const proxyConfig = this.proxies[name]; - if (!proxyConfig || !proxyConfig.limiter) { + if (!proxyConfig) { + return true; + } + const provider = proxyConfig.provider; + const providerLimiter = await this.providerLimiters[provider].getAvailability(); + if (!providerLimiter) { + return false; + } + if (!proxyConfig.limiter) { return true; } return await proxyConfig.limiter.getAvailability(); @@ -143,8 +161,8 @@ class NetScheduler { } const netScheduler = new NetScheduler(); -netScheduler.addProxy("default", "native", "default"); -netScheduler.addProxy("tags-native", "native", "getVideoTags"); +netScheduler.addProxy("default", "native", "default", "bilibili-native"); +netScheduler.addProxy("tags-native", "native", "getVideoTags", "bilibili-native"); const tagsRateLimiter = new RateLimiter("getVideoTags", [ { window: new SlidingWindow(redis, 1), @@ -159,6 +177,21 @@ const tagsRateLimiter = new RateLimiter("getVideoTags", [ max: 50, }, ]); +const biliLimiterNative = new RateLimiter("bilibili-native", [ + { + window: new SlidingWindow(redis, 1), + max: 5 + }, + { + window: new SlidingWindow(redis, 30), + max: 100 + }, + { + window: new SlidingWindow(redis, 5 * 60), + max: 180 + } +]); netScheduler.setProxyLimiter("tags-native", tagsRateLimiter); +netScheduler.setProviderLimiter("bilibili-native", biliLimiterNative) export default netScheduler;