add: quantized filter model inference

This commit is contained in:
alikia2x (寒寒) 2025-02-13 07:42:16 +08:00
parent 7e6854db00
commit fd090a25c2
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
11 changed files with 410 additions and 295 deletions

View File

@ -40,7 +40,9 @@
"@bull-board/api": "npm:@bull-board/api",
"@bull-board/express": "npm:@bull-board/express",
"express": "npm:express",
"src/": "./src/"
"src/": "./src/",
"onnxruntime": "npm:onnxruntime-node",
"chalk": "npm:chalk"
},
"compilerOptions": {
"jsx": "react-jsx",

54
filter/embedding_range.py Normal file
View File

@ -0,0 +1,54 @@
import json
import torch
import random
from embedding import prepare_batch
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
file_path = './data/filter/model_predicted.jsonl'
class Dataset:
def __init__(self, file_path):
all_examples = self.load_data(file_path)
self.examples = all_examples
def load_data(self, file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return [json.loads(line) for line in f]
def __getitem__(self, idx):
end_idx = min((idx + 1) * self.batch_size, len(self.examples))
texts = {
'title': [ex['title'] for ex in self.examples[idx * self.batch_size:end_idx]],
'description': [ex['description'] for ex in self.examples[idx * self.batch_size:end_idx]],
'tags': [",".join(ex['tags']) for ex in self.examples[idx * self.batch_size:end_idx]],
'author_info': [ex['author_info'] for ex in self.examples[idx * self.batch_size:end_idx]]
}
return texts
def __len__(self):
return len(self.examples)
def get_batch(self, idx, batch_size):
self.batch_size = batch_size
return self.__getitem__(idx)
total = 600000
batch_size = 512
batch_num = total // batch_size
dataset = Dataset(file_path)
arr_len = batch_size * 4 * 1024
sample_rate = 0.1
sample_num = int(arr_len * sample_rate)
data = np.array([])
for i in tqdm(range(batch_num)):
batch = dataset.get_batch(i, batch_size)
batch = prepare_batch(batch, device="cpu")
arr = batch.flatten().numpy()
sampled = np.random.choice(arr.shape[0], size=sample_num, replace=False)
data = np.concatenate((data, arr[sampled]), axis=0) if data.size else arr[sampled]
if i % 10 == 0:
np.save('embedding_range.npy', data)
np.save('embedding_range.npy', data)

View File

@ -0,0 +1,43 @@
import numpy as np
import matplotlib.pyplot as plt
# 加载数据
data = np.load("1.npy")
# 绘制直方图,获取频数
n, bins, patches = plt.hist(data, bins=32, density=False, alpha=0.7, color='skyblue')
# 计算数据总数
total_data = len(data)
# 将频数转换为频率
frequencies = n / total_data
# 计算统计信息
max_val = np.max(data)
min_val = np.min(data)
std_dev = np.std(data)
# 设置图形属性
plt.title('Frequency Distribution Histogram')
plt.xlabel('Value')
plt.ylabel('Frequency')
# 重新绘制直方图,使用频率作为高度
plt.cla() # 清除当前坐标轴上的内容
plt.bar([(bins[i] + bins[i+1])/2 for i in range(len(bins)-1)], frequencies, width=[bins[i+1]-bins[i] for i in range(len(bins)-1)], alpha=0.7, color='skyblue')
# 在柱子上注明频率值
for i in range(len(patches)):
plt.text(bins[i]+(bins[i+1]-bins[i])/2, frequencies[i], f'{frequencies[i]:.2e}', ha='center', va='bottom', fontsize=6)
# 在图表一角显示统计信息
stats_text = f"Max: {max_val:.6f}\nMin: {min_val:.6f}\nStd: {std_dev:.4e}"
plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes,
ha='right', va='top', bbox=dict(facecolor='white', edgecolor='black', alpha=0.8))
# 设置 x 轴刻度对齐柱子边界
plt.xticks(bins, fontsize = 6)
# 显示图形
plt.show()

View File

@ -1,111 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV3_4(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 可学习温度系数
self.temperature = nn.Parameter(torch.tensor(1.7))
# 带约束的通道权重使用Sigmoid替代Softmax
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
# 增强的非线性层
self.fc = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.3),
nn.GELU(),
nn.Linear(hidden_dim*2, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
)
# 权重初始化
self._init_weights()
def _init_weights(self):
for layer in self.fc:
if isinstance(layer, nn.Linear):
# 使用ReLU的初始化参数GELU的近似
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里
# 或者使用Xavier初始化更适合通用场景
# nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu'))
nn.init.zeros_(layer.bias)
def forward(self, input_texts, sentence_transformer):
# 合并文本进行批量编码
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
# 冻结的文本编码
with torch.no_grad():
embeddings = torch.tensor(
sentence_transformer.encode(all_texts),
device=next(self.parameters()).device
)
# 分割并加权通道特征
split_sizes = [len(input_texts[name]) for name in self.channel_names]
channel_features = torch.split(embeddings, split_sizes, dim=0)
channel_features = torch.stack(channel_features, dim=1)
# 自适应通道权重Sigmoid约束
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 特征拼接
combined = weighted_features.view(weighted_features.size(0), -1)
return self.fc(combined)
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
class AdaptiveRecallLoss(nn.Module):
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
"""
Args:
class_weights (torch.Tensor): 类别权重
alpha (float): 召回率调节因子0-1
gamma (float): Focal Loss参数
fp_penalty (float): 类别0假阳性惩罚强度
"""
super().__init__()
self.class_weights = class_weights
self.alpha = alpha
self.gamma = gamma
self.fp_penalty = fp_penalty
def forward(self, logits, targets):
# 基础交叉熵损失
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
# Focal Loss组件
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
# 召回率增强(对困难样本加权)
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
# 类别0假阳性惩罚
probs = F.softmax(logits, dim=1)
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
# 总损失
total_loss = recall_loss.mean() + fp_loss / len(targets)
return total_loss

View File

@ -1,148 +0,0 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
import json
from torch.utils.data import Dataset, DataLoader
import numpy as np
class VideoDataset(Dataset):
def __init__(self, data_path, sentence_transformer):
self.data = []
self.sentence_transformer = sentence_transformer
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
self.data.append(json.loads(line))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
title = item["title"]
description = item["description"]
tags = item["tags"]
label = item["label"]
# 获取每个特征的嵌入
title_embedding = self.get_embedding(title)
description_embedding = self.get_embedding(description)
tags_embedding = self.get_embedding(" ".join(tags))
# 将嵌入连接起来
combined_embedding = torch.cat([title_embedding, description_embedding, tags_embedding], dim=0)
return combined_embedding, label
def get_embedding(self, text):
# 使用SentenceTransformer生成嵌入
embedding = self.sentence_transformer.encode(text)
return torch.tensor(embedding)
class VideoClassifier(nn.Module):
def __init__(self, embedding_dim=768, hidden_dim=256, output_dim=3):
super(VideoClassifier, self).__init__()
# 每个特征的嵌入维度是embedding_dim总共有3个特征
total_embedding_dim = embedding_dim * 3
# 全连接层
self.fc1 = nn.Linear(total_embedding_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, embedding_features):
# 全连接层
x = torch.relu(self.fc1(embedding_features))
output = self.fc2(x)
output = self.log_softmax(output)
return output
def train(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
correct = 0
total = 0
for embedding_features, labels in dataloader:
embedding_features = embedding_features.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(embedding_features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct / total
return avg_loss, accuracy
def validate(model, dataloader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for embedding_features, labels in dataloader:
embedding_features = embedding_features.to(device)
labels = labels.to(device)
outputs = model(embedding_features)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct / total
return avg_loss, accuracy
# 超参数
hidden_dim = 256
output_dim = 3
batch_size = 32
num_epochs = 10
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载数据集
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3")
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
dataset = VideoDataset("labeled_data.jsonl", sentence_transformer=sentence_transformer)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型
model = VideoClassifier(embedding_dim=768, hidden_dim=256, output_dim=3).to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = 5
# 训练和验证
for epoch in range(num_epochs):
train_loss, train_acc = train(model, dataloader, criterion, optimizer, device)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
# 保存模型
torch.save(model.state_dict(), "video_classifier.pth")
model.eval() # 设置为评估模式
# 2. 定义推理函数
def predict(model, sentence_transformer, title, description, tags, device):
# 将输入数据转换为嵌入
title_embedding = torch.tensor(sentence_transformer.encode(title)).to(device)
description_embedding = torch.tensor(sentence_transformer.encode(description)).to(device)
tags_embedding = torch.tensor(sentence_transformer.encode(" ".join(tags))).to(device)
# 将嵌入连接起来
combined_embedding = torch.cat([title_embedding, description_embedding, tags_embedding], dim=0).unsqueeze(0)
# 推理
with torch.no_grad():
output = model(combined_embedding)
_, predicted = torch.max(output, 1)
return predicted.item()

View File

@ -2,8 +2,8 @@ import torch
from modelV3_10 import VideoClassifierV3_10
def export_onnx(model_path="./filter/checkpoints/best_model_V3.10.pt",
onnx_path="./model/video_classifier_v3_10.onnx"):
def export_onnx(model_path="./filter/checkpoints/best_model_V3.11.pt",
onnx_path="./model/video_classifier_v3_11.onnx"):
# 初始化模型
model = VideoClassifierV3_10()
model.load_state_dict(torch.load(model_path))

36
filter/quantize.py Normal file
View File

@ -0,0 +1,36 @@
from safetensors import safe_open
from safetensors.torch import save_file
import torch
# 配置路径
model_path = "./model/embedding/model.safetensors"
save_path = "./model/embedding/int8_model.safetensors"
# 加载原始嵌入层
with safe_open(model_path, framework="pt") as f:
embeddings_tensor = f.get_tensor("embeddings")
# 计算极值
min_val = torch.min(embeddings_tensor)
max_val = torch.max(embeddings_tensor)
# 计算量化参数
scale = (max_val - min_val) / 255 # int8 的范围是 256 个值(-128 到 127
# 将浮点数映射到 int8 范围
int8_tensor = torch.round((embeddings_tensor - min_val) / scale).to(torch.int8) - 128
# 确保与原张量形状一致
assert int8_tensor.shape == embeddings_tensor.shape
# 保存映射后的 int8 张量
save_file({"embeddings": int8_tensor}, save_path)
# 输出反映射公式
print("int8 反映射公式:")
m = min_val.item()
am = abs(min_val.item())
sign = "-" if m < 0 else "+"
print(f"int8_tensor = (int8_value + 128) × {scale.item()} {sign} {am}")
print("int8 映射完成!")

View File

@ -1,6 +1,6 @@
import winston, { format, transports } from "npm:winston";
import { TransformableInfo } from "npm:logform";
import chalk from "npm:chalk";
import chalk from "chalk";
const customFormat = format.printf((info: TransformableInfo) => {
const { timestamp, level, message, service, codePath, error } = info;

View File

@ -1,32 +0,0 @@
import { AutoModel, AutoTokenizer, Tensor } from '@huggingface/transformers';
const modelName = "alikia2x/jina-embedding-v3-m2v-1024";
const modelConfig = {
config: { model_type: 'model2vec' },
dtype: 'fp32',
revision: 'refs/pr/1',
cache_dir: undefined,
local_files_only: true,
};
const tokenizerConfig = {
revision: 'refs/pr/2'
};
const model = await AutoModel.from_pretrained(modelName, modelConfig);
const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig);
const texts = ['hello', 'hello world'];
const { input_ids } = await tokenizer(texts, { add_special_tokens: false, return_tensor: false });
const cumsum = arr => arr.reduce((acc, num, i) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets = [0, ...cumsum(input_ids.slice(0, -1).map(x => x.length))];
const flattened_input_ids = input_ids.flat();
const modelInputs = {
input_ids: new Tensor('int64', flattened_input_ids, [flattened_input_ids.length]),
offsets: new Tensor('int64', offsets, [offsets.length])
};
const { embeddings } = await model(modelInputs);
console.log(embeddings.tolist()); // output matches python version

103
lib/ml/filter_inference.ts Normal file
View File

@ -0,0 +1,103 @@
import { AutoTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime";
// 配置参数
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
const onnxEmbeddingOriginalPath = "./model/model.onnx";
// 初始化会话
const [sessionClassifier, sessionEmbedding] = await Promise.all([
ort.InferenceSession.create(onnxClassifierPath),
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
]);
let tokenizer: any;
// 初始化分词器
async function loadTokenizer() {
const tokenizerConfig = { local_files_only: true };
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig);
}
function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
}
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
const { input_ids } = await tokenizer(texts, {
add_special_tokens: false,
return_tensor: false
});
// 构造输入参数
const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const flattened_input_ids = input_ids.flat();
// 准备ONNX输入
const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length])
};
// 执行推理
const { embeddings } = await session.run(inputs);
return Array.from(embeddings.data as Float32Array);
}
// 分类推理函数
async function runClassification(embeddings: number[]): Promise<number[]> {
const inputTensor = new ort.Tensor(
Float32Array.from(embeddings),
[1, 4, 1024]
);
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
return softmax(logits.data as Float32Array);
}
async function processInputTexts(
title: string,
description: string,
tags: string,
author_info: string,
): Promise<number[]> {
const embeddings = await getONNXEmbeddings([
title,
description,
tags,
author_info
], sessionEmbedding);
const probabilities = await runClassification(embeddings);
return probabilities;
}
async function main() {
await loadTokenizer();
const titleText = `【洛天依&乐正绫&心华原创】归一【时之歌Project】`
const descriptionText = " 《归一》Vocaloid ver\r\n出品泛音堂 / 作词:冥凰 / 作曲:汤汤 / 编曲&amp;混音iAn / 调教花之祭P\r\n后期向南 / 人设Pora / 场景A舍长 / PVSung Hsu麻薯映画 / 海报:易玄玑 \r\n唱乐正绫 &amp; 洛天依 &amp; 心华\r\n时之歌Project东国世界观歌曲《归一》双本家VC版\r\nMP3http://5sing.kugou.com/yc/3006072.html \r\n伴奏http://5sing.kugou.com/bz/2";
const tagsText = '乐正绫,洛天依,心华,VOCALOID中文曲,时之歌,花之祭P';
const authorInfoText = "时之歌Project: 欢迎光临时之歌~\r\n官博http://weibo.com/songoftime\r\n官网http://www.songoftime.com/";
try {
const probabilities = await processInputTexts(titleText, descriptionText, tagsText, authorInfoText);
console.log("Class Probabilities:", probabilities);
console.log(`Class 0 Probability: ${probabilities[0]}`);
console.log(`Class 1 Probability: ${probabilities[1]}`);
console.log(`Class 2 Probability: ${probabilities[2]}`);
// Hold the session for 10s
await new Promise((resolve) => setTimeout(resolve, 10000));
} catch (error) {
console.error("Error processing texts:", error);
}
}
await main();

168
lib/ml/quant_benchmark.ts Normal file
View File

@ -0,0 +1,168 @@
import { AutoTokenizer } from "@huggingface/transformers";
import * as ort from "onnxruntime";
// 配置参数
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
const onnxEmbeddingOriginalPath = "./model/embedding_original.onnx";
const onnxEmbeddingQuantizedPath = "./model/model.onnx";
// 初始化会话
const [sessionClassifier, sessionEmbeddingOriginal, sessionEmbeddingQuantized] = await Promise.all([
ort.InferenceSession.create(onnxClassifierPath),
ort.InferenceSession.create(onnxEmbeddingOriginalPath),
ort.InferenceSession.create(onnxEmbeddingQuantizedPath)
]);
let tokenizer: any;
// 初始化分词器
async function loadTokenizer() {
const tokenizerConfig = { local_files_only: true };
tokenizer = await AutoTokenizer.from_pretrained(sentenceTransformerModelName, tokenizerConfig);
}
// 新的嵌入生成函数使用ONNX
async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession): Promise<number[]> {
const { input_ids } = await tokenizer(texts, {
add_special_tokens: false,
return_tensor: false
});
// 构造输入参数
const cumsum = (arr: number[]): number[] =>
arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))];
const flattened_input_ids = input_ids.flat();
// 准备ONNX输入
const inputs = {
input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [flattened_input_ids.length]),
offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length])
};
// 执行推理
const { embeddings } = await session.run(inputs);
return Array.from(embeddings.data as Float32Array);
}
function softmax(logits: Float32Array): number[] {
const maxLogit = Math.max(...logits);
const exponents = logits.map((logit) => Math.exp(logit - maxLogit));
const sumOfExponents = exponents.reduce((sum, exp) => sum + exp, 0);
return Array.from(exponents.map((exp) => exp / sumOfExponents));
}
// 分类推理函数
async function runClassification(embeddings: number[]): Promise<number[]> {
const inputTensor = new ort.Tensor(
Float32Array.from(embeddings),
[1, 4, 1024]
);
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
return softmax(logits.data as Float32Array);
}
// 指标计算函数
function calculateMetrics(labels: number[], predictions: number[]): {
accuracy: number,
precision: number,
recall: number,
f1: number
} {
// 初始化混淆矩阵
const classCount = Math.max(...labels, ...predictions) + 1;
const matrix = Array.from({ length: classCount }, () =>
Array.from({ length: classCount }, () => 0)
);
// 填充矩阵
labels.forEach((trueLabel, i) => {
matrix[trueLabel][predictions[i]]++;
});
// 计算各指标
let totalTP = 0, totalFP = 0, totalFN = 0;
for (let c = 0; c < classCount; c++) {
const TP = matrix[c][c];
const FP = matrix.flatMap((row, i) => i === c ? [] : [row[c]]).reduce((a, b) => a + b, 0);
const FN = matrix[c].filter((_, i) => i !== c).reduce((a, b) => a + b, 0);
totalTP += TP;
totalFP += FP;
totalFN += FN;
}
const precision = totalTP / (totalTP + totalFP);
const recall = totalTP / (totalTP + totalFN);
const f1 = 2 * (precision * recall) / (precision + recall) || 0;
return {
accuracy: labels.filter((l, i) => l === predictions[i]).length / labels.length,
precision,
recall,
f1
};
}
// 改造后的评估函数
async function evaluateModel(session: ort.InferenceSession): Promise<{
accuracy: number;
precision: number;
recall: number;
f1: number;
}> {
const data = await Deno.readTextFile("./data/filter/output.jsonl");
const samples = data.split("\n")
.map(line => {
try { return JSON.parse(line); }
catch { return null; }
})
.filter(Boolean);
const allPredictions: number[] = [];
const allLabels: number[] = [];
for (const sample of samples) {
try {
const embeddings = await getONNXEmbeddings([
sample.title,
sample.description,
sample.tags.join(","),
sample.author_info
], session);
const probabilities = await runClassification(embeddings);
allPredictions.push(probabilities.indexOf(Math.max(...probabilities)));
allLabels.push(sample.label);
} catch (error) {
console.error("Processing error:", error);
}
}
return calculateMetrics(allLabels, allPredictions);
}
// 主函数
async function main() {
await loadTokenizer();
// 评估原始模型
const t = new Date().getTime();
const originalMetrics = await evaluateModel(sessionEmbeddingOriginal);
console.log("Original Model Metrics:");
console.table(originalMetrics);
console.log(`Original Model Metrics: ${new Date().getTime() - t}ms`);
// 评估量化模型
const t2 = new Date().getTime();
const quantizedMetrics = await evaluateModel(sessionEmbeddingQuantized);
console.log("Quantized Model Metrics:");
console.table(quantizedMetrics);
console.log(`Quantized Model Metrics: ${new Date().getTime() - t2}ms`);
}
await main();