From 15312f407845854a463bb6c065beef486f30f614 Mon Sep 17 00:00:00 2001 From: alikia2x Date: Sat, 22 Feb 2025 22:53:40 +0800 Subject: [PATCH] update: filter model benchmark --- lib/ml/quant_benchmark.ts | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/lib/ml/quant_benchmark.ts b/lib/ml/quant_benchmark.ts index d761792..07777c2 100644 --- a/lib/ml/quant_benchmark.ts +++ b/lib/ml/quant_benchmark.ts @@ -1,4 +1,4 @@ -import { AutoTokenizer } from "@huggingface/transformers"; +import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import * as ort from "onnxruntime"; // 配置参数 @@ -14,7 +14,7 @@ const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = ort.InferenceSession.create(onnxEmbeddingQuantizedPath) ]); -let tokenizer: any; +let tokenizer: PreTrainedTokenizer; // 初始化分词器 async function loadTokenizer() { @@ -66,11 +66,12 @@ async function runClassification(embeddings: number[]): Promise { } // 指标计算函数 -function calculateMetrics(labels: number[], predictions: number[]): { +function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): { accuracy: number, precision: number, recall: number, - f1: number + f1: number, + speed: string } { // 初始化混淆矩阵 const classCount = Math.max(...labels, ...predictions) + 1; @@ -104,7 +105,8 @@ function calculateMetrics(labels: number[], predictions: number[]): { accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length, precision, recall, - f1 + f1, + speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec` }; } @@ -115,7 +117,7 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{ recall: number; f1: number; }> { - const data = await Deno.readTextFile("./data/filter/output.jsonl"); + const data = await Deno.readTextFile("./data/filter/test.jsonl"); const samples = data.split("\n") .map(line => { try { return JSON.parse(line); } @@ -125,7 +127,8 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{ const allPredictions: number[] = []; const allLabels: number[] = []; - + + const t = new Date().getTime(); for (const sample of samples) { try { const embeddings = await getONNXEmbeddings([ @@ -142,8 +145,9 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{ console.error("Processing error:", error); } } + const elapsed = new Date().getTime() - t; - return calculateMetrics(allLabels, allPredictions); + return calculateMetrics(allLabels, allPredictions, elapsed); } // 主函数 @@ -151,18 +155,14 @@ async function main() { await loadTokenizer(); // 评估原始模型 - const t = new Date().getTime(); const originalMetrics = await evaluateModel(sessionEmbeddingOriginal); console.log("Original Model Metrics:"); console.table(originalMetrics); - console.log(`Original Model Metrics: ${new Date().getTime() - t}ms`); // 评估量化模型 - const t2 = new Date().getTime(); const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized); console.log("Quantized Model Metrics:"); console.table(quantizedMetrics); - console.log(`Quantized Model Metrics: ${new Date().getTime() - t2}ms`); } await main(); \ No newline at end of file