add: provider in NetScheduler, missing await

This commit is contained in:
alikia2x (寒寒) 2025-02-26 00:55:48 +08:00
parent 15312f4078
commit 232585594a
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
8 changed files with 164 additions and 136 deletions

1
.gitignore vendored
View File

@ -76,7 +76,6 @@ node_modules/
# project specific
.env
logs/
__pycache__
filter/runs

View File

@ -1,7 +1,7 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
import * as ort from "onnxruntime";
import logger from "lib/log/logger.ts";
import { WorkerError } from "../mq/schema.ts";
import {WorkerError} from "lib/mq/schema.ts";
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
@ -29,12 +29,11 @@ export async function initializeModels() {
sessionEmbedding = embeddingSession;
logger.log("Filter models initialized", "ml");
} catch (error) {
const e = new WorkerError(error as Error, "ml", "fn:initializeModels");
throw e;
throw new WorkerError(error as Error, "ml", "fn:initializeModels");
}
}
function softmax(logits: Float32Array): number[] {
export function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);

View File

@ -1,5 +1,6 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
import * as ort from "onnxruntime";
import {softmax} from "lib/ml/filter_inference.ts";
// 配置参数
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
@ -11,7 +12,7 @@ const onnxEmbeddingQuantizedPath = "./model/model.onnx";
const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = await Promise.all([
ort.InferenceSession.create(onnxClassifierPath),
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
ort.InferenceSession.create(onnxEmbeddingQuantizedPath)
ort.InferenceSession.create(onnxEmbeddingQuantizedPath),
]);
let tokenizer: PreTrainedTokenizer;
@ -26,7 +27,7 @@ async function loadTokenizer() {
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
const { input_ids } = await tokenizer(texts, {
add_special_tokens: false,
return_tensor: false
return_tensor: false,
});
// 构造输入参数
@ -38,8 +39,10 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
// 准备ONNX输入
const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length])
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
flattened_input_ids.length,
]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
};
// 执行推理
@ -47,18 +50,11 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
return Array.from(embeddings.data as Float32Array);
}
function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
}
// 分类推理函数
async function runClassification(embeddings: number[]): Promise<number[]> {
const inputTensor = new ort.Tensor(
Float32Array.from(embeddings),
[1, 4, 1024]
[1, 4, 1024],
);
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
@ -67,17 +63,15 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
// 指标计算函数
function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): {
accuracy: number,
precision: number,
recall: number,
f1: number,
speed: string
accuracy: number;
precision: number;
recall: number;
f1: number;
speed: string;
} {
// 初始化混淆矩阵
const classCount = Math.max(...labels, ...predictions) + 1;
const matrix = Array.from({ length: classCount }, () =>
Array.from({ length: classCount }, () => 0)
);
const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
// 填充矩阵
labels.forEach((trueLabel, i) => {
@ -106,7 +100,7 @@ function calculateMetrics(labels: number[], predictions: number[], elapsedTime:
precision,
recall,
f1,
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`,
};
}
@ -119,9 +113,12 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
}> {
const data = await Deno.readTextFile("./data/filter/test.jsonl");
const samples = data.split("\n")
.map(line => {
try { return JSON.parse(line); }
catch { return null; }
.map((line) => {
try {
return JSON.parse(line);
} catch {
return null;
}
})
.filter(Boolean);
@ -135,7 +132,7 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
sample.title,
sample.description,
sample.tags.join(","),
sample.author_info
sample.author_info,
], session);
const probabilities = await runClassification(embeddings);

View File

@ -22,11 +22,11 @@ export const classifyVideoWorker = async (job: Job) => {
if (label == -1) {
logger.warn(`Failed to classify video ${aid}`, "ml");
}
insertVideoLabel(client, aid, label);
await insertVideoLabel(client, aid, label);
client.release();
job.updateData({
await job.updateData({
...job.data, label: label,
});
@ -39,7 +39,7 @@ export const classifyVideosWorker = async () => {
return;
}
lockManager.acquireLock("classifyVideos");
await lockManager.acquireLock("classifyVideos");
const client = await db.connect();
const videos = await getUnlabelledVideos(client);
@ -49,12 +49,12 @@ export const classifyVideosWorker = async () => {
let i = 0;
for (const aid of videos) {
if (i > 200) {
lockManager.releaseLock("classifyVideos");
await lockManager.releaseLock("classifyVideos");
return 10000 + i;
}
await ClassifyVideoQueue.add("classifyVideo", { aid: Number(aid) });
i++;
}
lockManager.releaseLock("classifyVideos");
await lockManager.releaseLock("classifyVideos");
return 0;
};

View File

@ -37,7 +37,7 @@ export const getLatestVideosWorker = async (job: Job) => {
return;
}
lockManager.acquireLock("getLatestVideos");
await lockManager.acquireLock("getLatestVideos");
const failedCount = (job.data.failedCount ?? 0) as number;
const client = await db.connect();
@ -46,7 +46,7 @@ export const getLatestVideosWorker = async (job: Job) => {
await executeTask(client, failedCount);
} finally {
client.release();
lockManager.releaseLock("getLatestVideos");
await lockManager.releaseLock("getLatestVideos");
}
return;
};

View File

@ -8,7 +8,7 @@ import logger from "lib/log/logger.ts";
import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts";
import { getVideoTags } from "lib/net/getVideoTags.ts";
import { NetSchedulerError } from "lib/mq/scheduler.ts";
import { WorkerError } from "../schema.ts";
import { WorkerError } from "lib/mq/schema.ts";
const delayMap = [0.5, 3, 5, 15, 30, 60];
const getJobPriority = (diff: number) => {

View File

@ -1,4 +1,4 @@
import { MINUTE, SECOND } from "$std/datetime/constants.ts";
import { MINUTE } from "$std/datetime/constants.ts";
import { ClassifyVideoQueue, LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts";
import logger from "lib/log/logger.ts";
@ -7,11 +7,11 @@ export async function initMQ() {
every: 1 * MINUTE
});
await VideoTagsQueue.upsertJobScheduler("getVideosTags", {
every: 30 * SECOND,
every: 5 * MINUTE,
immediately: true,
});
await ClassifyVideoQueue.upsertJobScheduler("classifyVideos", {
every: 30 * SECOND,
every: 5 * MINUTE,
immediately: true,
})

View File

@ -7,6 +7,7 @@ import Redis from "ioredis";
interface Proxy {
type: string;
task: string;
provider: string;
limiter?: RateLimiter;
}
@ -32,11 +33,16 @@ export class NetSchedulerError extends Error {
}
}
interface LimiterMap {
[name: string]: RateLimiter;
}
class NetScheduler {
private proxies: ProxiesMap = {};
private providerLimiters: LimiterMap = {};
addProxy(name: string, type: string, task: string): void {
this.proxies[name] = { type, task };
addProxy(name: string, type: string, task: string, provider: string): void {
this.proxies[name] = { type, task, provider };
}
removeProxy(name: string): void {
@ -47,6 +53,10 @@ class NetScheduler {
this.proxies[name].limiter = limiter;
}
setProviderLimiter(name: string, limiter: RateLimiter): void {
this.providerLimiters[name] = limiter;
}
/*
* Make a request to the specified URL with any available proxy
* @param {string} url - The URL to request.
@ -117,7 +127,15 @@ class NetScheduler {
private async getProxyAvailability(name: string): Promise<boolean> {
try {
const proxyConfig = this.proxies[name];
if (!proxyConfig || !proxyConfig.limiter) {
if (!proxyConfig) {
return true;
}
const provider = proxyConfig.provider;
const providerLimiter = await this.providerLimiters[provider].getAvailability();
if (!providerLimiter) {
return false;
}
if (!proxyConfig.limiter) {
return true;
}
return await proxyConfig.limiter.getAvailability();
@ -143,8 +161,8 @@ class NetScheduler {
}
const netScheduler = new NetScheduler();
netScheduler.addProxy("default", "native", "default");
netScheduler.addProxy("tags-native", "native", "getVideoTags");
netScheduler.addProxy("default", "native", "default", "bilibili-native");
netScheduler.addProxy("tags-native", "native", "getVideoTags", "bilibili-native");
const tagsRateLimiter = new RateLimiter("getVideoTags", [
{
window: new SlidingWindow(redis, 1),
@ -159,6 +177,21 @@ const tagsRateLimiter = new RateLimiter("getVideoTags", [
max: 50,
},
]);
const biliLimiterNative = new RateLimiter("bilibili-native", [
{
window: new SlidingWindow(redis, 1),
max: 5
},
{
window: new SlidingWindow(redis, 30),
max: 100
},
{
window: new SlidingWindow(redis, 5 * 60),
max: 180
}
]);
netScheduler.setProxyLimiter("tags-native", tagsRateLimiter);
netScheduler.setProviderLimiter("bilibili-native", biliLimiterNative)
export default netScheduler;