add: provider in NetScheduler, missing await

This commit is contained in:
alikia2x (寒寒) 2025-02-26 00:55:48 +08:00
parent 15312f4078
commit 232585594a
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
8 changed files with 164 additions and 136 deletions

1
.gitignore vendored
View File

@ -76,7 +76,6 @@ node_modules/
# project specific
.env
logs/
__pycache__
filter/runs

View File

@ -1,7 +1,7 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
import * as ort from "onnxruntime";
import logger from "lib/log/logger.ts";
import { WorkerError } from "../mq/schema.ts";
import {WorkerError} from "lib/mq/schema.ts";
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
@ -29,12 +29,11 @@ export async function initializeModels() {
sessionEmbedding = embeddingSession;
logger.log("Filter models initialized", "ml");
} catch (error) {
const e = new WorkerError(error as Error, "ml", "fn:initializeModels");
throw e;
throw new WorkerError(error as Error, "ml", "fn:initializeModels");
}
}
function softmax(logits: Float32Array): number[] {
export function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);

View File

@ -1,5 +1,6 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
import * as ort from "onnxruntime";
import {softmax} from "lib/ml/filter_inference.ts";
// 配置参数
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
@ -9,160 +10,156 @@ const onnxEmbeddingQuantizedPath = "./model/model.onnx";
// 初始化会话
const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = await Promise.all([
ort.InferenceSession.create(onnxClassifierPath),
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
ort.InferenceSession.create(onnxEmbeddingQuantizedPath)
ort.InferenceSession.create(onnxClassifierPath),
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
ort.InferenceSession.create(onnxEmbeddingQuantizedPath),
]);
let tokenizer: PreTrainedTokenizer;
// 初始化分词器
async function loadTokenizer() {
const tokenizerConfig = { local_files_only: true };
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig);
const tokenizerConfig = { local_files_only: true };
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig);
}
// 新的嵌入生成函数使用ONNX
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
const { input_ids } = await tokenizer(texts, {
add_special_tokens: false,
return_tensor: false
});
const { input_ids } = await tokenizer(texts, {
add_special_tokens: false,
return_tensor: false,
});
// 构造输入参数
const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const flattened_input_ids = input_ids.flat();
// 构造输入参数
const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
// 准备ONNX输入
const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length])
};
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const flattened_input_ids = input_ids.flat();
// 执行推理
const { embeddings } = await session.run(inputs);
return Array.from(embeddings.data as Float32Array);
}
// 准备ONNX输入
const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
flattened_input_ids.length,
]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
};
function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
// 执行推理
const { embeddings } = await session.run(inputs);
return Array.from(embeddings.data as Float32Array);
}
// 分类推理函数
async function runClassification(embeddings: number[]): Promise<number[]> {
const inputTensor = new ort.Tensor(
Float32Array.from(embeddings),
[1, 4, 1024]
);
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
return softmax(logits.data as Float32Array);
const inputTensor = new ort.Tensor(
Float32Array.from(embeddings),
[1, 4, 1024],
);
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
return softmax(logits.data as Float32Array);
}
// 指标计算函数
function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): {
accuracy: number,
precision: number,
recall: number,
f1: number,
speed: string
accuracy: number;
precision: number;
recall: number;
f1: number;
speed: string;
} {
// 初始化混淆矩阵
const classCount = Math.max(...labels, ...predictions) + 1;
const matrix = Array.from({ length: classCount }, () =>
Array.from({ length: classCount }, () => 0)
);
// 初始化混淆矩阵
const classCount = Math.max(...labels, ...predictions) + 1;
const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
// 填充矩阵
labels.forEach((trueLabel, i) => {
matrix[trueLabel][predictions[i]]++;
});
// 填充矩阵
labels.forEach((trueLabel, i) => {
matrix[trueLabel][predictions[i]]++;
});
// 计算各指标
let totalTP = 0, totalFP = 0, totalFN = 0;
for (let c = 0; c < classCount; c++) {
const TP = matrix[c][c];
const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0);
const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0);
// 计算各指标
let totalTP = 0, totalFP = 0, totalFN = 0;
totalTP += TP;
totalFP += FP;
totalFN += FN;
}
for (let c = 0; c < classCount; c++) {
const TP = matrix[c][c];
const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0);
const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0);
const precision = totalTP / (totalTP + totalFP);
const recall = totalTP / (totalTP + totalFN);
const f1 = 2 * (precision * recall) / (precision + recall) || 0;
totalTP += TP;
totalFP += FP;
totalFN += FN;
}
return {
accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length,
precision,
recall,
f1,
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`
};
const precision = totalTP / (totalTP + totalFP);
const recall = totalTP / (totalTP + totalFN);
const f1 = 2 * (precision * recall) / (precision + recall) || 0;
return {
accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length,
precision,
recall,
f1,
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`,
};
}
// 改造后的评估函数
async function evaluateModel(session: ort.InferenceSession): Promise<{
accuracy: number;
precision: number;
recall: number;
f1: number;
accuracy: number;
precision: number;
recall: number;
f1: number;
}> {
const data = await Deno.readTextFile("./data/filter/test.jsonl");
const samples = data.split("\n")
.map(line => {
try { return JSON.parse(line); }
catch { return null; }
})
.filter(Boolean);
const data = await Deno.readTextFile("./data/filter/test.jsonl");
const samples = data.split("\n")
.map((line) => {
try {
return JSON.parse(line);
} catch {
return null;
}
})
.filter(Boolean);
const allPredictions: number[] = [];
const allLabels: number[] = [];
const t = new Date().getTime();
for (const sample of samples) {
try {
const embeddings = await getONNXEmbeddings([
sample.title,
sample.description,
sample.tags.join(","),
sample.author_info
], session);
const allPredictions: number[] = [];
const allLabels: number[] = [];
const probabilities = await runClassification(embeddings);
allPredictions.push(probabilities.indexOf(Math.max(...probabilities)));
allLabels.push(sample.label);
} catch (error) {
console.error("Processing error:", error);
}
}
const elapsed = new Date().getTime() - t;
const t = new Date().getTime();
for (const sample of samples) {
try {
const embeddings = await getONNXEmbeddings([
sample.title,
sample.description,
sample.tags.join(","),
sample.author_info,
], session);
return calculateMetrics(allLabels, allPredictions, elapsed);
const probabilities = await runClassification(embeddings);
allPredictions.push(probabilities.indexOf(Math.max(...probabilities)));
allLabels.push(sample.label);
} catch (error) {
console.error("Processing error:", error);
}
}
const elapsed = new Date().getTime() - t;
return calculateMetrics(allLabels, allPredictions, elapsed);
}
// 主函数
async function main() {
await loadTokenizer();
await loadTokenizer();
// 评估原始模型
const originalMetrics = await evaluateModel(sessionEmbeddingOriginal);
console.log("Original Model Metrics:");
console.table(originalMetrics);
// 评估原始模型
const originalMetrics = await evaluateModel(sessionEmbeddingOriginal);
console.log("Original Model Metrics:");
console.table(originalMetrics);
// 评估量化模型
const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized);
console.log("Quantized Model Metrics:");
console.table(quantizedMetrics);
// 评估量化模型
const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized);
console.log("Quantized Model Metrics:");
console.table(quantizedMetrics);
}
await main();
await main();

