diff --git a/lib/db/snapshot.ts b/lib/db/snapshot.ts index c3f515b..81fe9a8 100644 --- a/lib/db/snapshot.ts +++ b/lib/db/snapshot.ts @@ -3,7 +3,7 @@ import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; import { VideoSnapshotType } from "lib/db/schema.d.ts"; import { parseTimestampFromPsql } from "lib/utils/formatTimestampToPostgre.ts"; -export async function getSongsNearMilestone(client: Client) { +export async function getVideosNearMilestone(client: Client) { const queryResult = await client.queryObject(` WITH max_views_per_aid AS ( -- 找出每个 aid 的最大 views 值,并确保 aid 存在于 songs 表中 diff --git a/lib/db/snapshotSchedule.ts b/lib/db/snapshotSchedule.ts index 111ffa1..3b77fce 100644 --- a/lib/db/snapshotSchedule.ts +++ b/lib/db/snapshotSchedule.ts @@ -1,3 +1,4 @@ +import { DAY, HOUR, MINUTE, SECOND } from "$std/datetime/constants.ts"; import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; /* @@ -9,4 +10,92 @@ export async function videoHasActiveSchedule(client: Client, aid: number) { [aid], ); return res.rows.length > 0; +} + +interface Snapshot { + created_at: Date; + views: number; +} + +export async function findClosestSnapshot( + client: Client, + aid: number, + targetTime: Date +): Promise { + const query = ` + SELECT created_at, views FROM video_snapshot + WHERE aid = $1 + ORDER BY ABS(EXTRACT(EPOCH FROM (created_at - $2::timestamptz))) ASC + LIMIT 1 + `; + const result = await client.queryObject<{ created_at: string; views: number }>( + query, + [aid, targetTime.toISOString()] + ); + if (result.rows.length === 0) return null; + const row = result.rows[0]; + return { + created_at: new Date(row.created_at), + views: row.views, + }; +} + +export async function getShortTermTimeFeaturesForVideo( + client: Client, + aid: number, + initialTimestampMiliseconds: number +): Promise { + const initialTime = new Date(initialTimestampMiliseconds); + const timeWindows = [ + [ 5 * MINUTE, 0 * MINUTE], + [ 15 * MINUTE, 0 * MINUTE], + [ 40 * MINUTE, 0 * MINUTE], + [ 1 * HOUR, 0 * HOUR], + [ 2 * HOUR, 1 * HOUR], + [ 3 * HOUR, 2 * HOUR], + [ 3 * HOUR, 0 * HOUR], + [ 6 * HOUR, 0 * HOUR], + [18 * HOUR, 12 * HOUR], + [ 1 * DAY, 0 * DAY], + [ 3 * DAY, 0 * DAY], + [ 7 * DAY, 0 * DAY] + ]; + + const results: number[] = []; + + for (const [windowStart, windowEnd] of timeWindows) { + const targetTimeStart = new Date(initialTime.getTime() - windowStart); + const targetTimeEnd = new Date(initialTime.getTime() - windowEnd); + + const startRecord = await findClosestSnapshot(client, aid, targetTimeStart); + const endRecord = await findClosestSnapshot(client, aid, targetTimeEnd); + + if (!startRecord || !endRecord) { + results.push(NaN); + continue; + } + + const timeDiffSeconds = + (endRecord.created_at.getTime() - startRecord.created_at.getTime()) / 1000; + const windowDuration = windowStart - windowEnd; + + let scale = 0; + if (windowDuration > 0) { + scale = timeDiffSeconds / windowDuration; + } + + const viewsDiff = endRecord.views - startRecord.views; + const adjustedViews = Math.max(viewsDiff, 1); + + let result: number; + if (scale > 0) { + result = Math.log2(adjustedViews / scale + 1); + } else { + result = Math.log2(adjustedViews + 1); + } + + results.push(result); + } + + return results; } \ No newline at end of file diff --git a/lib/ml/akari.ts b/lib/ml/akari.ts index 386bb56..d5ce9b2 100644 --- a/lib/ml/akari.ts +++ b/lib/ml/akari.ts @@ -5,102 +5,103 @@ import { WorkerError } from "lib/mq/schema.ts"; import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; -const onnxClassifierPath = "./model/video_classifier_v3_17.onnx"; -const onnxEmbeddingPath = "./model/model.onnx"; +const onnxClassifierPath = "./model/akari/3.17.onnx"; +const onnxEmbeddingPath = "./model/embedding/model.onnx"; class AkariProto extends AIManager { - private tokenizer: PreTrainedTokenizer | null = null; - private readonly modelVersion = "3.17"; + private tokenizer: PreTrainedTokenizer | null = null; + private readonly modelVersion = "3.17"; constructor() { super(); - this.models = { - "classifier": onnxClassifierPath, - "embedding": onnxEmbeddingPath, - } + this.models = { + "classifier": onnxClassifierPath, + "embedding": onnxEmbeddingPath, + }; } - public override async init(): Promise { - super.init(); - await this.initJinaTokenizer(); - } + public override async init(): Promise { + await super.init(); + await this.initJinaTokenizer(); + } - private tokenizerInitialized(): boolean { - return this.tokenizer !== null; - } + private tokenizerInitialized(): boolean { + return this.tokenizer !== null; + } - private getTokenizer(): PreTrainedTokenizer { - if (!this.tokenizerInitialized()) { - throw new Error("Tokenizer is not initialized. Call init() first."); - } - return this.tokenizer!; - } + private getTokenizer(): PreTrainedTokenizer { + if (!this.tokenizerInitialized()) { + throw new Error("Tokenizer is not initialized. Call init() first."); + } + return this.tokenizer!; + } - private async initJinaTokenizer(): Promise { - if (this.tokenizerInitialized()) { - return; - } - try { - this.tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel); - logger.log("Tokenizer initialized", "ml"); - } catch (error) { - throw new WorkerError(error as Error, "ml", "fn:initTokenizer"); - } - } + private async initJinaTokenizer(): Promise { + if (this.tokenizerInitialized()) { + return; + } + try { + this.tokenizer = await AutoTokenizer.from_pretrained(tokenizerModel); + logger.log("Tokenizer initialized", "ml"); + } catch (error) { + throw new WorkerError(error as Error, "ml", "fn:initTokenizer"); + } + } - private async getJinaEmbeddings1024(texts: string[]): Promise { - const tokenizer = this.getTokenizer(); - const session = this.getModelSession("embedding"); + private async getJinaEmbeddings1024(texts: string[]): Promise { + const tokenizer = this.getTokenizer(); + const session = this.getModelSession("embedding"); - const { input_ids } = await tokenizer(texts, { - add_special_tokens: false, - return_tensors: "js", - }); + const { input_ids } = await tokenizer(texts, { + add_special_tokens: false, + return_tensor: false, + }); - const cumsum = (arr: number[]): number[] => - arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); + const cumsum = (arr: number[]): number[] => + arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); - const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string[]) => x.length))]; - const flattened_input_ids = input_ids.flat(); + const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))]; + const flattened_input_ids = input_ids.flat(); - const inputs = { - input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ - flattened_input_ids.length, - ]), - offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]), - }; + const inputs = { + input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ + flattened_input_ids.length, + ]), + offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]), + }; - const { embeddings } = await session.run(inputs); - return Array.from(embeddings.data as Float32Array); - } + const { embeddings } = await session.run(inputs); + return Array.from(embeddings.data as Float32Array); + } - private async runClassification(embeddings: number[]): Promise { - const session = this.getModelSession("classifier"); - const inputTensor = new ort.Tensor( - Float32Array.from(embeddings), - [1, 3, 1024], - ); + private async runClassification(embeddings: number[]): Promise { + const session = this.getModelSession("classifier"); + const inputTensor = new ort.Tensor( + Float32Array.from(embeddings), + [1, 3, 1024], + ); - const { logits } = await session.run({ channel_features: inputTensor }); - return this.softmax(logits.data as Float32Array); - } + const { logits } = await session.run({ channel_features: inputTensor }); + return this.softmax(logits.data as Float32Array); + } - public async classifyVideo(title: string, description: string, tags: string, aid: number): Promise { - const embeddings = await this.getJinaEmbeddings1024([ - title, - description, - tags, - ]); - const probabilities = await this.runClassification(embeddings); - logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml"); - return probabilities.indexOf(Math.max(...probabilities)); - } + public async classifyVideo(title: string, description: string, tags: string, aid?: number): Promise { + const embeddings = await this.getJinaEmbeddings1024([ + title, + description, + tags, + ]); + const probabilities = await this.runClassification(embeddings); + if (aid) { + logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml"); + } + return probabilities.indexOf(Math.max(...probabilities)); + } - public getModelVersion(): string { - return this.modelVersion; - } + public getModelVersion(): string { + return this.modelVersion; + } } const Akari = new AkariProto(); export default Akari; - diff --git a/lib/ml/manager.ts b/lib/ml/manager.ts index 268985d..8f15513 100644 --- a/lib/ml/manager.ts +++ b/lib/ml/manager.ts @@ -22,7 +22,7 @@ export class AIManager { } public getModelSession(key: string): ort.InferenceSession { - if (!this.sessions[key]) { + if (this.sessions[key] === undefined) { throw new WorkerError(new Error(`Model ${key} not found / not initialized.`), "ml", "fn:getModelSession"); } return this.sessions[key]; diff --git a/lib/ml/mantis.ts b/lib/ml/mantis.ts new file mode 100644 index 0000000..59bc09a --- /dev/null +++ b/lib/ml/mantis.ts @@ -0,0 +1,25 @@ +import { AIManager } from "lib/ml/manager.ts"; +import * as ort from "onnxruntime"; +import logger from "lib/log/logger.ts"; +import { WorkerError } from "lib/mq/schema.ts"; + +const modelPath = "./model/model.onnx"; + +class MantisProto extends AIManager { + + constructor() { + super(); + this.models = { + "predictor": modelPath, + } + } + + public override async init(): Promise { + await super.init(); + } + + +} + +const Mantis = new MantisProto(); +export default Mantis; diff --git a/lib/mq/exec/snapshotTick.ts b/lib/mq/exec/snapshotTick.ts index 12443ff..65564d0 100644 --- a/lib/mq/exec/snapshotTick.ts +++ b/lib/mq/exec/snapshotTick.ts @@ -1,229 +1,31 @@ import { Job } from "bullmq"; -import { MINUTE, SECOND } from "$std/datetime/constants.ts"; -import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; import { db } from "lib/db/init.ts"; -import { - getShortTermEtaPrediction, - getSongsNearMilestone, - getUnsnapshotedSongs, - songEligibleForMilestoneSnapshot, -} from "lib/db/snapshot.ts"; -import { SnapshotQueue } from "lib/mq/index.ts"; -import { insertVideoStats } from "lib/mq/task/getVideoStats.ts"; -import { parseTimestampFromPsql } from "lib/utils/formatTimestampToPostgre.ts"; -import { redis } from "lib/db/redis.ts"; -import { NetSchedulerError } from "lib/mq/scheduler.ts"; -import logger from "lib/log/logger.ts"; -import { formatSeconds } from "lib/utils/formatSeconds.ts"; -import { truncate } from "lib/utils/truncate.ts"; - -async function snapshotScheduled(aid: number) { - try { - return await redis.exists(`cvsa:snapshot:${aid}`); - } catch { - logger.error(`Failed to check scheduled status for ${aid}`, "mq"); - return false; - } -} - -async function setSnapshotScheduled(aid: number, value: boolean, exp: number) { - try { - if (value) { - await redis.set(`cvsa:snapshot:${aid}`, 1, "EX", exp); - } else { - await redis.del(`cvsa:snapshot:${aid}`); - } - } catch { - logger.error(`Failed to set scheduled status to ${value} for ${aid}`, "mq"); - } -} - -interface SongNearMilestone { - aid: number; - id: number; - created_at: string; - views: number; - coins: number; - likes: number; - favorites: number; - shares: number; - danmakus: number; - replies: number; -} - -async function processMilestoneSnapshots(client: Client, vidoesNearMilestone: SongNearMilestone[]) { - let i = 0; - for (const snapshot of vidoesNearMilestone) { - if (await snapshotScheduled(snapshot.aid)) { - logger.silly( - `Video ${snapshot.aid} is already scheduled for snapshot`, - "mq", - "fn:processMilestoneSnapshots", - ); - continue; - } - if (await songEligibleForMilestoneSnapshot(client, snapshot.aid) === false) { - logger.silly( - `Video ${snapshot.aid} is not eligible for milestone snapshot`, - "mq", - "fn:processMilestoneSnapshots", - ); - continue; - } - const factor = Math.floor(i / 8); - const delayTime = factor * SECOND * 2; - await SnapshotQueue.add("snapshotMilestoneVideo", { - aid: snapshot.aid, - currentViews: snapshot.views, - snapshotedAt: snapshot.created_at, - }, { delay: delayTime, priority: 1 }); - await setSnapshotScheduled(snapshot.aid, true, 20 * 60); - i++; - } -} - -async function processUnsnapshotedVideos(unsnapshotedVideos: number[]) { - let i = 0; - for (const aid of unsnapshotedVideos) { - if (await snapshotScheduled(aid)) { - logger.silly(`Video ${aid} is already scheduled for snapshot`, "mq", "fn:processUnsnapshotedVideos"); - continue; - } - const factor = Math.floor(i / 5); - const delayTime = factor * SECOND * 4; - await SnapshotQueue.add("snapshotVideo", { - aid, - }, { delay: delayTime, priority: 3 }); - await setSnapshotScheduled(aid, true, 6 * 60 * 60); - i++; - } -} +import { getVideosNearMilestone } from "lib/db/snapshot.ts"; +import { videoHasActiveSchedule } from "lib/db/snapshotSchedule.ts"; export const snapshotTickWorker = async (_job: Job) => { const client = await db.connect(); try { - const vidoesNearMilestone = await getSongsNearMilestone(client); - await processMilestoneSnapshots(client, vidoesNearMilestone); - - const unsnapshotedVideos = await getUnsnapshotedSongs(client); - await processUnsnapshotedVideos(unsnapshotedVideos); + // TODO: implement } finally { client.release(); } }; -export const takeSnapshotForMilestoneVideoWorker = async (job: Job) => { +export const collectMilestoneSnapshotsWorker = async (_job: Job) => { const client = await db.connect(); - await setSnapshotScheduled(job.data.aid, true, 20 * 60); try { - const aid: number = job.data.aid; - const currentViews: number = job.data.currentViews; - const lastSnapshoted: string = job.data.snapshotedAt; - const stat = await insertVideoStats(client, aid, "snapshotMilestoneVideo"); - if (typeof stat === "number") { - if (stat === -404 || stat === 62002 || stat == 62012) { - await setSnapshotScheduled(aid, true, 6 * 60 * 60); - } else { - await setSnapshotScheduled(aid, false, 0); - } - return; + const videos = await getVideosNearMilestone(client); + for (const video of videos) { + if (await videoHasActiveSchedule(client, video.aid)) continue; } - const nextMilestone = currentViews >= 100000 ? 1000000 : 100000; - if (stat.views >= nextMilestone) { - await setSnapshotScheduled(aid, false, 0); - return; - } - let eta = await getShortTermEtaPrediction(client, aid); - if (eta === null) { - const DELTA = 0.001; - const intervalSeconds = (Date.now() - parseTimestampFromPsql(lastSnapshoted)) / SECOND; - const viewsIncrement = stat.views - currentViews; - const incrementSpeed = viewsIncrement / (intervalSeconds + DELTA); - const viewsToIncrease = nextMilestone - stat.views; - eta = viewsToIncrease / (incrementSpeed + DELTA); - } - const scheduledNextSnapshotDelay = eta * SECOND / 3; - const maxInterval = 20 * MINUTE; - const minInterval = 1 * SECOND; - const delay = truncate(scheduledNextSnapshotDelay, minInterval, maxInterval); - await SnapshotQueue.add("snapshotMilestoneVideo", { - aid, - currentViews: stat.views, - snapshotedAt: stat.time, - }, { delay, priority: 1 }); - await job.updateData({ - ...job.data, - updatedViews: stat.views, - updatedTime: new Date(stat.time).toISOString(), - etaInMins: eta / 60, - }); - logger.log( - `Scheduled next milestone snapshot for ${aid} in ${ - formatSeconds(delay / 1000) - }, current views: ${stat.views}`, - "mq", - ); - } catch (e) { - if (e instanceof NetSchedulerError && e.code === "NO_AVAILABLE_PROXY") { - logger.warn( - `No available proxy for aid ${job.data.aid}.`, - "mq", - "fn:takeSnapshotForMilestoneVideoWorker", - ); - await SnapshotQueue.add("snapshotMilestoneVideo", { - aid: job.data.aid, - currentViews: job.data.currentViews, - snapshotedAt: job.data.snapshotedAt, - }, { delay: 5 * SECOND, priority: 1 }); - return; - } - throw e; + } catch (_e) { + // } finally { client.release(); } }; -export const takeSnapshotForVideoWorker = async (job: Job) => { - const client = await db.connect(); - await setSnapshotScheduled(job.data.aid, true, 6 * 60 * 60); - try { - const { aid } = job.data; - const stat = await insertVideoStats(client, aid, "getVideoInfo"); - if (typeof stat === "number") { - if (stat === -404 || stat === 62002 || stat == 62012) { - await setSnapshotScheduled(aid, true, 6 * 60 * 60); - } else { - await setSnapshotScheduled(aid, false, 0); - } - return; - } - logger.log(`Taken snapshot for ${aid}`, "mq"); - if (stat == null) { - setSnapshotScheduled(aid, false, 0); - return; - } - await job.updateData({ - ...job.data, - updatedViews: stat.views, - updatedTime: new Date(stat.time).toISOString(), - }); - const nearMilestone = (stat.views >= 90000 && stat.views < 100000) || - (stat.views >= 900000 && stat.views < 1000000); - if (nearMilestone) { - await SnapshotQueue.add("snapshotMilestoneVideo", { - aid, - currentViews: stat.views, - snapshotedAt: stat.time, - }, { delay: 0, priority: 1 }); - } - await setSnapshotScheduled(aid, false, 0); - } catch (e) { - if (e instanceof NetSchedulerError && e.code === "NO_AVAILABLE_PROXY") { - await setSnapshotScheduled(job.data.aid, false, 0); - return; - } - throw e; - } finally { - client.release(); - } +export const takeSnapshotForVideoWorker = async (_job: Job) => { + // TODO: implement }; diff --git a/lib/mq/init.ts b/lib/mq/init.ts index 03a0aad..688dd4a 100644 --- a/lib/mq/init.ts +++ b/lib/mq/init.ts @@ -19,6 +19,10 @@ export async function initMQ() { every: 1 * SECOND, immediately: true, }); + await SnapshotQueue.upsertJobScheduler("collectMilestoneSnapshots", { + every: 5 * MINUTE, + immediately: true, + }); logger.log("Message queue initialized."); } diff --git a/pred/inference.py b/pred/inference.py index 9a3d678..cadb90f 100644 --- a/pred/inference.py +++ b/pred/inference.py @@ -4,20 +4,20 @@ from model import CompactPredictor import torch def main(): - model = CompactPredictor(16).to('cpu', dtype=torch.float32) - model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0530.pt')) + model = CompactPredictor(10).to('cpu', dtype=torch.float32) + model.load_state_dict(torch.load('./pred/checkpoints/long_term.pt')) model.eval() # inference - initial = 999269 + initial = 997029 last = initial - start_time = '2025-03-15 01:03:21' - for i in range(1, 48): + start_time = '2025-03-17 00:13:17' + for i in range(1, 120): hour = i / 0.5 sec = hour * 3600 time_d = np.log2(sec) data = [time_d, np.log2(initial+1), # time_delta, current_views - 2.801318, 3.455128, 3.903391, 3.995577, 4.641488, 5.75131, 6.723868, 6.105322, 8.141023, 9.576701, 10.665067, # grows_feat - 0.043993, 0.72057, 28.000902 # time_feat + 6.111542, 8.404707, 10.071566, 11.55888, 12.457823,# grows_feat + 0.009225, 0.001318, 28.001814# time_feat ] np_arr = np.array([data]) tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32) @@ -25,7 +25,7 @@ def main(): num = output.detach().numpy()[0][0] views_pred = int(np.exp2(num)) + initial current_time = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S') + datetime.timedelta(hours=hour) - print(current_time.strftime('%m-%d %H:%M'), views_pred, views_pred - last) + print(current_time.strftime('%m-%d %H:%M:%S'), views_pred, views_pred - last) last = views_pred if __name__ == '__main__': diff --git a/src/worker.ts b/src/worker.ts index 1b59785..9523a42 100644 --- a/src/worker.ts +++ b/src/worker.ts @@ -5,7 +5,7 @@ import logger from "lib/log/logger.ts"; import { lockManager } from "lib/mq/lockManager.ts"; import { WorkerError } from "lib/mq/schema.ts"; import { getVideoInfoWorker } from "lib/mq/exec/getLatestVideos.ts"; -import { snapshotTickWorker, takeSnapshotForMilestoneVideoWorker, takeSnapshotForVideoWorker } from "lib/mq/exec/snapshotTick.ts"; +import { snapshotTickWorker, collectMilestoneSnapshotsWorker, takeSnapshotForVideoWorker } from "lib/mq/exec/snapshotTick.ts"; Deno.addSignalListener("SIGINT", async () => { logger.log("SIGINT Received: Shutting down workers...", "mq"); @@ -56,15 +56,15 @@ const snapshotWorker = new Worker( "snapshot", async (job: Job) => { switch (job.name) { - case "snapshotMilestoneVideo": - await takeSnapshotForMilestoneVideoWorker(job); - break; case "snapshotVideo": await takeSnapshotForVideoWorker(job); break; case "snapshotTick": await snapshotTickWorker(job); break; + case "collectMilestoneSnapshots": + await collectMilestoneSnapshotsWorker(job); + break; default: break; } diff --git a/test/db/snapshotSchedule.test.ts b/test/db/snapshotSchedule.test.ts new file mode 100644 index 0000000..a5e1d6a --- /dev/null +++ b/test/db/snapshotSchedule.test.ts @@ -0,0 +1,18 @@ +import { assertEquals, assertInstanceOf, assertNotEquals } from "@std/assert"; +import { findClosestSnapshot } from "lib/db/snapshotSchedule.ts"; +import { postgresConfig } from "lib/db/pgConfig.ts"; +import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; + +Deno.test("Snapshot Schedule - getShortTermTimeFeaturesForVideo", async () => { + const client = new Client(postgresConfig); + try { + const result = await findClosestSnapshot(client, 247308539, new Date(1741983383000)); + assertNotEquals(result, null); + const created_at = result!.created_at; + const views = result!.views; + assertInstanceOf(created_at, Date); + assertEquals(typeof views, "number"); + } finally { + client.end(); + } +}); diff --git a/test/ml/akari.json b/test/ml/akari.json new file mode 100644 index 0000000..7345078 --- /dev/null +++ b/test/ml/akari.json @@ -0,0 +1,22 @@ +{ + "test1": [ + { + "title": "【洛天依】《一花依世界》(2024重调版)|“抬头仰望,夜空多安详”【原创PV付】", + "desc": "本家:BV1Vs411H7JH\n作曲:LS\n作词:杏花包子\n调教:鬼面P\n混音:虎皮猫P\n演唱:洛天依\n曲绘:山下鸭鸭窝\n映像:阿妍\n——————————————————————\n本稿为同人二创,非本家重制", + "tags": "发现《一花依世界》, Vsinger创作激励计划, 洛天依, VOCALOID CHINA, 翻唱, 原创PV付, ACE虚拟歌姬, 中文VOCALOID, 国风电子, 一花依世界, ACE Studio, Vsinger创作激励计划2024冬季物语", + "label": 2 + }, + { + "title": "【鏡音レン】アカシア【VOCALOID Cover】", + "desc": "鏡音リン・レン 13th Anniversary\n\nMusic:BUMP OF CHICKEN https://youtu.be/BoZ0Zwab6Oc\nust:Maplestyle sm37853236\nOff Vocal: https://youtu.be/YMzrUzq1uX0\nSinger:鏡音レン\n\n氷雨ハルカ\nYoutube :https://t.co/8zuv6g7Acm\nniconico:https://t.co/C6DRfdYAp0\ntwitter :https://twitter.com/hisame_haruka\n\n転載禁止\nPlease do not reprint without my permission.", + "tags": "鏡音レン", + "label": 0 + }, + { + "title": "【洛天依原创曲】谪星【姆斯塔之谕】", + "desc": "谪星\n\n策划/世界观:听雨\n作词:听雨\n作曲/编曲:太白\n混音:虎皮猫\n人设:以木\n曲绘:Ar极光\n调校:哈士奇p\n视频:苏卿白", + "tags": "2025虚拟歌手贺岁纪, 洛天依, 原创歌曲, VOCALOID, 虚拟歌手, 原创音乐, 姆斯塔, 中文VOCALOID", + "label": 1 + } + ] +} diff --git a/test/ml/akari.test.ts b/test/ml/akari.test.ts new file mode 100644 index 0000000..958f34d --- /dev/null +++ b/test/ml/akari.test.ts @@ -0,0 +1,46 @@ +import Akari from "lib/ml/akari.ts"; +import { assertEquals, assertGreaterOrEqual } from "jsr:@std/assert"; +import { join } from "$std/path/join.ts"; +import { SECOND } from "$std/datetime/constants.ts"; + +Deno.test("Akari AI - normal cases accuracy test", async () => { + const path = import.meta.dirname!; + const dataPath = join(path, "akari.json"); + const rawData = await Deno.readTextFile(dataPath); + const data = JSON.parse(rawData); + await Akari.init(); + for (const testCase of data.test1) { + const result = await Akari.classifyVideo( + testCase.title, + testCase.desc, + testCase.tags + ); + assertEquals(result, testCase.label); + } +}); + +Deno.test("Akari AI - performance test", async () => { + const path = import.meta.dirname!; + const dataPath = join(path, "akari.json"); + const rawData = await Deno.readTextFile(dataPath); + const data = JSON.parse(rawData); + await Akari.init(); + const N = 200; + const testCase = data.test1[0]; + const title = testCase.title; + const desc = testCase.desc; + const tags = testCase.tags; + const time = performance.now(); + for (let i = 0; i < N; i++){ + await Akari.classifyVideo( + title, + desc, + tags + ); + } + const end = performance.now(); + const elapsed = (end - time) / SECOND; + const throughput = N / elapsed; + assertGreaterOrEqual(throughput, 100); + console.log(`Akari AI throughput: ${throughput.toFixed(1)} samples / sec`) +}); \ No newline at end of file