ref: remove data/ from git

update: inference code for filter model
This commit is contained in:
alikia2x (寒寒) 2025-02-14 02:04:07 +08:00
parent fd090a25c2
commit f78f7fabdd
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
13 changed files with 96 additions and 695713 deletions

3
.gitignore vendored
View File

@ -81,8 +81,7 @@ data/main.db
logs/ logs/
__pycache__ __pycache__
filter/runs filter/runs
data/filter/eval* data/
data/filter/train*
filter/checkpoints filter/checkpoints
data/filter/model_predicted* data/filter/model_predicted*
scripts scripts

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +0,0 @@
# The data
感谢[天钿Daily](https://tdd.bunnyxt.com/)提供的数据。

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,19 +0,0 @@
import { SentenceTransformer } from "./model.ts"; // Changed import path
async function main() {
const sentenceTransformer = await SentenceTransformer.from_pretrained(
"mixedbread-ai/mxbai-embed-large-v1",
);
const outputs = await sentenceTransformer.encode([
"Hello world",
"How are you guys doing?",
"Today is Friday!",
]);
// @ts-ignore
console.log(outputs["last_hidden_state"]);
return outputs;
}
main(); // Keep main function call if you want this file to be runnable directly for testing.

View File

@ -1,40 +0,0 @@
// lib/ml/sentence_transformer_model.ts
import { AutoModel, AutoTokenizer, PretrainedOptions } from "@huggingface/transformers";
export class SentenceTransformer {
constructor(
private readonly tokenizer: AutoTokenizer,
private readonly model: AutoModel,
) {}
static async from_pretrained(
modelName: string,
options?: PretrainedOptions,
): Promise<SentenceTransformer> {
if (!options) {
options = {
progress_callback: undefined,
cache_dir: undefined,
local_files_only: false,
revision: "main",
};
}
const tokenizer = await AutoTokenizer.from_pretrained(modelName, options);
const model = await AutoModel.from_pretrained(modelName, options);
return new SentenceTransformer(tokenizer, model);
}
async encode(sentences: string[]): Promise<any> { // Changed return type to 'any' for now to match console.log output
//@ts-ignore
const modelInputs = await this.tokenizer(sentences, {
padding: true,
truncation: true,
});
//@ts-ignore
const outputs = await this.model(modelInputs);
return outputs;
}
}

View File

@ -1,34 +0,0 @@
import { Tensor } from "@huggingface/transformers";
//@ts-ignore
import { Callable } from "@huggingface/transformers/src/utils/core.js"; // Keep as is for now, might need adjustment
export interface PoolingConfig {
word_embedding_dimension: number;
pooling_mode_cls_token: boolean;
pooling_mode_mean_tokens: boolean;
pooling_mode_max_tokens: boolean;
pooling_mode_mean_sqrt_len_tokens: boolean;
}
export interface PoolingInput {
token_embeddings: Tensor;
attention_mask: Tensor;
}
export interface PoolingOutput {
sentence_embedding: Tensor;
}
export class Pooling extends Callable {
constructor(private readonly config: PoolingConfig) {
super();
}
// async _call(inputs: any) { // Keep if pooling functionality is needed
// return this.forward(inputs);
// }
// async forward(inputs: PoolingInput): PoolingOutput { // Keep if pooling functionality is needed
// }
}

View File

@ -1,24 +1,38 @@
import { AutoTokenizer } from "@huggingface/transformers"; import { AutoTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime"; import * as ort from "onnxruntime";
// 配置参数 // 模型路径和名称
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024"; const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx"; const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
const onnxEmbeddingOriginalPath = "./model/model.onnx"; const onnxEmbeddingOriginalPath = "./model/model.onnx";
export const modelVersion = "3.11";
// 初始化会话 // 全局变量,用于存储模型和分词器
const [sessionClassifier, sessionEmbedding] = await Promise.all([ let sessionClassifier: ort.InferenceSession | null = null;
ort.InferenceSession.create(onnxClassifierPath), let sessionEmbedding: ort.InferenceSession | null = null;
ort.InferenceSession.create(onnxEmbeddingOriginalPath), let tokenizer: any | null = null;
]);
// 初始化分词器和ONNX会话
async function initializeModels() {
if (tokenizer && sessionClassifier && sessionEmbedding) {
return; // 模型已加载,无需重复加载
}
let tokenizer: any; try {
const tokenizerConfig = { local_files_only: true };
tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel, tokenizerConfig);
// 初始化分词器 const [classifierSession, embeddingSession] = await Promise.all([
async function loadTokenizer() { ort.InferenceSession.create(onnxClassifierPath),
const tokenizerConfig = { local_files_only: true }; ort.InferenceSession.create(onnxEmbeddingOriginalPath),
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig); ]);
sessionClassifier = classifierSession;
sessionEmbedding = embeddingSession;
} catch (error) {
console.error("Error initializing models:", error);
throw error; // 重新抛出错误,以便调用方处理
}
} }
function softmax(logits: Float32Array): number[] { function softmax(logits: Float32Array): number[] {
@ -29,75 +43,65 @@ function softmax(logits: Float32Array): number[] {
} }
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> { async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
const { input_ids } = await tokenizer(texts, { if (!tokenizer) {
add_special_tokens: false, throw new Error("Tokenizer is not initialized. Call initializeModels() first.");
return_tensor: false }
}); const { input_ids } = await tokenizer(texts, {
add_special_tokens: false,
return_tensor: false,
});
// 构造输入参数 // 构造输入参数
const cumsum = (arr: number[]): number[] => const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const flattened_input_ids = input_ids.flat();
// 准备ONNX输入 const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const inputs = { const flattened_input_ids = input_ids.flat();
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length])
};
// 执行推理 // 准备ONNX输入
const { embeddings } = await session.run(inputs); const inputs = {
return Array.from(embeddings.data as Float32Array); input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
flattened_input_ids.length,
]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
};
// 执行推理
const { embeddings } = await session.run(inputs);
return Array.from(embeddings.data as Float32Array);
} }
// 分类推理函数 // 分类推理函数
async function runClassification(embeddings: number[]): Promise<number[]> { async function runClassification(embeddings: number[]): Promise<number[]> {
const inputTensor = new ort.Tensor( if (!sessionClassifier) {
Float32Array.from(embeddings), throw new Error("Classifier session is not initialized. Call initializeModels() first.");
[1, 4, 1024] }
); const inputTensor = new ort.Tensor(
Float32Array.from(embeddings),
const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); [1, 4, 1024],
return softmax(logits.data as Float32Array); );
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
return softmax(logits.data as Float32Array);
} }
async function processInputTexts( // 导出分类函数
export async function classifyVideo(
title: string, title: string,
description: string, description: string,
tags: string, tags: string,
author_info: string, author_info: string,
): Promise<number[]> { ): Promise<number[]> {
if (!sessionEmbedding) {
throw new Error("Embedding session is not initialized. Call initializeModels() first.");
}
const embeddings = await getONNXEmbeddings([ const embeddings = await getONNXEmbeddings([
title, title,
description, description,
tags, tags,
author_info author_info,
], sessionEmbedding); ], sessionEmbedding);
const probabilities = await runClassification(embeddings); const probabilities = await runClassification(embeddings);
return probabilities; return probabilities;
} }
async function main() {
await loadTokenizer();
const titleText = `【洛天依&乐正绫&心华原创】归一【时之歌Project】`
const descriptionText = " 《归一》Vocaloid ver\r\n出品泛音堂 / 作词:冥凰 / 作曲:汤汤 / 编曲&amp;混音iAn / 调教花之祭P\r\n后期向南 / 人设Pora / 场景A舍长 / PVSung Hsu麻薯映画 / 海报:易玄玑 \r\n唱乐正绫 &amp; 洛天依 &amp; 心华\r\n时之歌Project东国世界观歌曲《归一》双本家VC版\r\nMP3http://5sing.kugou.com/yc/3006072.html \r\n伴奏http://5sing.kugou.com/bz/2";
const tagsText = '乐正绫,洛天依,心华,VOCALOID中文曲,时之歌,花之祭P';
const authorInfoText = "时之歌Project: 欢迎光临时之歌~\r\n官博http://weibo.com/songoftime\r\n官网http://www.songoftime.com/";
try {
const probabilities = await processInputTexts(titleText, descriptionText, tagsText, authorInfoText);
console.log("Class Probabilities:", probabilities);
console.log(`Class 0 Probability: ${probabilities[0]}`);
console.log(`Class 1 Probability: ${probabilities[1]}`);
console.log(`Class 2 Probability: ${probabilities[2]}`);
// Hold the session for 10s
await new Promise((resolve) => setTimeout(resolve, 10000));
} catch (error) {
console.error("Error processing texts:", error);
}
}
await main();

View File

@ -3,3 +3,5 @@ import { Queue } from "bullmq";
export const LatestVideosQueue = new Queue("latestVideos"); export const LatestVideosQueue = new Queue("latestVideos");
export const VideoTagsQueue = new Queue("videoTags"); export const VideoTagsQueue = new Queue("videoTags");
export const ClassifyVideoQueue = new Queue("classifyVideo");

28
src/filterWorker.ts Normal file
View File

@ -0,0 +1,28 @@
import { Job, Worker } from "bullmq";
import { redis } from "lib/db/redis.ts";
import logger from "lib/log/logger.ts";
import { WorkerError } from "src/worker.ts";
const filterWorker = new Worker(
"classifyVideo",
async (job: Job) => {
switch (job.name) {
case "classifyVideo":
return await getVideoTagsWorker(job);
case "classifyVideos":
return await getVideoTagsInitializer();
default:
break;
}
},
{ connection: redis, concurrency: 1, removeOnComplete: { count: 1440 } },
);
filterWorker.on("active", () => {
logger.log("Worker activated.", "mq");
});
filterWorker.on("error", (err) => {
const e = err as WorkerError;
logger.error(e.rawError, e.service, e.codePath);
});