View File

@ -22,11 +22,11 @@ export const classifyVideoWorker = async (job: Job) => {
if (label == -1) {
logger.warn(`Failed to classify video ${aid}`, "ml");
}
insertVideoLabel(client, aid, label);
await insertVideoLabel(client, aid, label);
client.release();
job.updateData({
await job.updateData({
...job.data, label: label,
});
@ -39,7 +39,7 @@ export const classifyVideosWorker = async () => {
return;
}
lockManager.acquireLock("classifyVideos");
await lockManager.acquireLock("classifyVideos");
const client = await db.connect();
const videos = await getUnlabelledVideos(client);
@ -49,12 +49,12 @@ export const classifyVideosWorker = async () => {
let i = 0;
for (const aid of videos) {
if (i > 200) {
lockManager.releaseLock("classifyVideos");
await lockManager.releaseLock("classifyVideos");
return 10000 + i;
}
await ClassifyVideoQueue.add("classifyVideo", { aid: Number(aid) });
i++;
}
lockManager.releaseLock("classifyVideos");
await lockManager.releaseLock("classifyVideos");
return 0;
};

View File

@ -37,7 +37,7 @@ export const getLatestVideosWorker = async (job: Job) => {
return;
}
lockManager.acquireLock("getLatestVideos");
await lockManager.acquireLock("getLatestVideos");
const failedCount = (job.data.failedCount ?? 0) as number;
const client = await db.connect();
@ -46,7 +46,7 @@ export const getLatestVideosWorker = async (job: Job) => {
await executeTask(client, failedCount);
} finally {
client.release();
lockManager.releaseLock("getLatestVideos");
await lockManager.releaseLock("getLatestVideos");
}
return;
};

View File

@ -8,7 +8,7 @@ import logger from "lib/log/logger.ts";
import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts";
import { getVideoTags } from "lib/net/getVideoTags.ts";
import { NetSchedulerError } from "lib/mq/scheduler.ts";
import { WorkerError } from "../schema.ts";
import { WorkerError } from "lib/mq/schema.ts";
const delayMap = [0.5, 3, 5, 15, 30, 60];
const getJobPriority = (diff: number) => {

View File

@ -1,4 +1,4 @@
import { MINUTE, SECOND } from "$std/datetime/constants.ts";
import { MINUTE } from "$std/datetime/constants.ts";
import { ClassifyVideoQueue, LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts";
import logger from "lib/log/logger.ts";
@ -7,11 +7,11 @@ export async function initMQ() {
every: 1 * MINUTE
});
await VideoTagsQueue.upsertJobScheduler("getVideosTags", {
every: 30 * SECOND,
every: 5 * MINUTE,
immediately: true,
});
await ClassifyVideoQueue.upsertJobScheduler("classifyVideos", {
every: 30 * SECOND,
every: 5 * MINUTE,
immediately: true,
})

View File

@ -7,6 +7,7 @@ import Redis from "ioredis";
interface Proxy {
type: string;
task: string;
provider: string;
limiter?: RateLimiter;
}
@ -32,11 +33,16 @@ export class NetSchedulerError extends Error {
}
}
interface LimiterMap {
[name: string]: RateLimiter;
}
class NetScheduler {
private proxies: ProxiesMap = {};
private providerLimiters: LimiterMap = {};
addProxy(name: string, type: string, task: string): void {
this.proxies[name] = { type, task };
addProxy(name: string, type: string, task: string, provider: string): void {
this.proxies[name] = { type, task, provider };
}
removeProxy(name: string): void {
@ -47,6 +53,10 @@ class NetScheduler {
this.proxies[name].limiter = limiter;
}
setProviderLimiter(name: string, limiter: RateLimiter): void {
this.providerLimiters[name] = limiter;
}
/*
* Make a request to the specified URL with any available proxy
* @param {string} url - The URL to request.
@ -117,7 +127,15 @@ class NetScheduler {
private async getProxyAvailability(name: string): Promise<boolean> {
try {
const proxyConfig = this.proxies[name];
if (!proxyConfig || !proxyConfig.limiter) {
if (!proxyConfig) {
return true;
}
const provider = proxyConfig.provider;
const providerLimiter = await this.providerLimiters[provider].getAvailability();
if (!providerLimiter) {
return false;
}
if (!proxyConfig.limiter) {
return true;
}
return await proxyConfig.limiter.getAvailability();
@ -143,8 +161,8 @@ class NetScheduler {
}
const netScheduler = new NetScheduler();
netScheduler.addProxy("default", "native", "default");
netScheduler.addProxy("tags-native", "native", "getVideoTags");
netScheduler.addProxy("default", "native", "default", "bilibili-native");
netScheduler.addProxy("tags-native", "native", "getVideoTags", "bilibili-native");
const tagsRateLimiter = new RateLimiter("getVideoTags", [
{
window: new SlidingWindow(redis, 1),
@ -159,6 +177,21 @@ const tagsRateLimiter = new RateLimiter("getVideoTags", [
max: 50,
},
]);
const biliLimiterNative = new RateLimiter("bilibili-native", [
{
window: new SlidingWindow(redis, 1),
max: 5
},
{
window: new SlidingWindow(redis, 30),
max: 100
},
{
window: new SlidingWindow(redis, 5 * 60),
max: 180
}
]);
netScheduler.setProxyLimiter("tags-native", tagsRateLimiter);
netScheduler.setProviderLimiter("bilibili-native", biliLimiterNative)
export default netScheduler;