ref: remove data/ from git
update: inference code for filter model
This commit is contained in:
parent
fd090a25c2
commit
f78f7fabdd
3
.gitignore
vendored
3
.gitignore
vendored
@ -81,8 +81,7 @@ data/main.db
|
||||
logs/
|
||||
__pycache__
|
||||
filter/runs
|
||||
data/filter/eval*
|
||||
data/filter/train*
|
||||
data/
|
||||
filter/checkpoints
|
||||
data/filter/model_predicted*
|
||||
scripts
|
||||
|
File diff suppressed because it is too large
Load Diff
649902
data/2025010104_c30_aids.txt
649902
data/2025010104_c30_aids.txt
File diff suppressed because it is too large
Load Diff
@ -1,3 +0,0 @@
|
||||
# The data
|
||||
|
||||
感谢[天钿Daily](https://tdd.bunnyxt.com/)提供的数据。
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,19 +0,0 @@
|
||||
import { SentenceTransformer } from "./model.ts"; // Changed import path
|
||||
|
||||
async function main() {
|
||||
const sentenceTransformer = await SentenceTransformer.from_pretrained(
|
||||
"mixedbread-ai/mxbai-embed-large-v1",
|
||||
);
|
||||
const outputs = await sentenceTransformer.encode([
|
||||
"Hello world",
|
||||
"How are you guys doing?",
|
||||
"Today is Friday!",
|
||||
]);
|
||||
|
||||
// @ts-ignore
|
||||
console.log(outputs["last_hidden_state"]);
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
main(); // Keep main function call if you want this file to be runnable directly for testing.
|
@ -1,40 +0,0 @@
|
||||
// lib/ml/sentence_transformer_model.ts
|
||||
import { AutoModel, AutoTokenizer, PretrainedOptions } from "@huggingface/transformers";
|
||||
|
||||
export class SentenceTransformer {
|
||||
constructor(
|
||||
private readonly tokenizer: AutoTokenizer,
|
||||
private readonly model: AutoModel,
|
||||
) {}
|
||||
|
||||
static async from_pretrained(
|
||||
modelName: string,
|
||||
options?: PretrainedOptions,
|
||||
): Promise<SentenceTransformer> {
|
||||
if (!options) {
|
||||
options = {
|
||||
progress_callback: undefined,
|
||||
cache_dir: undefined,
|
||||
local_files_only: false,
|
||||
revision: "main",
|
||||
};
|
||||
}
|
||||
const tokenizer = await AutoTokenizer.from_pretrained(modelName, options);
|
||||
const model = await AutoModel.from_pretrained(modelName, options);
|
||||
|
||||
return new SentenceTransformer(tokenizer, model);
|
||||
}
|
||||
|
||||
async encode(sentences: string[]): Promise<any> { // Changed return type to 'any' for now to match console.log output
|
||||
//@ts-ignore
|
||||
const modelInputs = await this.tokenizer(sentences, {
|
||||
padding: true,
|
||||
truncation: true,
|
||||
});
|
||||
|
||||
//@ts-ignore
|
||||
const outputs = await this.model(modelInputs);
|
||||
|
||||
return outputs;
|
||||
}
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
import { Tensor } from "@huggingface/transformers";
|
||||
//@ts-ignore
|
||||
import { Callable } from "@huggingface/transformers/src/utils/core.js"; // Keep as is for now, might need adjustment
|
||||
|
||||
export interface PoolingConfig {
|
||||
word_embedding_dimension: number;
|
||||
pooling_mode_cls_token: boolean;
|
||||
pooling_mode_mean_tokens: boolean;
|
||||
pooling_mode_max_tokens: boolean;
|
||||
pooling_mode_mean_sqrt_len_tokens: boolean;
|
||||
}
|
||||
|
||||
export interface PoolingInput {
|
||||
token_embeddings: Tensor;
|
||||
attention_mask: Tensor;
|
||||
}
|
||||
|
||||
export interface PoolingOutput {
|
||||
sentence_embedding: Tensor;
|
||||
}
|
||||
|
||||
export class Pooling extends Callable {
|
||||
constructor(private readonly config: PoolingConfig) {
|
||||
super();
|
||||
}
|
||||
|
||||
// async _call(inputs: any) { // Keep if pooling functionality is needed
|
||||
// return this.forward(inputs);
|
||||
// }
|
||||
|
||||
// async forward(inputs: PoolingInput): PoolingOutput { // Keep if pooling functionality is needed
|
||||
|
||||
// }
|
||||
}
|
@ -1,24 +1,38 @@
|
||||
import { AutoTokenizer } from "@huggingface/transformers";
|
||||
import * as ort from "onnxruntime";
|
||||
|
||||
// 配置参数
|
||||
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||
// 模型路径和名称
|
||||
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
|
||||
const onnxEmbeddingOriginalPath = "./model/model.onnx";
|
||||
export const modelVersion = "3.11";
|
||||
|
||||
// 初始化会话
|
||||
const [sessionClassifier, sessionEmbedding] = await Promise.all([
|
||||
ort.InferenceSession.create(onnxClassifierPath),
|
||||
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
|
||||
]);
|
||||
// 全局变量,用于存储模型和分词器
|
||||
let sessionClassifier: ort.InferenceSession | null = null;
|
||||
let sessionEmbedding: ort.InferenceSession | null = null;
|
||||
let tokenizer: any | null = null;
|
||||
|
||||
// 初始化分词器和ONNX会话
|
||||
async function initializeModels() {
|
||||
if (tokenizer && sessionClassifier && sessionEmbedding) {
|
||||
return; // 模型已加载,无需重复加载
|
||||
}
|
||||
|
||||
let tokenizer: any;
|
||||
try {
|
||||
const tokenizerConfig = { local_files_only: true };
|
||||
tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel, tokenizerConfig);
|
||||
|
||||
// 初始化分词器
|
||||
async function loadTokenizer() {
|
||||
const tokenizerConfig = { local_files_only: true };
|
||||
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig);
|
||||
const [classifierSession, embeddingSession] = await Promise.all([
|
||||
ort.InferenceSession.create(onnxClassifierPath),
|
||||
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
|
||||
]);
|
||||
|
||||
sessionClassifier = classifierSession;
|
||||
sessionEmbedding = embeddingSession;
|
||||
} catch (error) {
|
||||
console.error("Error initializing models:", error);
|
||||
throw error; // 重新抛出错误,以便调用方处理
|
||||
}
|
||||
}
|
||||
|
||||
function softmax(logits: Float32Array): number[] {
|
||||
@ -29,75 +43,65 @@ function softmax(logits: Float32Array): number[] {
|
||||
}
|
||||
|
||||
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
|
||||
const { input_ids } = await tokenizer(texts, {
|
||||
add_special_tokens: false,
|
||||
return_tensor: false
|
||||
});
|
||||
if (!tokenizer) {
|
||||
throw new Error("Tokenizer is not initialized. Call initializeModels() first.");
|
||||
}
|
||||
const { input_ids } = await tokenizer(texts, {
|
||||
add_special_tokens: false,
|
||||
return_tensor: false,
|
||||
});
|
||||
|
||||
// 构造输入参数
|
||||
const cumsum = (arr: number[]): number[] =>
|
||||
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||
// 构造输入参数
|
||||
const cumsum = (arr: number[]): number[] =>
|
||||
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||
|
||||
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
|
||||
const flattened_input_ids = input_ids.flat();
|
||||
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
|
||||
const flattened_input_ids = input_ids.flat();
|
||||
|
||||
// 准备ONNX输入
|
||||
const inputs = {
|
||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]),
|
||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length])
|
||||
};
|
||||
// 准备ONNX输入
|
||||
const inputs = {
|
||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
||||
flattened_input_ids.length,
|
||||
]),
|
||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
||||
};
|
||||
|
||||
// 执行推理
|
||||
const { embeddings } = await session.run(inputs);
|
||||
return Array.from(embeddings.data as Float32Array);
|
||||
// 执行推理
|
||||
const { embeddings } = await session.run(inputs);
|
||||
return Array.from(embeddings.data as Float32Array);
|
||||
}
|
||||
|
||||
// 分类推理函数
|
||||
async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||
const inputTensor = new ort.Tensor(
|
||||
Float32Array.from(embeddings),
|
||||
[1, 4, 1024]
|
||||
);
|
||||
if (!sessionClassifier) {
|
||||
throw new Error("Classifier session is not initialized. Call initializeModels() first.");
|
||||
}
|
||||
const inputTensor = new ort.Tensor(
|
||||
Float32Array.from(embeddings),
|
||||
[1, 4, 1024],
|
||||
);
|
||||
|
||||
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
||||
return softmax(logits.data as Float32Array);
|
||||
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
||||
return softmax(logits.data as Float32Array);
|
||||
}
|
||||
|
||||
async function processInputTexts(
|
||||
// 导出分类函数
|
||||
export async function classifyVideo(
|
||||
title: string,
|
||||
description: string,
|
||||
tags: string,
|
||||
author_info: string,
|
||||
): Promise<number[]> {
|
||||
if (!sessionEmbedding) {
|
||||
throw new Error("Embedding session is not initialized. Call initializeModels() first.");
|
||||
}
|
||||
const embeddings = await getONNXEmbeddings([
|
||||
title,
|
||||
description,
|
||||
tags,
|
||||
author_info
|
||||
author_info,
|
||||
], sessionEmbedding);
|
||||
|
||||
const probabilities = await runClassification(embeddings);
|
||||
return probabilities;
|
||||
}
|
||||
|
||||
async function main() {
|
||||
await loadTokenizer();
|
||||
const titleText = `【洛天依&乐正绫&心华原创】归一【时之歌Project】`
|
||||
const descriptionText = " 《归一》Vocaloid ver\r\n出品:泛音堂 / 作词:冥凰 / 作曲:汤汤 / 编曲&混音:iAn / 调教:花之祭P\r\n后期:向南 / 人设:Pora / 场景:A舍长 / PV:Sung Hsu(麻薯映画) / 海报:易玄玑 \r\n唱:乐正绫 & 洛天依 & 心华\r\n时之歌Project东国世界观歌曲《归一》双本家VC版\r\nMP3:http://5sing.kugou.com/yc/3006072.html \r\n伴奏:http://5sing.kugou.com/bz/2";
|
||||
const tagsText = '乐正绫,洛天依,心华,VOCALOID中文曲,时之歌,花之祭P';
|
||||
const authorInfoText = "时之歌Project: 欢迎光临时之歌~\r\n官博:http://weibo.com/songoftime\r\n官网:http://www.songoftime.com/";
|
||||
|
||||
try {
|
||||
const probabilities = await processInputTexts(titleText, descriptionText, tagsText, authorInfoText);
|
||||
console.log("Class Probabilities:", probabilities);
|
||||
console.log(`Class 0 Probability: ${probabilities[0]}`);
|
||||
console.log(`Class 1 Probability: ${probabilities[1]}`);
|
||||
console.log(`Class 2 Probability: ${probabilities[2]}`);
|
||||
// Hold the session for 10s
|
||||
await new Promise((resolve) => setTimeout(resolve, 10000));
|
||||
} catch (error) {
|
||||
console.error("Error processing texts:", error);
|
||||
}
|
||||
}
|
||||
|
||||
await main();
|
||||
|
@ -3,3 +3,5 @@ import { Queue } from "bullmq";
|
||||
export const LatestVideosQueue = new Queue("latestVideos");
|
||||
|
||||
export const VideoTagsQueue = new Queue("videoTags");
|
||||
|
||||
export const ClassifyVideoQueue = new Queue("classifyVideo");
|
||||
|
28
src/filterWorker.ts
Normal file
28
src/filterWorker.ts
Normal file
@ -0,0 +1,28 @@
|
||||
import { Job, Worker } from "bullmq";
|
||||
import { redis } from "lib/db/redis.ts";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { WorkerError } from "src/worker.ts";
|
||||
|
||||
const filterWorker = new Worker(
|
||||
"classifyVideo",
|
||||
async (job: Job) => {
|
||||
switch (job.name) {
|
||||
case "classifyVideo":
|
||||
return await getVideoTagsWorker(job);
|
||||
case "classifyVideos":
|
||||
return await getVideoTagsInitializer();
|
||||
default:
|
||||
break;
|
||||
}
|
||||
},
|
||||
{ connection: redis, concurrency: 1, removeOnComplete: { count: 1440 } },
|
||||
);
|
||||
|
||||
filterWorker.on("active", () => {
|
||||
logger.log("Worker activated.", "mq");
|
||||
});
|
||||
|
||||
filterWorker.on("error", (err) => {
|
||||
const e = err as WorkerError;
|
||||
logger.error(e.rawError, e.service, e.codePath);
|
||||
});
|
Loading…
Reference in New Issue
Block a user