ref: code structure related to AI
This commit is contained in:
parent
5af2236109
commit
a6c8fd7f3f
@ -1,6 +1,6 @@
|
|||||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||||
import { AllDataType, BiliUserType } from "lib/db/schema.d.ts";
|
import { AllDataType, BiliUserType } from "lib/db/schema.d.ts";
|
||||||
import { modelVersion } from "lib/ml/filter_inference.ts";
|
import Akari from "lib/ml/akari.ts";
|
||||||
|
|
||||||
export async function videoExistsInAllData(client: Client, aid: number) {
|
export async function videoExistsInAllData(client: Client, aid: number) {
|
||||||
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
|
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
|
||||||
@ -23,7 +23,7 @@ export async function getUnlabelledVideos(client: Client) {
|
|||||||
export async function insertVideoLabel(client: Client, aid: number, label: number) {
|
export async function insertVideoLabel(client: Client, aid: number, label: number) {
|
||||||
return await client.queryObject(
|
return await client.queryObject(
|
||||||
`INSERT INTO labelling_result (aid, label, model_version) VALUES ($1, $2, $3) ON CONFLICT (aid, model_version) DO NOTHING`,
|
`INSERT INTO labelling_result (aid, label, model_version) VALUES ($1, $2, $3) ON CONFLICT (aid, model_version) DO NOTHING`,
|
||||||
[aid, label, modelVersion],
|
[aid, label, Akari.getModelVersion()],
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,7 +28,8 @@ export async function getSongsNearMilestone(client: Client) {
|
|||||||
max_views_per_aid
|
max_views_per_aid
|
||||||
WHERE
|
WHERE
|
||||||
(max_views >= 90000 AND max_views < 100000) OR
|
(max_views >= 90000 AND max_views < 100000) OR
|
||||||
(max_views >= 900000 AND max_views < 1000000)
|
(max_views >= 900000 AND max_views < 1000000) OR
|
||||||
|
(max_views >= 9900000 AND max_views < 10000000)
|
||||||
)
|
)
|
||||||
-- 获取符合条件的完整行数据
|
-- 获取符合条件的完整行数据
|
||||||
SELECT
|
SELECT
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||||
|
|
||||||
export async function getUnsnapshotedSongs(client: Client) {
|
/*
|
||||||
const queryResult = await client.queryObject<{ aid: bigint }>(`
|
Returns true if the specified `aid` has at least one record with "pending" or "processing" status.
|
||||||
SELECT DISTINCT s.aid
|
*/
|
||||||
FROM songs s
|
export async function videoHasActiveSchedule(client: Client, aid: number) {
|
||||||
LEFT JOIN video_snapshot v ON s.aid = v.aid
|
const res = await client.queryObject<{ status: string }>(
|
||||||
WHERE v.aid IS NULL;
|
`SELECT status FROM snapshot_schedule WHERE aid = $1 AND (status = 'pending' OR status = 'processing')`,
|
||||||
`);
|
[aid],
|
||||||
return queryResult.rows.map((row) => Number(row.aid));
|
);
|
||||||
}
|
return res.rows.length > 0;
|
||||||
|
}
|
106
lib/ml/akari.ts
Normal file
106
lib/ml/akari.ts
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import { AIManager } from "lib/ml/manager.ts";
|
||||||
|
import * as ort from "onnxruntime";
|
||||||
|
import logger from "lib/log/logger.ts";
|
||||||
|
import { WorkerError } from "lib/mq/schema.ts";
|
||||||
|
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||||
|
|
||||||
|
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||||
|
const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
|
||||||
|
const onnxEmbeddingPath = "./model/model.onnx";
|
||||||
|
|
||||||
|
class AkariProto extends AIManager {
|
||||||
|
private tokenizer: PreTrainedTokenizer | null = null;
|
||||||
|
private readonly modelVersion = "3.17";
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
super();
|
||||||
|
this.models = {
|
||||||
|
"classifier": onnxClassifierPath,
|
||||||
|
"embedding": onnxEmbeddingPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public override async init(): Promise<void> {
|
||||||
|
super.init();
|
||||||
|
await this.initJinaTokenizer();
|
||||||
|
}
|
||||||
|
|
||||||
|
private tokenizerInitialized(): boolean {
|
||||||
|
return this.tokenizer !== null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private getTokenizer(): PreTrainedTokenizer {
|
||||||
|
if (!this.tokenizerInitialized()) {
|
||||||
|
throw new Error("Tokenizer is not initialized. Call init() first.");
|
||||||
|
}
|
||||||
|
return this.tokenizer!;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async initJinaTokenizer(): Promise<void> {
|
||||||
|
if (this.tokenizerInitialized()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
this.tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel);
|
||||||
|
logger.log("Tokenizer initialized", "ml");
|
||||||
|
} catch (error) {
|
||||||
|
throw new WorkerError(error as Error, "ml", "fn:initTokenizer");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async getJinaEmbeddings1024(texts: string[]): Promise<number[]> {
|
||||||
|
const tokenizer = this.getTokenizer();
|
||||||
|
const session = this.getModelSession("embedding");
|
||||||
|
|
||||||
|
const { input_ids } = await tokenizer(texts, {
|
||||||
|
add_special_tokens: false,
|
||||||
|
return_tensors: "js",
|
||||||
|
});
|
||||||
|
|
||||||
|
const cumsum = (arr: number[]): number[] =>
|
||||||
|
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||||
|
|
||||||
|
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string[]) => x.length))];
|
||||||
|
const flattened_input_ids = input_ids.flat();
|
||||||
|
|
||||||
|
const inputs = {
|
||||||
|
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
||||||
|
flattened_input_ids.length,
|
||||||
|
]),
|
||||||
|
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
||||||
|
};
|
||||||
|
|
||||||
|
const { embeddings } = await session.run(inputs);
|
||||||
|
return Array.from(embeddings.data as Float32Array);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async runClassification(embeddings: number[]): Promise<number[]> {
|
||||||
|
const session = this.getModelSession("classifier");
|
||||||
|
const inputTensor = new ort.Tensor(
|
||||||
|
Float32Array.from(embeddings),
|
||||||
|
[1, 3, 1024],
|
||||||
|
);
|
||||||
|
|
||||||
|
const { logits } = await session.run({ channel_features: inputTensor });
|
||||||
|
return this.softmax(logits.data as Float32Array);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async classifyVideo(title: string, description: string, tags: string, aid: number): Promise<number> {
|
||||||
|
const embeddings = await this.getJinaEmbeddings1024([
|
||||||
|
title,
|
||||||
|
description,
|
||||||
|
tags,
|
||||||
|
]);
|
||||||
|
const probabilities = await this.runClassification(embeddings);
|
||||||
|
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml");
|
||||||
|
return probabilities.indexOf(Math.max(...probabilities));
|
||||||
|
}
|
||||||
|
|
||||||
|
public getModelVersion(): string {
|
||||||
|
return this.modelVersion;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const Akari = new AkariProto();
|
||||||
|
export default Akari;
|
||||||
|
|
@ -1,6 +1,13 @@
|
|||||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||||
import * as ort from "onnxruntime";
|
import * as ort from "onnxruntime";
|
||||||
import { softmax } from "lib/ml/filter_inference.ts";
|
|
||||||
|
|
||||||
|
function softmax(logits: Float32Array): number[] {
|
||||||
|
const maxLogit = Math.max(...logits);
|
||||||
|
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
||||||
|
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
||||||
|
return Array.from(exponents.map((exp) => exp / sumOfExponents));
|
||||||
|
}
|
||||||
|
|
||||||
// 配置参数
|
// 配置参数
|
||||||
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||||
|
@ -1,99 +0,0 @@
|
|||||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
|
||||||
import * as ort from "onnxruntime";
|
|
||||||
import logger from "lib/log/logger.ts";
|
|
||||||
import { WorkerError } from "lib/mq/schema.ts";
|
|
||||||
|
|
||||||
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
|
||||||
const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
|
|
||||||
const onnxEmbeddingOriginalPath = "./model/model.onnx";
|
|
||||||
export const modelVersion = "3.17";
|
|
||||||
|
|
||||||
let sessionClassifier: ort.InferenceSession | null = null;
|
|
||||||
let sessionEmbedding: ort.InferenceSession | null = null;
|
|
||||||
let tokenizer: PreTrainedTokenizer | null = null;
|
|
||||||
|
|
||||||
export async function initializeModels() {
|
|
||||||
if (tokenizer && sessionClassifier && sessionEmbedding) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel);
|
|
||||||
|
|
||||||
const [classifierSession, embeddingSession] = await Promise.all([
|
|
||||||
ort.InferenceSession.create(onnxClassifierPath),
|
|
||||||
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
|
|
||||||
]);
|
|
||||||
|
|
||||||
sessionClassifier = classifierSession;
|
|
||||||
sessionEmbedding = embeddingSession;
|
|
||||||
logger.log("Filter models initialized", "ml");
|
|
||||||
} catch (error) {
|
|
||||||
throw new WorkerError(error as Error, "ml", "fn:initializeModels");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function softmax(logits: Float32Array): number[] {
|
|
||||||
const maxLogit = Math.max(...logits);
|
|
||||||
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
|
||||||
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
|
||||||
return Array.from(exponents.map((exp) => exp / sumOfExponents));
|
|
||||||
}
|
|
||||||
|
|
||||||
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
|
|
||||||
if (!tokenizer) {
|
|
||||||
throw new Error("Tokenizer is not initialized. Call initializeModels() first.");
|
|
||||||
}
|
|
||||||
const { input_ids } = await tokenizer(texts, {
|
|
||||||
add_special_tokens: false,
|
|
||||||
return_tensor: false,
|
|
||||||
});
|
|
||||||
|
|
||||||
const cumsum = (arr: number[]): number[] =>
|
|
||||||
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
|
|
||||||
|
|
||||||
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
|
|
||||||
const flattened_input_ids = input_ids.flat();
|
|
||||||
|
|
||||||
const inputs = {
|
|
||||||
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [
|
|
||||||
flattened_input_ids.length,
|
|
||||||
]),
|
|
||||||
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]),
|
|
||||||
};
|
|
||||||
|
|
||||||
const { embeddings } = await session.run(inputs);
|
|
||||||
return Array.from(embeddings.data as Float32Array);
|
|
||||||
}
|
|
||||||
|
|
||||||
async function runClassification(embeddings: number[]): Promise<number[]> {
|
|
||||||
if (!sessionClassifier) {
|
|
||||||
throw new Error("Classifier session is not initialized. Call initializeModels() first.");
|
|
||||||
}
|
|
||||||
const inputTensor = new ort.Tensor(
|
|
||||||
Float32Array.from(embeddings),
|
|
||||||
[1, 3, 1024],
|
|
||||||
);
|
|
||||||
|
|
||||||
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
|
||||||
return softmax(logits.data as Float32Array);
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function classifyVideo(
|
|
||||||
title: string,
|
|
||||||
description: string,
|
|
||||||
tags: string,
|
|
||||||
aid: number,
|
|
||||||
): Promise<number> {
|
|
||||||
if (!sessionEmbedding) {
|
|
||||||
throw new Error("Embedding session is not initialized. Call initializeModels() first.");
|
|
||||||
}
|
|
||||||
const embeddings = await getONNXEmbeddings([
|
|
||||||
title,
|
|
||||||
description,
|
|
||||||
tags,
|
|
||||||
], sessionEmbedding);
|
|
||||||
const probabilities = await runClassification(embeddings);
|
|
||||||
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml");
|
|
||||||
return probabilities.indexOf(Math.max(...probabilities));
|
|
||||||
}
|
|
37
lib/ml/manager.ts
Normal file
37
lib/ml/manager.ts
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
import * as ort from "onnxruntime";
|
||||||
|
import logger from "lib/log/logger.ts";
|
||||||
|
import { WorkerError } from "lib/mq/schema.ts";
|
||||||
|
|
||||||
|
export class AIManager {
|
||||||
|
public sessions: { [key: string]: ort.InferenceSession } = {};
|
||||||
|
public models: { [key: string]: string } = {};
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public async init() {
|
||||||
|
const modelKeys = Object.keys(this.models);
|
||||||
|
for (const key of modelKeys) {
|
||||||
|
try {
|
||||||
|
this.sessions[key] = await ort.InferenceSession.create(this.models[key]);
|
||||||
|
logger.log(`Model ${key} initialized`, "ml");
|
||||||
|
} catch (error) {
|
||||||
|
throw new WorkerError(error as Error, "ml", "fn:init");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public getModelSession(key: string): ort.InferenceSession {
|
||||||
|
if (!this.sessions[key]) {
|
||||||
|
throw new WorkerError(new Error(`Model ${key} not found / not initialized.`), "ml", "fn:getModelSession");
|
||||||
|
}
|
||||||
|
return this.sessions[key];
|
||||||
|
}
|
||||||
|
|
||||||
|
public softmax(logits: Float32Array): number[] {
|
||||||
|
const maxLogit = Math.max(...logits);
|
||||||
|
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
||||||
|
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
||||||
|
return Array.from(exponents.map((exp) => exp / sumOfExponents));
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,12 @@
|
|||||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||||
import * as ort from "onnxruntime";
|
import * as ort from "onnxruntime";
|
||||||
import { softmax } from "lib/ml/filter_inference.ts";
|
|
||||||
|
function softmax(logits: Float32Array): number[] {
|
||||||
|
const maxLogit = Math.max(...logits);
|
||||||
|
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
|
||||||
|
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
|
||||||
|
return Array.from(exponents.map((exp) => exp / sumOfExponents));
|
||||||
|
}
|
||||||
|
|
||||||
// 配置参数
|
// 配置参数
|
||||||
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { Job } from "bullmq";
|
import { Job } from "bullmq";
|
||||||
import { db } from "lib/db/init.ts";
|
import { db } from "lib/db/init.ts";
|
||||||
import { getUnlabelledVideos, getVideoInfoFromAllData, insertVideoLabel } from "lib/db/allData.ts";
|
import { getUnlabelledVideos, getVideoInfoFromAllData, insertVideoLabel } from "lib/db/allData.ts";
|
||||||
import { classifyVideo } from "lib/ml/filter_inference.ts";
|
import Akari from "lib/ml/akari.ts";
|
||||||
import { ClassifyVideoQueue } from "lib/mq/index.ts";
|
import { ClassifyVideoQueue } from "lib/mq/index.ts";
|
||||||
import logger from "lib/log/logger.ts";
|
import logger from "lib/log/logger.ts";
|
||||||
import { lockManager } from "lib/mq/lockManager.ts";
|
import { lockManager } from "lib/mq/lockManager.ts";
|
||||||
@ -19,7 +19,7 @@ export const classifyVideoWorker = async (job: Job) => {
|
|||||||
const title = videoInfo.title?.trim() || "untitled";
|
const title = videoInfo.title?.trim() || "untitled";
|
||||||
const description = videoInfo.description?.trim() || "N/A";
|
const description = videoInfo.description?.trim() || "N/A";
|
||||||
const tags = videoInfo.tags?.trim() || "empty";
|
const tags = videoInfo.tags?.trim() || "empty";
|
||||||
const label = await classifyVideo(title, description, tags, aid);
|
const label = await Akari.classifyVideo(title, description, tags, aid);
|
||||||
if (label == -1) {
|
if (label == -1) {
|
||||||
logger.warn(`Failed to classify video ${aid}`, "ml");
|
logger.warn(`Failed to classify video ${aid}`, "ml");
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ import logger from "lib/log/logger.ts";
|
|||||||
import { classifyVideosWorker, classifyVideoWorker } from "lib/mq/exec/classifyVideo.ts";
|
import { classifyVideosWorker, classifyVideoWorker } from "lib/mq/exec/classifyVideo.ts";
|
||||||
import { WorkerError } from "lib/mq/schema.ts";
|
import { WorkerError } from "lib/mq/schema.ts";
|
||||||
import { lockManager } from "lib/mq/lockManager.ts";
|
import { lockManager } from "lib/mq/lockManager.ts";
|
||||||
import { initializeModels } from "lib/ml/filter_inference.ts";
|
import Akari from "lib/ml/akari.ts";
|
||||||
|
|
||||||
Deno.addSignalListener("SIGINT", async () => {
|
Deno.addSignalListener("SIGINT", async () => {
|
||||||
logger.log("SIGINT Received: Shutting down workers...", "mq");
|
logger.log("SIGINT Received: Shutting down workers...", "mq");
|
||||||
@ -18,7 +18,7 @@ Deno.addSignalListener("SIGTERM", async () => {
|
|||||||
Deno.exit();
|
Deno.exit();
|
||||||
});
|
});
|
||||||
|
|
||||||
await initializeModels();
|
Akari.init();
|
||||||
|
|
||||||
const filterWorker = new Worker(
|
const filterWorker = new Worker(
|
||||||
"classifyVideo",
|
"classifyVideo",
|
||||||
|
Loading…
Reference in New Issue
Block a user