ref: code structure related to AI

This commit is contained in:
alikia2x (寒寒) 2025-03-16 01:23:10 +08:00
parent 5af2236109
commit a6c8fd7f3f
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
10 changed files with 176 additions and 117 deletions

View File

@ -1,6 +1,6 @@
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import { AllDataType, BiliUserType } from "lib/db/schema.d.ts";
import { modelVersion } from "lib/ml/filter_inference.ts";
import Akari from "lib/ml/akari.ts";
export async function videoExistsInAllData(client: Client, aid: number) {
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
@ -23,7 +23,7 @@ export async function getUnlabelledVideos(client: Client) {
export async function insertVideoLabel(client: Client, aid: number, label: number) {
return await client.queryObject(
`INSERT INTO labelling_result (aid, label, model_version) VALUES ($1, $2, $3) ON CONFLICT (aid, model_version) DO NOTHING`,
[aid, label, modelVersion],
[aid, label, Akari.getModelVersion()],
);
}

View File

@ -28,7 +28,8 @@ export async function getSongsNearMilestone(client: Client) {
max_views_per_aid
WHERE
(max_views >= 90000 AND max_views < 100000) OR
(max_views >= 900000 AND max_views < 1000000)
(max_views >= 900000 AND max_views < 1000000) OR
(max_views >= 9900000 AND max_views < 10000000)
)
--
SELECT

View File

@ -1,11 +1,12 @@
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
export async function getUnsnapshotedSongs(client: Client) {
const queryResult = await client.queryObject<{ aid: bigint }>(`
SELECT DISTINCT s.aid
FROM songs s
LEFT JOIN video_snapshot v ON s.aid = v.aid
WHERE v.aid IS NULL;
`);
return queryResult.rows.map((row) => Number(row.aid));
/*
Returns true if the specified `aid` has at least one record with "pending" or "processing" status.
*/
export async function videoHasActiveSchedule(client: Client, aid: number) {
const res = await client.queryObject<{ status: string }>(
`SELECT status FROM snapshot_schedule WHERE aid = $1 AND (status = 'pending' OR status = 'processing')`,
[aid],
);
return res.rows.length > 0;
}

106
lib/ml/akari.ts Normal file
View File

@ -0,0 +1,106 @@
import { AIManager } from "lib/ml/manager.ts";
import * as ort from "onnxruntime";
import logger from "lib/log/logger.ts";
import { WorkerError } from "lib/mq/schema.ts";
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
const onnxEmbeddingPath = "./model/model.onnx";
class AkariProto extends AIManager {
private tokenizer: PreTrainedTokenizer | null = null;
private readonly modelVersion = "3.17";
constructor() {
super();
this.models = {
"classifier": onnxClassifierPath,
"embedding": onnxEmbeddingPath,
}
}
public override async init(): Promise<void> {
super.init();
await this.initJinaTokenizer();
}
private tokenizerInitialized(): boolean {
return this.tokenizer !== null;
}
private getTokenizer(): PreTrainedTokenizer {
if (!this.tokenizerInitialized()) {
throw new Error("Tokenizer is not initialized. Call init() first.");
}
return this.tokenizer!;
}
private async initJinaTokenizer(): Promise<void> {
if (this.tokenizerInitialized()) {
return;
}
try {
this.tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel);
logger.log("Tokenizer initialized", "ml");
} catch (error) {
throw new WorkerError(error as Error, "ml", "fn:initTokenizer");
}
}
private async getJinaEmbeddings1024(texts: string[]): Promise<number[]> {
const tokenizer = this.getTokenizer();
const session = this.getModelSession("embedding");
const { input_ids } = await tokenizer(texts, {
add_special_tokens: false,
return_tensors: "js",
});
const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string[]) => x.length))];
const flattened_input_ids = input_ids.flat();
const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
flattened_input_ids.length,
]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
};
const { embeddings } = await session.run(inputs);
return Array.from(embeddings.data as Float32Array);
}
private async runClassification(embeddings: number[]): Promise<number[]> {
const session = this.getModelSession("classifier");
const inputTensor = new ort.Tensor(
Float32Array.from(embeddings),
[1, 3, 1024],
);
const { logits } = await session.run({ channel_features: inputTensor });
return this.softmax(logits.data as Float32Array);
}
public async classifyVideo(title: string, description: string, tags: string, aid: number): Promise<number> {
const embeddings = await this.getJinaEmbeddings1024([
title,
description,
tags,
]);
const probabilities = await this.runClassification(embeddings);
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml");
return probabilities.indexOf(Math.max(...probabilities));
}
public getModelVersion(): string {
return this.modelVersion;
}
}
const Akari = new AkariProto();
export default Akari;

View File

@ -1,6 +1,13 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime";
import { softmax } from "lib/ml/filter_inference.ts";
function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
}
// 配置参数
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";

View File

