add: provider in NetScheduler, missing await

This commit is contained in:
alikia2x (寒寒) 2025-02-26 00:55:48 +08:00
parent 15312f4078
commit 232585594a
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
8 changed files with 164 additions and 136 deletions

1
.gitignore vendored
View File

@ -76,7 +76,6 @@ node_modules/
# project specific # project specific
.env
logs/ logs/
__pycache__ __pycache__
filter/runs filter/runs

View File

@ -1,7 +1,7 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
import * as ort from "onnxruntime"; import * as ort from "onnxruntime";
import logger from "lib/log/logger.ts"; import logger from "lib/log/logger.ts";
import { WorkerError } from "../mq/schema.ts"; import {WorkerError} from "lib/mq/schema.ts";
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx"; const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
@ -29,12 +29,11 @@ export async function initializeModels() {
sessionEmbedding = embeddingSession; sessionEmbedding = embeddingSession;
logger.log("Filter models initialized", "ml"); logger.log("Filter models initialized", "ml");
} catch (error) { } catch (error) {
const e = new WorkerError(error as Error, "ml", "fn:initializeModels"); throw new WorkerError(error as Error, "ml", "fn:initializeModels");
throw e;
} }
} }
function softmax(logits: Float32Array): number[] { export function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits); const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit)); const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0); const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);

View File

@ -1,5 +1,6 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
import * as ort from "onnxruntime"; import * as ort from "onnxruntime";
import {softmax} from "lib/ml/filter_inference.ts";
// 配置参数 // 配置参数
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024"; const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
@ -11,7 +12,7 @@ const onnxEmbeddingQuantizedPath = "./model/model.onnx";
const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = await Promise.all([ const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = await Promise.all([
ort.InferenceSession.create(onnxClassifierPath), ort.InferenceSession.create(onnxClassifierPath),
ort.InferenceSession.create(onnxEmbeddingOriginalPath), ort.InferenceSession.create(onnxEmbeddingOriginalPath),
ort.InferenceSession.create(onnxEmbeddingQuantizedPath) ort.InferenceSession.create(onnxEmbeddingQuantizedPath),
]); ]);
let tokenizer: PreTrainedTokenizer; let tokenizer: PreTrainedTokenizer;
@ -26,7 +27,7 @@ async function loadTokenizer() {
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> { async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
const { input_ids } = await tokenizer(texts, { const { input_ids } = await tokenizer(texts, {
add_special_tokens: false, add_special_tokens: false,
return_tensor: false return_tensor: false,
}); });
// 构造输入参数 // 构造输入参数
@ -38,8 +39,10 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
// 准备ONNX输入 // 准备ONNX输入
const inputs = { const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]), input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]) flattened_input_ids.length,
]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
}; };
// 执行推理 // 执行推理
@ -47,18 +50,11 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
return Array.from(embeddings.data as Float32Array); return Array.from(embeddings.data as Float32Array);
} }
function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
}
// 分类推理函数 // 分类推理函数
async function runClassification(embeddings: number[]): Promise<number[]> { async function runClassification(embeddings: number[]): Promise<number[]> {
const inputTensor = new ort.Tensor( const inputTensor = new ort.Tensor(
Float32Array.from(embeddings), Float32Array.from(embeddings),
[1, 4, 1024] [1, 4, 1024],
); );
const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
@ -67,17 +63,15 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
// 指标计算函数 // 指标计算函数
function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): { function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): {
accuracy: number, accuracy: number;
precision: number, precision: number;
recall: number, recall: number;
f1: number, f1: number;
speed: string speed: string;
} { } {
// 初始化混淆矩阵 // 初始化混淆矩阵
const classCount = Math.max(...labels, ...predictions) + 1; const classCount = Math.max(...labels, ...predictions) + 1;
const matrix = Array.from({ length: classCount }, () => const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
Array.from({ length: classCount }, () => 0)
);
// 填充矩阵 // 填充矩阵
labels.forEach((trueLabel, i) => { labels.forEach((trueLabel, i) => {
@ -106,7 +100,7 @@ function calculateMetrics(labels: number[], predictions: number[], elapsedTime:
precision, precision,
recall, recall,
f1, f1,
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec` speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`,
}; };
} }
@ -119,9 +113,12 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
}> { }> {
const data = await Deno.readTextFile("./data/filter/test.jsonl"); const data = await Deno.readTextFile("./data/filter/test.jsonl");
const samples = data.split("\n") const samples = data.split("\n")
.map(line => { .map((line) => {
try { return JSON.parse(line); } try {
catch { return null; } return JSON.parse(line);
} catch {
return null;
}
}) })
.filter(Boolean); .filter(Boolean);
@ -135,7 +132,7 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
sample.title, sample.title,
sample.description, sample.description,
sample.tags.join(","), sample.tags.join(","),
sample.author_info sample.author_info,
], session); ], session);
const probabilities = await runClassification(embeddings); const probabilities = await runClassification(embeddings);

