1
0

update: filter model benchmark

This commit is contained in:
alikia2x (寒寒) 2025-02-22 22:53:40 +08:00
parent 73b96e869d
commit 15312f4078
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6

View File

@ -1,4 +1,4 @@
import { AutoTokenizer } from "@huggingface/transformers";
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime";
// 配置参数
@ -14,7 +14,7 @@ const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] =
ort.InferenceSession.create(onnxEmbeddingQuantizedPath)
]);
let tokenizer: any;
let tokenizer: PreTrainedTokenizer;
// 初始化分词器
async function loadTokenizer() {
@ -66,11 +66,12 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
}
// 指标计算函数
function calculateMetrics(labels: number[], predictions: number[]): {
function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): {
accuracy: number,
precision: number,
recall: number,
f1: number
f1: number,
speed: string
} {
// 初始化混淆矩阵
const classCount = Math.max(...labels, ...predictions) + 1;
@ -104,7 +105,8 @@ function calculateMetrics(labels: number[], predictions: number[]): {
accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length,
precision,
recall,
f1
f1,
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`
};
}
@ -115,7 +117,7 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
recall: number;
f1: number;
}> {
const data = await Deno.readTextFile("./data/filter/output.jsonl");
const data = await Deno.readTextFile("./data/filter/test.jsonl");
const samples = data.split("\n")
.map(line => {
try { return JSON.parse(line); }
@ -125,7 +127,8 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
const allPredictions: number[] = [];
const allLabels: number[] = [];
const t = new Date().getTime();
for (const sample of samples) {
try {
const embeddings = await getONNXEmbeddings([
@ -142,8 +145,9 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
console.error("Processing error:", error);
}
}
const elapsed = new Date().getTime() - t;
return calculateMetrics(allLabels, allPredictions);
return calculateMetrics(allLabels, allPredictions, elapsed);
}
// 主函数
@ -151,18 +155,14 @@ async function main() {
await loadTokenizer();
// 评估原始模型
const t = new Date().getTime();
const originalMetrics = await evaluateModel(sessionEmbeddingOriginal);
console.log("Original Model Metrics:");
console.table(originalMetrics);
console.log(`Original Model Metrics: ${new Date().getTime() - t}ms`);
// 评估量化模型
const t2 = new Date().getTime();
const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized);
console.log("Quantized Model Metrics:");
console.table(quantizedMetrics);
console.log(`Quantized Model Metrics: ${new Date().getTime() - t2}ms`);
}
await main();