add: provider in NetScheduler, missing await
This commit is contained in:
parent
15312f4078
commit
232585594a
1
.gitignore
vendored
1
.gitignore
vendored
@ -76,7 +76,6 @@ node_modules/
|
||||
|
||||
|
||||
# project specific
|
||||
.env
|
||||
logs/
|
||||
__pycache__
|
||||
filter/runs
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||
import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
|
||||
import * as ort from "onnxruntime";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { WorkerError } from "../mq/schema.ts";
|
||||
import {WorkerError} from "lib/mq/schema.ts";
|
||||
|
||||
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
|
||||
@ -29,12 +29,11 @@ export async function initializeModels() {
|
||||
sessionEmbedding = embeddingSession;
|
||||
logger.log("Filter models initialized", "ml");
|
||||
} catch (error) {
|
||||
const e = new WorkerError(error as Error, "ml", "fn:initializeModels");
|
||||
throw e;
|
||||
throw new WorkerError(error as Error, "ml", "fn:initializeModels");
|
||||
}
|
||||
}
|
||||
|
||||
function softmax(logits: Float32Array): number[] {
|
||||
export function softmax(logits: Float32Array): number[] {
|
||||
const maxLogit = Math.max(...logits);
|
||||
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
||||
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||
import {AutoTokenizer, PreTrainedTokenizer} from "@huggingface/transformers";
|
||||
import * as ort from "onnxruntime";
|
||||
import {softmax} from "lib/ml/filter_inference.ts";
|
||||
|
||||
// 配置参数
|
||||
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||
@ -11,7 +12,7 @@ const onnxEmbeddingQuantizedPath = "./model/model.onnx";
|
||||
const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = await Promise.all([
|
||||
ort.InferenceSession.create(onnxClassifierPath),
|
||||
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
|
||||
ort.InferenceSession.create(onnxEmbeddingQuantizedPath)
|
||||
ort.InferenceSession.create(onnxEmbeddingQuantizedPath),
|
||||
]);
|
||||
|
||||
let tokenizer: PreTrainedTokenizer;
|
||||
@ -26,7 +27,7 @@ async function loadTokenizer() {
|
||||
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
|
||||
const { input_ids } = await tokenizer(texts, {
|
||||
add_special_tokens: false,
|
||||
return_tensor: false
|
||||
return_tensor: false,
|
||||
});
|
||||
|
||||
// 构造输入参数
|
||||
@ -38,8 +39,10 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
|
||||
|
||||
// 准备ONNX输入
|
||||
const inputs = {
|
||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]),
|
||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length])
|
||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
||||
flattened_input_ids.length,
|
||||
]),
|
||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
||||
};
|
||||
|
||||
// 执行推理
|
||||
@ -47,18 +50,11 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
|
||||
return Array.from(embeddings.data as Float32Array);
|
||||
}
|
||||
|
||||
function softmax(logits: Float32Array): number[] {
|
||||
const maxLogit = Math.max(...logits);
|
||||
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
||||
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
||||
return Array.from(exponents.map((exp) => exp / sumOfExponents));
|
||||
}
|
||||
|
||||
// 分类推理函数
|
||||
async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||
const inputTensor = new ort.Tensor(
|
||||
Float32Array.from(embeddings),
|
||||
[1, 4, 1024]
|
||||
[1, 4, 1024],
|
||||
);
|
||||
|
||||
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
||||
@ -67,17 +63,15 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||
|
||||
// 指标计算函数
|
||||
function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): {
|
||||
accuracy: number,
|
||||
precision: number,
|
||||
recall: number,
|
||||
f1: number,
|
||||
speed: string
|
||||
accuracy: number;
|
||||
precision: number;
|
||||
recall: number;
|
||||
f1: number;
|
||||
speed: string;
|
||||
} {
|
||||
// 初始化混淆矩阵
|
||||
const classCount = Math.max(...labels, ...predictions) + 1;
|
||||
const matrix = Array.from({ length: classCount }, () =>
|
||||
Array.from({ length: classCount }, () => 0)
|
||||
);
|
||||
const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
|
||||
|
||||
// 填充矩阵
|
||||
labels.forEach((trueLabel, i) => {
|
||||
@ -106,7 +100,7 @@ function calculateMetrics(labels: number[], predictions: number[], elapsedTime:
|
||||
precision,
|
||||
recall,
|
||||
f1,
|
||||
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`
|
||||
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`,
|
||||
};
|
||||
}
|
||||
|
||||
@ -119,9 +113,12 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
|
||||
}> {
|
||||
const data = await Deno.readTextFile("./data/filter/test.jsonl");
|
||||
const samples = data.split("\n")
|
||||
.map(line => {
|
||||
try { return JSON.parse(line); }
|
||||
catch { return null; }
|
||||
.map((line) => {
|
||||
try {
|
||||
return JSON.parse(line);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Boolean);
|
||||
|
||||
@ -135,7 +132,7 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
|
||||
sample.title,
|
||||
sample.description,
|
||||
sample.tags.join(","),
|
||||
sample.author_info
|
||||
sample.author_info,
|
||||
], session);
|
||||
|
||||
const probabilities = await runClassification(embeddings);
|
||||
|
@ -22,11 +22,11 @@ export const classifyVideoWorker = async (job: Job) => {
|
||||
if (label == -1) {
|
||||
logger.warn(`Failed to classify video ${aid}`, "ml");
|
||||
}
|
||||
insertVideoLabel(client, aid, label);
|
||||
await insertVideoLabel(client, aid, label);
|
||||
|
||||
client.release();
|
||||
|
||||
job.updateData({
|
||||
await job.updateData({
|
||||
...job.data, label: label,
|
||||
});
|
||||
|
||||
@ -39,7 +39,7 @@ export const classifyVideosWorker = async () => {
|
||||
return;
|
||||
}
|
||||
|
||||
lockManager.acquireLock("classifyVideos");
|
||||
await lockManager.acquireLock("classifyVideos");
|
||||
|
||||
const client = await db.connect();
|
||||
const videos = await getUnlabelledVideos(client);
|
||||
@ -49,12 +49,12 @@ export const classifyVideosWorker = async () => {
|
||||
let i = 0;
|
||||
for (const aid of videos) {
|
||||
if (i > 200) {
|
||||
lockManager.releaseLock("classifyVideos");
|
||||
await lockManager.releaseLock("classifyVideos");
|
||||
return 10000 + i;
|
||||
}
|
||||
await ClassifyVideoQueue.add("classifyVideo", { aid: Number(aid) });
|
||||
i++;
|
||||
}
|
||||
lockManager.releaseLock("classifyVideos");
|
||||
await lockManager.releaseLock("classifyVideos");
|
||||
return 0;
|
||||
};
|
||||
|
@ -37,7 +37,7 @@ export const getLatestVideosWorker = async (job: Job) => {
|
||||
return;
|
||||
}
|
||||
|
||||
lockManager.acquireLock("getLatestVideos");
|
||||
await lockManager.acquireLock("getLatestVideos");
|
||||
|
||||
const failedCount = (job.data.failedCount ?? 0) as number;
|
||||
const client = await db.connect();
|
||||
@ -46,7 +46,7 @@ export const getLatestVideosWorker = async (job: Job) => {
|
||||
await executeTask(client, failedCount);
|
||||
} finally {
|
||||
client.release();
|
||||
lockManager.releaseLock("getLatestVideos");
|
||||
await lockManager.releaseLock("getLatestVideos");
|
||||
}
|
||||
return;
|
||||
};
|
||||
|
@ -8,7 +8,7 @@ import logger from "lib/log/logger.ts";
|
||||
import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts";
|
||||
import { getVideoTags } from "lib/net/getVideoTags.ts";
|
||||
import { NetSchedulerError } from "lib/mq/scheduler.ts";
|
||||
import { WorkerError } from "../schema.ts";
|
||||
import { WorkerError } from "lib/mq/schema.ts";
|
||||
|
||||
const delayMap = [0.5, 3, 5, 15, 30, 60];
|
||||
const getJobPriority = (diff: number) => {
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { MINUTE, SECOND } from "$std/datetime/constants.ts";
|
||||
import { MINUTE } from "$std/datetime/constants.ts";
|
||||
import { ClassifyVideoQueue, LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts";
|
||||
import logger from "lib/log/logger.ts";
|
||||
|
||||
@ -7,11 +7,11 @@ export async function initMQ() {
|
||||
every: 1 * MINUTE
|
||||
});
|
||||
await VideoTagsQueue.upsertJobScheduler("getVideosTags", {
|
||||
every: 30 * SECOND,
|
||||
every: 5 * MINUTE,
|
||||
immediately: true,
|
||||
});
|
||||
await ClassifyVideoQueue.upsertJobScheduler("classifyVideos", {
|
||||
every: 30 * SECOND,
|
||||
every: 5 * MINUTE,
|
||||
immediately: true,
|
||||
})
|
||||
|
||||
|
@ -7,6 +7,7 @@ import Redis from "ioredis";
|
||||
interface Proxy {
|
||||
type: string;
|
||||
task: string;
|
||||
provider: string;
|
||||
limiter?: RateLimiter;
|
||||
}
|
||||
|
||||
@ -32,11 +33,16 @@ export class NetSchedulerError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
interface LimiterMap {
|
||||
[name: string]: RateLimiter;
|
||||
}
|
||||
|
||||
class NetScheduler {
|
||||
private proxies: ProxiesMap = {};
|
||||
private providerLimiters: LimiterMap = {};
|
||||
|
||||
addProxy(name: string, type: string, task: string): void {
|
||||
this.proxies[name] = { type, task };
|
||||
addProxy(name: string, type: string, task: string, provider: string): void {
|
||||
this.proxies[name] = { type, task, provider };
|
||||
}
|
||||
|
||||
removeProxy(name: string): void {
|
||||
@ -47,6 +53,10 @@ class NetScheduler {
|
||||
this.proxies[name].limiter = limiter;
|
||||
}
|
||||
|
||||
setProviderLimiter(name: string, limiter: RateLimiter): void {
|
||||
this.providerLimiters[name] = limiter;
|
||||
}
|
||||
|
||||
/*
|
||||
* Make a request to the specified URL with any available proxy
|
||||
* @param {string} url - The URL to request.
|
||||
@ -117,7 +127,15 @@ class NetScheduler {
|
||||
private async getProxyAvailability(name: string): Promise<boolean> {
|
||||
try {
|
||||
const proxyConfig = this.proxies[name];
|
||||
if (!proxyConfig || !proxyConfig.limiter) {
|
||||
if (!proxyConfig) {
|
||||
return true;
|
||||
}
|
||||
const provider = proxyConfig.provider;
|
||||
const providerLimiter = await this.providerLimiters[provider].getAvailability();
|
||||
if (!providerLimiter) {
|
||||
return false;
|
||||
}
|
||||
if (!proxyConfig.limiter) {
|
||||
return true;
|
||||
}
|
||||
return await proxyConfig.limiter.getAvailability();
|
||||
@ -143,8 +161,8 @@ class NetScheduler {
|
||||
}
|
||||
|
||||
const netScheduler = new NetScheduler();
|
||||
netScheduler.addProxy("default", "native", "default");
|
||||
netScheduler.addProxy("tags-native", "native", "getVideoTags");
|
||||
netScheduler.addProxy("default", "native", "default", "bilibili-native");
|
||||
netScheduler.addProxy("tags-native", "native", "getVideoTags", "bilibili-native");
|
||||
const tagsRateLimiter = new RateLimiter("getVideoTags", [
|
||||
{
|
||||
window: new SlidingWindow(redis, 1),
|
||||
@ -159,6 +177,21 @@ const tagsRateLimiter = new RateLimiter("getVideoTags", [
|
||||
max: 50,
|
||||
},
|
||||
]);
|
||||
const biliLimiterNative = new RateLimiter("bilibili-native", [
|
||||
{
|
||||
window: new SlidingWindow(redis, 1),
|
||||
max: 5
|
||||
},
|
||||
{
|
||||
window: new SlidingWindow(redis, 30),
|
||||
max: 100
|
||||
},
|
||||
{
|
||||
window: new SlidingWindow(redis, 5 * 60),
|
||||
max: 180
|
||||
}
|
||||
]);
|
||||
netScheduler.setProxyLimiter("tags-native", tagsRateLimiter);
|
||||
netScheduler.setProviderLimiter("bilibili-native", biliLimiterNative)
|
||||
|
||||
export default netScheduler;
|
||||
|
Loading…
Reference in New Issue
Block a user