ref: code structure related to AI
This commit is contained in:
parent
5af2236109
commit
a6c8fd7f3f
@ -1,6 +1,6 @@
|
||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||
import { AllDataType, BiliUserType } from "lib/db/schema.d.ts";
|
||||
import { modelVersion } from "lib/ml/filter_inference.ts";
|
||||
import Akari from "lib/ml/akari.ts";
|
||||
|
||||
export async function videoExistsInAllData(client: Client, aid: number) {
|
||||
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
|
||||
@ -23,7 +23,7 @@ export async function getUnlabelledVideos(client: Client) {
|
||||
export async function insertVideoLabel(client: Client, aid: number, label: number) {
|
||||
return await client.queryObject(
|
||||
`INSERT INTO labelling_result (aid, label, model_version) VALUES ($1, $2, $3) ON CONFLICT (aid, model_version) DO NOTHING`,
|
||||
[aid, label, modelVersion],
|
||||
[aid, label, Akari.getModelVersion()],
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,8 @@ export async function getSongsNearMilestone(client: Client) {
|
||||
max_views_per_aid
|
||||
WHERE
|
||||
(max_views >= 90000 AND max_views < 100000) OR
|
||||
(max_views >= 900000 AND max_views < 1000000)
|
||||
(max_views >= 900000 AND max_views < 1000000) OR
|
||||
(max_views >= 9900000 AND max_views < 10000000)
|
||||
)
|
||||
-- 获取符合条件的完整行数据
|
||||
SELECT
|
||||
|
@ -1,11 +1,12 @@
|
||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||
|
||||
export async function getUnsnapshotedSongs(client: Client) {
|
||||
const queryResult = await client.queryObject<{ aid: bigint }>(`
|
||||
SELECT DISTINCT s.aid
|
||||
FROM songs s
|
||||
LEFT JOIN video_snapshot v ON s.aid = v.aid
|
||||
WHERE v.aid IS NULL;
|
||||
`);
|
||||
return queryResult.rows.map((row) => Number(row.aid));
|
||||
/*
|
||||
Returns true if the specified `aid` has at least one record with "pending" or "processing" status.
|
||||
*/
|
||||
export async function videoHasActiveSchedule(client: Client, aid: number) {
|
||||
const res = await client.queryObject<{ status: string }>(
|
||||
`SELECT status FROM snapshot_schedule WHERE aid = $1 AND (status = 'pending' OR status = 'processing')`,
|
||||
[aid],
|
||||
);
|
||||
return res.rows.length > 0;
|
||||
}
|
106
lib/ml/akari.ts
Normal file
106
lib/ml/akari.ts
Normal file
@ -0,0 +1,106 @@
|
||||
import { AIManager } from "lib/ml/manager.ts";
|
||||
import * as ort from "onnxruntime";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { WorkerError } from "lib/mq/schema.ts";
|
||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||
|
||||
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||
const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
|
||||
const onnxEmbeddingPath = "./model/model.onnx";
|
||||
|
||||
class AkariProto extends AIManager {
|
||||
private tokenizer: PreTrainedTokenizer | null = null;
|
||||
private readonly modelVersion = "3.17";
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
this.models = {
|
||||
"classifier": onnxClassifierPath,
|
||||
"embedding": onnxEmbeddingPath,
|
||||
}
|
||||
}
|
||||
|
||||
public override async init(): Promise<void> {
|
||||
super.init();
|
||||
await this.initJinaTokenizer();
|
||||
}
|
||||
|
||||
private tokenizerInitialized(): boolean {
|
||||
return this.tokenizer !== null;
|
||||
}
|
||||
|
||||
private getTokenizer(): PreTrainedTokenizer {
|
||||
if (!this.tokenizerInitialized()) {
|
||||
throw new Error("Tokenizer is not initialized. Call init() first.");
|
||||
}
|
||||
return this.tokenizer!;
|
||||
}
|
||||
|
||||
private async initJinaTokenizer(): Promise<void> {
|
||||
if (this.tokenizerInitialized()) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
this.tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel);
|
||||
logger.log("Tokenizer initialized", "ml");
|
||||
} catch (error) {
|
||||
throw new WorkerError(error as Error, "ml", "fn:initTokenizer");
|
||||
}
|
||||
}
|
||||
|
||||
private async getJinaEmbeddings1024(texts: string[]): Promise<number[]> {
|
||||
const tokenizer = this.getTokenizer();
|
||||
const session = this.getModelSession("embedding");
|
||||
|
||||
const { input_ids } = await tokenizer(texts, {
|
||||
add_special_tokens: false,
|
||||
return_tensors: "js",
|
||||
});
|
||||
|
||||
const cumsum = (arr: number[]): number[] =>
|
||||
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||
|
||||
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string[]) => x.length))];
|
||||
const flattened_input_ids = input_ids.flat();
|
||||
|
||||
const inputs = {
|
||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
||||
flattened_input_ids.length,
|
||||
]),
|
||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
||||
};
|
||||
|
||||
const { embeddings } = await session.run(inputs);
|
||||
return Array.from(embeddings.data as Float32Array);
|
||||
}
|
||||
|
||||
private async runClassification(embeddings: number[]): Promise<number[]> {
|
||||
const session = this.getModelSession("classifier");
|
||||
const inputTensor = new ort.Tensor(
|
||||
Float32Array.from(embeddings),
|
||||
[1, 3, 1024],
|
||||
);
|
||||
|
||||
const { logits } = await session.run({ channel_features: inputTensor });
|
||||
return this.softmax(logits.data as Float32Array);
|
||||
}
|
||||
|
||||
public async classifyVideo(title: string, description: string, tags: string, aid: number): Promise<number> {
|
||||
const embeddings = await this.getJinaEmbeddings1024([
|
||||
title,
|
||||
description,
|
||||
tags,
|
||||
]);
|
||||
const probabilities = await this.runClassification(embeddings);
|
||||
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml");
|
||||
return probabilities.indexOf(Math.max(...probabilities));
|
||||
}
|
||||
|
||||
public getModelVersion(): string {
|
||||
return this.modelVersion;
|
||||
}
|
||||
}
|
||||
|
||||
const Akari = new AkariProto();
|
||||
export default Akari;
|
||||
|
@ -1,6 +1,13 @@
|
||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||
import * as ort from "onnxruntime";
|
||||
import { softmax } from "lib/ml/filter_inference.ts";
|
||||
|
||||
|
||||
function softmax(logits: Float32Array): number[] {
|
||||
const maxLogit = Math.max(...logits);
|
||||
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
||||
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
||||
return Array.from(exponents.map((exp) => exp / sumOfExponents));
|
||||
}
|
||||
|
||||
// 配置参数
|
||||
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||
|
@ -1,99 +0,0 @@
|
||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||
import * as ort from "onnxruntime";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { WorkerError } from "lib/mq/schema.ts";
|
||||
|
||||
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||
const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
|
||||
const onnxEmbeddingOriginalPath = "./model/model.onnx";
|
||||
export const modelVersion = "3.17";
|
||||
|
||||
let sessionClassifier: ort.InferenceSession | null = null;
|
||||
let sessionEmbedding: ort.InferenceSession | null = null;
|
||||
let tokenizer: PreTrainedTokenizer | null = null;
|
||||
|
||||
export async function initializeModels() {
|
||||
if (tokenizer && sessionClassifier && sessionEmbedding) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel);
|
||||
|
||||
const [classifierSession, embeddingSession] = await Promise.all([
|
||||
ort.InferenceSession.create(onnxClassifierPath),
|
||||
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
|
||||
]);
|
||||
|
||||
sessionClassifier = classifierSession;
|
||||
sessionEmbedding = embeddingSession;
|
||||
logger.log("Filter models initialized", "ml");
|
||||
} catch (error) {
|
||||
throw new WorkerError(error as Error, "ml", "fn:initializeModels");
|
||||
}
|
||||
}
|
||||
|
||||
export function softmax(logits: Float32Array): number[] {
|
||||
const maxLogit = Math.max(...logits);
|
||||
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
||||
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
||||
return Array.from(exponents.map((exp) => exp / sumOfExponents));
|
||||
}
|
||||
|
||||
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
|
||||
if (!tokenizer) {
|
||||
throw new Error("Tokenizer is not initialized. Call initializeModels() first.");
|
||||
}
|
||||
const { input_ids } = await tokenizer(texts, {
|
||||
add_special_tokens: false,
|
||||
return_tensor: false,
|
||||
});
|
||||
|
||||
const cumsum = (arr: number[]): number[] =>
|
||||
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||
|
||||
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
|
||||
const flattened_input_ids = input_ids.flat();
|
||||
|
||||
const inputs = {
|
||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
||||
flattened_input_ids.length,
|
||||
]),
|
||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
||||
};
|
||||
|
||||
const { embeddings } = await session.run(inputs);
|
||||
return Array.from(embeddings.data as Float32Array);
|
||||
}
|
||||
|
||||
async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||
if (!sessionClassifier) {
|
||||
throw new Error("Classifier session is not initialized. Call initializeModels() first.");
|
||||
}
|
||||
const inputTensor = new ort.Tensor(
|
||||
Float32Array.from(embeddings),
|
||||
[1, 3, 1024],
|
||||
);
|
||||
|
||||
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
||||
return softmax(logits.data as Float32Array);
|
||||
}
|
||||
|
||||
export async function classifyVideo(
|
||||
title: string,
|
||||
description: string,
|
||||
tags: string,
|
||||
aid: number,
|
||||
): Promise<number> {
|
||||
if (!sessionEmbedding) {
|
||||
throw new Error("Embedding session is not initialized. Call initializeModels() first.");
|
||||
}
|
||||
const embeddings = await getONNXEmbeddings([
|
||||
title,
|
||||
description,
|
||||
tags,
|
||||
], sessionEmbedding);
|
||||
const probabilities = await runClassification(embeddings);
|
||||
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml");
|
||||
return probabilities.indexOf(Math.max(...probabilities));
|
||||
}
|
37
lib/ml/manager.ts
Normal file
37
lib/ml/manager.ts
Normal file
@ -0,0 +1,37 @@
|
||||
import * as ort from "onnxruntime";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { WorkerError } from "lib/mq/schema.ts";
|
||||
|
||||
export class AIManager {
|
||||
public sessions: { [key: string]: ort.InferenceSession } = {};
|
||||
public models: { [key: string]: string } = {};
|
||||
|
||||
constructor() {
|
||||
}
|
||||
|
||||
public async init() {
|
||||
const modelKeys = Object.keys(this.models);
|
||||
for (const key of modelKeys) {
|
||||
try {
|
||||
this.sessions[key] = await ort.InferenceSession.create(this.models[key]);
|
||||
logger.log(`Model ${key} initialized`, "ml");
|
||||
} catch (error) {
|
||||
throw new WorkerError(error as Error, "ml", "fn:init");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public getModelSession(key: string): ort.InferenceSession {
|
||||
if (!this.sessions[key]) {
|
||||
throw new WorkerError(new Error(`Model ${key} not found / not initialized.`), "ml", "fn:getModelSession");
|
||||
}
|
||||
return this.sessions[key];
|
||||
}
|
||||
|
||||
public softmax(logits: Float32Array): number[] {
|
||||
const maxLogit = Math.max(...logits);
|
||||
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
||||
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
||||
return Array.from(exponents.map((exp) => exp / sumOfExponents));
|
||||
}
|
||||
}
|
@ -1,6 +1,12 @@
|
||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||
import * as ort from "onnxruntime";
|
||||
import { softmax } from "lib/ml/filter_inference.ts";
|
||||
|
||||
function softmax(logits: Float32Array): number[] {
|
||||
const maxLogit = Math.max(...logits);
|
||||
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
||||
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
||||
return Array.from(exponents.map((exp) => exp / sumOfExponents));
|
||||
}
|
||||
|
||||
// 配置参数
|
||||
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { Job } from "bullmq";
|
||||
import { db } from "lib/db/init.ts";
|
||||
import { getUnlabelledVideos, getVideoInfoFromAllData, insertVideoLabel } from "lib/db/allData.ts";
|
||||
import { classifyVideo } from "lib/ml/filter_inference.ts";
|
||||
import Akari from "lib/ml/akari.ts";
|
||||
import { ClassifyVideoQueue } from "lib/mq/index.ts";
|
||||
import logger from "lib/log/logger.ts";
|
||||
import { lockManager } from "lib/mq/lockManager.ts";
|
||||
@ -19,7 +19,7 @@ export const classifyVideoWorker = async (job: Job) => {
|
||||
const title = videoInfo.title?.trim() || "untitled";
|
||||
const description = videoInfo.description?.trim() || "N/A";
|
||||
const tags = videoInfo.tags?.trim() || "empty";
|
||||
const label = await classifyVideo(title, description, tags, aid);
|
||||
const label = await Akari.classifyVideo(title, description, tags, aid);
|
||||
if (label == -1) {
|
||||
logger.warn(`Failed to classify video ${aid}`, "ml");
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ import logger from "lib/log/logger.ts";
|
||||
import { classifyVideosWorker, classifyVideoWorker } from "lib/mq/exec/classifyVideo.ts";
|
||||
import { WorkerError } from "lib/mq/schema.ts";
|
||||
import { lockManager } from "lib/mq/lockManager.ts";
|
||||
import { initializeModels } from "lib/ml/filter_inference.ts";
|
||||
import Akari from "lib/ml/akari.ts";
|
||||
|
||||
Deno.addSignalListener("SIGINT", async () => {
|
||||
logger.log("SIGINT Received: Shutting down workers...", "mq");
|
||||
@ -18,7 +18,7 @@ Deno.addSignalListener("SIGTERM", async () => {
|
||||
Deno.exit();
|
||||
});
|
||||
|
||||
await initializeModels();
|
||||
Akari.init();
|
||||
|
||||
const filterWorker = new Worker(
|
||||
"classifyVideo",
|
||||
|
Loading…
Reference in New Issue
Block a user