update: filter model benchmark
This commit is contained in:
parent
73b96e869d
commit
15312f4078
@ -1,4 +1,4 @@
|
||||
import { AutoTokenizer } from "@huggingface/transformers";
|
||||
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers";
|
||||
import * as ort from "onnxruntime";
|
||||
|
||||
// 配置参数
|
||||
@ -14,7 +14,7 @@ const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] =
|
||||
ort.InferenceSession.create(onnxEmbeddingQuantizedPath)
|
||||
]);
|
||||
|
||||
let tokenizer: any;
|
||||
let tokenizer: PreTrainedTokenizer;
|
||||
|
||||
// 初始化分词器
|
||||
async function loadTokenizer() {
|
||||
@ -66,11 +66,12 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||
}
|
||||
|
||||
// 指标计算函数
|
||||
function calculateMetrics(labels: number[], predictions: number[]): {
|
||||
function calculateMetrics(labels: number[], predictions: number[], elapsedTime: number): {
|
||||
accuracy: number,
|
||||
precision: number,
|
||||
recall: number,
|
||||
f1: number
|
||||
f1: number,
|
||||
speed: string
|
||||
} {
|
||||
// 初始化混淆矩阵
|
||||
const classCount = Math.max(...labels, ...predictions) + 1;
|
||||
@ -104,7 +105,8 @@ function calculateMetrics(labels: number[], predictions: number[]): {
|
||||
accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length,
|
||||
precision,
|
||||
recall,
|
||||
f1
|
||||
f1,
|
||||
speed: `${(labels.length / (elapsedTime / 1000)).toFixed(1)} samples/sec`
|
||||
};
|
||||
}
|
||||
|
||||
@ -115,7 +117,7 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
|
||||
recall: number;
|
||||
f1: number;
|
||||
}> {
|
||||
const data = await Deno.readTextFile("./data/filter/output.jsonl");
|
||||
const data = await Deno.readTextFile("./data/filter/test.jsonl");
|
||||
const samples = data.split("\n")
|
||||
.map(line => {
|
||||
try { return JSON.parse(line); }
|
||||
@ -125,7 +127,8 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
|
||||
|
||||
const allPredictions: number[] = [];
|
||||
const allLabels: number[] = [];
|
||||
|
||||
|
||||
const t = new Date().getTime();
|
||||
for (const sample of samples) {
|
||||
try {
|
||||
const embeddings = await getONNXEmbeddings([
|
||||
@ -142,8 +145,9 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
|
||||
console.error("Processing error:", error);
|
||||
}
|
||||
}
|
||||
const elapsed = new Date().getTime() - t;
|
||||
|
||||
return calculateMetrics(allLabels, allPredictions);
|
||||
return calculateMetrics(allLabels, allPredictions, elapsed);
|
||||
}
|
||||
|
||||
// 主函数
|
||||
@ -151,18 +155,14 @@ async function main() {
|
||||
await loadTokenizer();
|
||||
|
||||
// 评估原始模型
|
||||
const t = new Date().getTime();
|
||||
const originalMetrics = await evaluateModel(sessionEmbeddingOriginal);
|
||||
console.log("Original Model Metrics:");
|
||||
console.table(originalMetrics);
|
||||
console.log(`Original Model Metrics: ${new Date().getTime() - t}ms`);
|
||||
|
||||
// 评估量化模型
|
||||
const t2 = new Date().getTime();
|
||||
const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized);
|
||||
console.log("Quantized Model Metrics:");
|
||||
console.table(quantizedMetrics);
|
||||
console.log(`Quantized Model Metrics: ${new Date().getTime() - t2}ms`);
|
||||
}
|
||||
|
||||
await main();
|
||||
Loading…
Reference in New Issue
Block a user