36 lines
1.0 KiB
Python
36 lines
1.0 KiB
Python
from safetensors import safe_open
|
||
from safetensors.torch import save_file
|
||
import torch
|
||
|
||
# 配置路径
|
||
model_path = "./model/embedding/model.safetensors"
|
||
save_path = "./model/embedding/int8_model.safetensors"
|
||
|
||
# 加载原始嵌入层
|
||
with safe_open(model_path, framework="pt") as f:
|
||
embeddings_tensor = f.get_tensor("embeddings")
|
||
|
||
# 计算极值
|
||
min_val = torch.min(embeddings_tensor)
|
||
max_val = torch.max(embeddings_tensor)
|
||
|
||
# 计算量化参数
|
||
scale = (max_val - min_val) / 255 # int8 的范围是 256 个值(-128 到 127)
|
||
|
||
# 将浮点数映射到 int8 范围
|
||
int8_tensor = torch.round((embeddings_tensor - min_val) / scale).to(torch.int8) - 128
|
||
|
||
# 确保与原张量形状一致
|
||
assert int8_tensor.shape == embeddings_tensor.shape
|
||
|
||
# 保存映射后的 int8 张量
|
||
save_file({"embeddings": int8_tensor}, save_path)
|
||
|
||
# 输出反映射公式
|
||
print("int8 反映射公式:")
|
||
m = min_val.item()
|
||
am = abs(min_val.item())
|
||
sign = "-" if m < 0 else "+"
|
||
print(f"int8_tensor = (int8_value + 128) × {scale.item()} {sign} {am}")
|
||
|
||
print("int8 映射完成!") |