View File

@ -22,11 +22,11 @@ export const classifyVideoWorker = async (job: Job) => {
if (label == -1) { if (label == -1) {
logger.warn(`Failed to classify video ${aid}`, "ml"); logger.warn(`Failed to classify video ${aid}`, "ml");
} }
insertVideoLabel(client, aid, label); await insertVideoLabel(client, aid, label);
client.release(); client.release();
job.updateData({ await job.updateData({
...job.data, label: label, ...job.data, label: label,
}); });
@ -39,7 +39,7 @@ export const classifyVideosWorker = async () => {
return; return;
} }
lockManager.acquireLock("classifyVideos"); await lockManager.acquireLock("classifyVideos");
const client = await db.connect(); const client = await db.connect();
const videos = await getUnlabelledVideos(client); const videos = await getUnlabelledVideos(client);
@ -49,12 +49,12 @@ export const classifyVideosWorker = async () => {
let i = 0; let i = 0;
for (const aid of videos) { for (const aid of videos) {
if (i > 200) { if (i > 200) {
lockManager.releaseLock("classifyVideos"); await lockManager.releaseLock("classifyVideos");
return 10000 + i; return 10000 + i;
} }
await ClassifyVideoQueue.add("classifyVideo", { aid: Number(aid) }); await ClassifyVideoQueue.add("classifyVideo", { aid: Number(aid) });
i++; i++;
} }
lockManager.releaseLock("classifyVideos"); await lockManager.releaseLock("classifyVideos");
return 0; return 0;
}; };

View File

@ -37,7 +37,7 @@ export const getLatestVideosWorker = async (job: Job) => {
return; return;
} }
lockManager.acquireLock("getLatestVideos"); await lockManager.acquireLock("getLatestVideos");
const failedCount = (job.data.failedCount ?? 0) as number; const failedCount = (job.data.failedCount ?? 0) as number;
const client = await db.connect(); const client = await db.connect();
@ -46,7 +46,7 @@ export const getLatestVideosWorker = async (job: Job) => {
await executeTask(client, failedCount); await executeTask(client, failedCount);
} finally { } finally {
client.release(); client.release();
lockManager.releaseLock("getLatestVideos"); await lockManager.releaseLock("getLatestVideos");
} }
return; return;
}; };

View File

@ -8,7 +8,7 @@ import logger from "lib/log/logger.ts";
import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts"; import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts";
import { getVideoTags } from "lib/net/getVideoTags.ts"; import { getVideoTags } from "lib/net/getVideoTags.ts";
import { NetSchedulerError } from "lib/mq/scheduler.ts"; import { NetSchedulerError } from "lib/mq/scheduler.ts";
import { WorkerError } from "../schema.ts"; import { WorkerError } from "lib/mq/schema.ts";
const delayMap = [0.5, 3, 5, 15, 30, 60]; const delayMap = [0.5, 3, 5, 15, 30, 60];
const getJobPriority = (diff: number) => { const getJobPriority = (diff: number) => {

View File

@ -1,4 +1,4 @@
import { MINUTE, SECOND } from "$std/datetime/constants.ts"; import { MINUTE } from "$std/datetime/constants.ts";
import { ClassifyVideoQueue, LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts"; import { ClassifyVideoQueue, LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts";
import logger from "lib/log/logger.ts"; import logger from "lib/log/logger.ts";
@ -7,11 +7,11 @@ export async function initMQ() {
every: 1 * MINUTE every: 1 * MINUTE
}); });
await VideoTagsQueue.upsertJobScheduler("getVideosTags", { await VideoTagsQueue.upsertJobScheduler("getVideosTags", {
every: 30 * SECOND, every: 5 * MINUTE,
immediately: true, immediately: true,
}); });
await ClassifyVideoQueue.upsertJobScheduler("classifyVideos", { await ClassifyVideoQueue.upsertJobScheduler("classifyVideos", {
every: 30 * SECOND, every: 5 * MINUTE,
immediately: true, immediately: true,
}) })