@ -1,99 +0,0 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime";
import logger from "lib/log/logger.ts";
import { WorkerError } from "lib/mq/schema.ts";
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
const onnxEmbeddingOriginalPath = "./model/model.onnx";
export const modelVersion = "3.17";
let sessionClassifier: ort.InferenceSession | null = null;
let sessionEmbedding: ort.InferenceSession | null = null;
let tokenizer: PreTrainedTokenizer | null = null;
export async function initializeModels() {
if (tokenizer && sessionClassifier && sessionEmbedding) {
return;
}
try {
tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel);
const [classifierSession, embeddingSession] = await Promise.all([
ort.InferenceSession.create(onnxClassifierPath),
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
]);
sessionClassifier = classifierSession;
sessionEmbedding = embeddingSession;
logger.log("Filter models initialized", "ml");
} catch (error) {
throw new WorkerError(error as Error, "ml", "fn:initializeModels");
}
}
export function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
}
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
if (!tokenizer) {
throw new Error("Tokenizer is not initialized. Call initializeModels() first.");
}
const { input_ids } = await tokenizer(texts, {
add_special_tokens: false,
return_tensor: false,
});
const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const flattened_input_ids = input_ids.flat();
const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
flattened_input_ids.length,
]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
};
const { embeddings } = await session.run(inputs);
return Array.from(embeddings.data as Float32Array);
}
async function runClassification(embeddings: number[]): Promise<number[]> {
if (!sessionClassifier) {
throw new Error("Classifier session is not initialized. Call initializeModels() first.");
}
const inputTensor = new ort.Tensor(
Float32Array.from(embeddings),
[1, 3, 1024],
);
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
return softmax(logits.data as Float32Array);
}
export async function classifyVideo(
title: string,
description: string,
tags: string,
aid: number,
): Promise<number> {
if (!sessionEmbedding) {
throw new Error("Embedding session is not initialized. Call initializeModels() first.");
}
const embeddings = await getONNXEmbeddings([
title,
description,
tags,
], sessionEmbedding);
const probabilities = await runClassification(embeddings);
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml");
return probabilities.indexOf(Math.max(...probabilities));
}

37
lib/ml/manager.ts Normal file
View File

@ -0,0 +1,37 @@
import * as ort from "onnxruntime";
import logger from "lib/log/logger.ts";
import { WorkerError } from "lib/mq/schema.ts";
export class AIManager {
public sessions: { [key: string]: ort.InferenceSession } = {};
public models: { [key: string]: string } = {};
constructor() {
}
public async init() {
const modelKeys = Object.keys(this.models);
for (const key of modelKeys) {
try {
this.sessions[key] = await ort.InferenceSession.create(this.models[key]);
logger.log(`Model ${key} initialized`, "ml");
} catch (error) {
throw new WorkerError(error as Error, "ml", "fn:init");
}
}
}
public getModelSession(key: string): ort.InferenceSession {
if (!this.sessions[key]) {
throw new WorkerError(new Error(`Model ${key} not found / not initialized.`), "ml", "fn:getModelSession");
}
return this.sessions[key];
}
public softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
}
}

View File

@ -1,6 +1,12 @@
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime";
import { softmax } from "lib/ml/filter_inference.ts";
function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
}
// 配置参数
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";

View File

@ -1,7 +1,7 @@
import { Job } from "bullmq";
import { db } from "lib/db/init.ts";
import { getUnlabelledVideos, getVideoInfoFromAllData, insertVideoLabel } from "lib/db/allData.ts";
import { classifyVideo } from "lib/ml/filter_inference.ts";
import Akari from "lib/ml/akari.ts";
import { ClassifyVideoQueue } from "lib/mq/index.ts";
import logger from "lib/log/logger.ts";
import { lockManager } from "lib/mq/lockManager.ts";
@ -19,7 +19,7 @@ export const classifyVideoWorker = async (job: Job) => {
const title = videoInfo.title?.trim() || "untitled";
const description = videoInfo.description?.trim() || "N/A";
const tags = videoInfo.tags?.trim() || "empty";
const label = await classifyVideo(title, description, tags, aid);
const label = await Akari.classifyVideo(title, description, tags, aid);
if (label == -1) {
logger.warn(`Failed to classify video ${aid}`, "ml");
}

View File

@ -4,7 +4,7 @@ import logger from "lib/log/logger.ts";
import { classifyVideosWorker, classifyVideoWorker } from "lib/mq/exec/classifyVideo.ts";
import { WorkerError } from "lib/mq/schema.ts";
import { lockManager } from "lib/mq/lockManager.ts";
import { initializeModels } from "lib/ml/filter_inference.ts";
import Akari from "lib/ml/akari.ts";
Deno.addSignalListener("SIGINT", async () => {
logger.log("SIGINT Received: Shutting down workers...", "mq");
@ -18,7 +18,7 @@ Deno.addSignalListener("SIGTERM", async () => {
Deno.exit();
});
await initializeModels();
Akari.init();
const filterWorker = new Worker(
"classifyVideo",