cvsa/filter/quantize.py

36 lines
1.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from safetensors import safe_open
from safetensors.torch import save_file
import torch
# 配置路径
model_path = "./model/embedding/model.safetensors"
save_path = "./model/embedding/int8_model.safetensors"
# 加载原始嵌入层
with safe_open(model_path, framework="pt") as f:
embeddings_tensor = f.get_tensor("embeddings")
# 计算极值
min_val = torch.min(embeddings_tensor)
max_val = torch.max(embeddings_tensor)
# 计算量化参数
scale = (max_val - min_val) / 255 # int8 的范围是 256 个值(-128 到 127
# 将浮点数映射到 int8 范围
int8_tensor = torch.round((embeddings_tensor - min_val) / scale).to(torch.int8) - 128
# 确保与原张量形状一致
assert int8_tensor.shape == embeddings_tensor.shape
# 保存映射后的 int8 张量
save_file({"embeddings": int8_tensor}, save_path)
# 输出反映射公式
print("int8 反映射公式:")
m = min_val.item()
am = abs(min_val.item())
sign = "-" if m < 0 else "+"
print(f"int8_tensor = (int8_value + 128) × {scale.item()} {sign} {am}")
print("int8 映射完成!")