View File

@ -7,6 +7,7 @@ import Redis from "ioredis";
interface Proxy { interface Proxy {
type: string; type: string;
task: string; task: string;
provider: string;
limiter?: RateLimiter; limiter?: RateLimiter;
} }
@ -32,11 +33,16 @@ export class NetSchedulerError extends Error {
} }
} }
interface LimiterMap {
[name: string]: RateLimiter;
}
class NetScheduler { class NetScheduler {
private proxies: ProxiesMap = {}; private proxies: ProxiesMap = {};
private providerLimiters: LimiterMap = {};
addProxy(name: string, type: string, task: string): void { addProxy(name: string, type: string, task: string, provider: string): void {
this.proxies[name] = { type, task }; this.proxies[name] = { type, task, provider };
} }
removeProxy(name: string): void { removeProxy(name: string): void {
@ -47,6 +53,10 @@ class NetScheduler {
this.proxies[name].limiter = limiter; this.proxies[name].limiter = limiter;
} }
setProviderLimiter(name: string, limiter: RateLimiter): void {
this.providerLimiters[name] = limiter;
}
/* /*
* Make a request to the specified URL with any available proxy * Make a request to the specified URL with any available proxy
* @param {string} url - The URL to request. * @param {string} url - The URL to request.
@ -117,7 +127,15 @@ class NetScheduler {
private async getProxyAvailability(name: string): Promise<boolean> { private async getProxyAvailability(name: string): Promise<boolean> {
try { try {
const proxyConfig = this.proxies[name]; const proxyConfig = this.proxies[name];
if (!proxyConfig || !proxyConfig.limiter) { if (!proxyConfig) {
return true;
}
const provider = proxyConfig.provider;
const providerLimiter = await this.providerLimiters[provider].getAvailability();
if (!providerLimiter) {
return false;
}
if (!proxyConfig.limiter) {
return true; return true;
} }
return await proxyConfig.limiter.getAvailability(); return await proxyConfig.limiter.getAvailability();
@ -143,8 +161,8 @@ class NetScheduler {
} }
const netScheduler = new NetScheduler(); const netScheduler = new NetScheduler();
netScheduler.addProxy("default", "native", "default"); netScheduler.addProxy("default", "native", "default", "bilibili-native");
netScheduler.addProxy("tags-native", "native", "getVideoTags"); netScheduler.addProxy("tags-native", "native", "getVideoTags", "bilibili-native");
const tagsRateLimiter = new RateLimiter("getVideoTags", [ const tagsRateLimiter = new RateLimiter("getVideoTags", [
{ {
window: new SlidingWindow(redis, 1), window: new SlidingWindow(redis, 1),
@ -159,6 +177,21 @@ const tagsRateLimiter = new RateLimiter("getVideoTags", [
max: 50, max: 50,
}, },
]); ]);
const biliLimiterNative = new RateLimiter("bilibili-native", [
{
window: new SlidingWindow(redis, 1),
max: 5
},
{
window: new SlidingWindow(redis, 30),
max: 100
},
{
window: new SlidingWindow(redis, 5 * 60),
max: 180
}
]);
netScheduler.setProxyLimiter("tags-native", tagsRateLimiter); netScheduler.setProxyLimiter("tags-native", tagsRateLimiter);
netScheduler.setProviderLimiter("bilibili-native", biliLimiterNative)
export default netScheduler; export default netScheduler;