update: translate

improve speed
This commit is contained in:
alikia2x (寒寒) 2024-09-07 23:00:15 +08:00
parent 12b9b910f4
commit bb0aa5b79b
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
3 changed files with 172 additions and 84 deletions

View File

@ -0,0 +1,48 @@
import json, random
from torch.utils.data import Dataset
max_dataset_size = 82000
class MultiTRANS19(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
Data = []
file_lines = []
with open(data_file, "rt", encoding="utf-8") as f:
file_lines = f.readlines()
combine_number_list = []
for _ in range(max_dataset_size):
num = random.randint(2, 7)
combine_number_list.append(num)
file_lines = random.sample(file_lines, sum(combine_number_list))
total = 0
for combine_count in combine_number_list:
num_combination = combine_number_list[combine_count]
sample = {
"chinese": "",
"english": ""
}
for line in file_lines[total: total+num_combination]:
try:
line_sample = json.loads(line.strip())
sample["chinese"] += line_sample["chinese"]
sample["english"] += line_sample["english"]
except json.JSONDecodeError as e:
print(f"Error decoding line: {e}")
Data.append(sample)
total+=num_combination
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]

View File

@ -2,23 +2,10 @@
"cells": [
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 19,
"id": "07b697c8-5cc2-4021-9ab8-e7e3c90065ee",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/models/marian/tokenization_marian.py:175: UserWarning: Recommended: pip install sacremoses.\n",
" warnings.warn(\"Recommended: pip install sacremoses.\")\n",
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n",
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_69064/1647496252.py:14: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))\n"
]
}
],
"outputs": [],
"source": [
"import time\n",
"import torch\n",
@ -26,37 +13,38 @@
"\n",
"# 定义参数\n",
"model_checkpoint = \"Helsinki-NLP/opus-mt-zh-en\"\n",
"checkpoint_path = \"./saves/step_74500_valid_bleu_30.28_model_weights.bin\" # 假设使用训练中的checkpoint\n",
"checkpoint_path = \"./saves/step_86500_bleu_29.87.bin\" # 假设使用训练中的checkpoint\n",
"\n",
"# 加载tokenizer和模型\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n",
"\n",
"# 加载checkpoint\n",
"model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))\n",
"#model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')[\"model_state_dict\"])\n",
"model.eval()\n",
"\n",
"# 将模型转移到设备\n",
"device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 10,
"id": "ccfb5004-2bdd-4d64-88a3-2af96b87092c",
"metadata": {},
"outputs": [],
"source": [
"def infer_translation(input_text, model, tokenizer, max_length=128, num_beams=1, length_penalty=1.2):\n",
"def infer_translation_batch(input_texts, model, tokenizer, max_length=512, num_beams=1, length_penalty=1):\n",
" # 记录推理开始时间\n",
" start_time = time.time()\n",
"\n",
" # 预处理输入文本\n",
" # 预处理输入文本(批量处理)\n",
" inputs = tokenizer(\n",
" input_text,\n",
" input_texts,\n",
" return_tensors=\"pt\",\n",
" padding=\"max_length\",\n",
" padding=True, # 使用动态填充,对齐批量输入的长度\n",
" truncation=True,\n",
" max_length=max_length,\n",
" ).to(device)\n",
"\n",
@ -64,54 +52,41 @@
" with torch.no_grad():\n",
" output_tokens = model.generate(\n",
" inputs[\"input_ids\"],\n",
" max_length=max_length,\n",
" num_beams=num_beams,\n",
" length_penalty=length_penalty,\n",
" early_stopping=True,\n",
" no_repeat_ngram_size=2,\n",
" temperature = 0.3,\n",
" top_p = 0.85,\n",
" do_sample = False\n",
" early_stopping=False,\n",
" #temperature=0.5,\n",
" #top_p=0.90,\n",
" do_sample=False\n",
" )\n",
"\n",
" # 解码生成的tokens为文本\n",
" translation = tokenizer.decode(output_tokens[0], skip_special_tokens=True)\n",
" # 解码生成的tokens为文本批量处理\n",
" translations = [\n",
" tokenizer.decode(output, skip_special_tokens=True) for output in output_tokens\n",
" ]\n",
"\n",
" # 记录推理结束时间\n",
" end_time = time.time()\n",
" inference_time = end_time - start_time\n",
"\n",
" return translation, inference_time\n",
" return translations, inference_time\n",
"\n",
"def translate(input_text, model, tokenizer):\n",
"def translate(input_text, model, tokenizer, batch_size=16):\n",
" lines = input_text.splitlines()\n",
" \n",
" # 存储每一行的翻译结果\n",
" translations = []\n",
" total_time = 0 \n",
" total_time = 0\n",
" \n",
" # 对每一行进行翻译\n",
" for line in lines:\n",
" if line.strip() == \"\":\n",
" translations.append(\"\")\n",
" # 分批处理\n",
" for i in range(0, len(lines), batch_size):\n",
" batch_lines = [line for line in lines[i:i + batch_size] if line.strip()]\n",
" if not batch_lines:\n",
" translations.extend([\"\"] * len(batch_lines))\n",
" continue\n",
" #对于长行按句翻译\n",
" if len(line) > 64 and '。' in line:\n",
" sentences = line.split('。')\n",
" translated_sentences=[]\n",
" for sentence in sentences:\n",
" if sentence.strip() == \"\":\n",
" continue\n",
" translation, time_cost = infer_translation(sentence, model, tokenizer)\n",
" translated_sentences.append(translation)\n",
" total_time += time_cost\n",
" #print(sentence,translation)\n",
" translations.append(\" \".join(translated_sentences))\n",
" else:\n",
" translation, time_cost = infer_translation(line, model, tokenizer)\n",
" #print(line,translation)\n",
" translations.append(translation)\n",
" total_time += time_cost\n",
" batch_translations, time_cost = infer_translation_batch(batch_lines, model, tokenizer)\n",
" translations.extend(batch_translations)\n",
" total_time += time_cost\n",
" \n",
" final_translation = \"\\n\".join(translations)\n",
" \n",
@ -120,45 +95,100 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 20,
"id": "d5d35c96-3c4a-487c-ac26-d3d97f1208a6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.3` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
" warnings.warn(\n",
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:572: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.85` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
" warnings.warn(\n",
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:615: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.\n",
" warnings.warn(\n",
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:634: UserWarning: `num_beams` is set to 1. However, `length_penalty` is set to `1.2` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `length_penalty`.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original Text: \n",
"自2000年左右台湾的珍珠奶茶传入中国大陆市场规模逐步扩大。当地不断推出新口味的奶茶、水果茶和奶盖茶等创新饮品并提供多样化的配料选择统称为新式茶饮。2018年起奶茶品牌开始采用网红营销策略使得部分城市门店顾客络绎不绝。尽管消费者有多达两千种的搭配选择但销量最高的依旧是珍珠、红豆和布丁这三种经典配料。\n",
"面对激烈的市场竞争,茶饮品牌开始区分不同的档次,从使用红茶粉和奶精的低成本产品,到采用新鲜牛奶和现场煮制的高级奶茶,甚至高端茶叶如大红袍、龙井茶也成为一些品牌的选用。\n",
"\n",
"为了降低Transformer翻译模型如基于Helsinki-NLP的Opus模型的推理时间并提高性能以下是一些常见且有效的优化方法\n",
"\n",
"1. 模型量化\n",
"简介量化是通过使用低精度数值表示模型权重例如将32位浮点数转换为8位整数来减少模型的计算量和内存占用从而加快推理速度。\n",
"方法:\n",
"Post-training quantization (PTQ):模型训练后对权重进行量化。\n",
"Quantization-aware training (QAT)在训练时引入量化通常效果比PTQ更好。\n",
"2. 模型剪枝\n",
"简介:剪枝通过移除模型中对推理结果影响较小的权重和节点来减小模型规模,从而加速推理。\n",
"方法:\n",
"结构化剪枝:移除整个层、注意力头或神经元。\n",
"非结构化剪枝:移除个别的低权重参数。\n",
"3. 减少模型尺寸\n",
"简介:通过使用更小的模型架构(例如减少层数、隐藏层维度或注意力头的数量),可以减少计算量和推理时间。\n",
"方法使用较小版本的模型例如opus-mt-small或手动调整Transformer的超参数。\n",
"4. 启用混合精度推理\n",
"简介混合精度推理允许部分计算使用半精度浮点数FP16从而减少内存占用并提高推理速度。\n",
"工具:\n",
"NVIDIA的TensorRT和**AMP (Automatic Mixed Precision)**是常用的工具可以自动处理FP16的计算。\n",
"5. 使用高效的解码策略\n",
"简介解码策略的选择影响推理速度。常用的解码方式如Beam Search虽然精度较高但速度较慢。\n",
"方法:\n",
"降低beam size减小beam size可以显著加快解码速度虽然可能会略微牺牲翻译质量。\n",
"Top-k sampling和Nucleus Sampling (Top-p sampling):这些方法通过限制词汇选择的范围来加快推理速度。\n",
"\n",
"\n",
"\n",
"Translated Text: \n",
"Since about 2000, the Pearl Milk Tea of Taiwan has been spreading into the mainland, and the market has gradually expanded The new tea, fruit tea and milk tea are introduced in the local market, and the variety of ingredients is offered, collectively known as new-style tea. Since 2018, the milk tea brand has adopted a mesh marketing strategy, which has made some city stores more and more customers. Despite the fact that consumers have as many as 2, 000 combinations, the highest sales are still the three classic ingredients: pearls and red beans and pudding.\n",
"In the face of fierce market competition, tea and tea brands began to differentiate between low-cost products using red tea powder and cream, high-grade milk tea made from fresh milk and live, even high tea such as big red robes and dragon well tea have become a few brand selections.\n",
"To reduce the time of reasoning and improve performance of the Transformer translation model (e.g., the Opus model based on Helsinki-NLP), the following are common and effective methods of optimization:\n",
"Model quantification\n",
"Profile: Quantification reduces model computing and memory occupancy by using low precision values to indicate model weights (e.g., converting 32-digit float points to 8-digit integer values), thereby accelerating reasoning.\n",
"Methodology:\n",
"Post-training Quantisation (PTQ): Quantifying weights after model training.\n",
"Quantification-aware trading (QAT): Quantification is introduced in training, usually with better results than PTQ.\n",
"Model cutting\n",
"Profile: Cuts reduce the size of the model by removing weights and nodes in the model that influence the reasoning results less.\n",
"Methodology:\n",
"Structured cut-off: removes the whole layer, attention head or neuron.\n",
"Unstructured cut-off: removes individual low weight parameters.\n",
"3. Reduction of model size\n",
"Profile: The calculation and reasoning time can be reduced by using smaller model structures (e.g., reducing the number of layers, hidden layers or the number of attention points).\n",
"Method: Use smaller versions of models, such as opus-mt-small, or manually adjust Transformer's hyperparameters.\n",
"4. Enable mixed precision reasoning\n",
"Introduction: The mixed precision reasoning allows for partial calculation of semi-precision floats (FP16), thereby reducing memory occupancy and increasing the speed of reasoning.\n",
"Tools:\n",
"The NVIDIA TensorRT and **AMP (Automatic Mixed Precision)** are commonly used tools that can automatically process FP16 calculations.\n",
"Use of efficient decoding strategies\n",
"Profile: The selection of the decoding strategy affects the speed of reasoning. Common decoding methods such as BeamSearch are more precise but slow.\n",
"Methodology:\n",
"Lower beam size: Reduction of beam size can significantly accelerate decoding, although it may be at the expense of translation quality.\n",
"Top-k sampling and Nucleus Sampling (Top-p sampling): These methods accelerate reasoning by limiting the range of vocabulary selections.\n",
"\n",
"Inference Time: 3.8918 seconds\n"
"Inference Time: 2.8956 seconds\n"
]
}
],
"source": [
"# 用户输入\n",
"input_text = '''自2000年左右台湾的珍珠奶茶传入中国大陆市场规模逐步扩大。当地不断推出新口味的奶茶、水果茶和奶盖茶等创新饮品并提供多样化的配料选择统称为新式茶饮。2018年起奶茶品牌开始采用网红营销策略使得部分城市门店顾客络绎不绝。尽管消费者有多达两千种的搭配选择但销量最高的依旧是珍珠、红豆和布丁这三种经典配料。\n",
"面对激烈的市场竞争,茶饮品牌开始区分不同的档次,从使用红茶粉和奶精的低成本产品,到采用新鲜牛奶和现场煮制的高级奶茶,甚至高端茶叶如大红袍、龙井茶也成为一些品牌的选用。'''\n",
"input_text = '''\n",
"为了降低Transformer翻译模型如基于Helsinki-NLP的Opus模型的推理时间并提高性能以下是一些常见且有效的优化方法\n",
"\n",
"1. 模型量化\n",
"简介量化是通过使用低精度数值表示模型权重例如将32位浮点数转换为8位整数来减少模型的计算量和内存占用从而加快推理速度。\n",
"方法:\n",
"Post-training quantization (PTQ):模型训练后对权重进行量化。\n",
"Quantization-aware training (QAT)在训练时引入量化通常效果比PTQ更好。\n",
"2. 模型剪枝\n",
"简介:剪枝通过移除模型中对推理结果影响较小的权重和节点来减小模型规模,从而加速推理。\n",
"方法:\n",
"结构化剪枝:移除整个层、注意力头或神经元。\n",
"非结构化剪枝:移除个别的低权重参数。\n",
"3. 减少模型尺寸\n",
"简介:通过使用更小的模型架构(例如减少层数、隐藏层维度或注意力头的数量),可以减少计算量和推理时间。\n",
"方法使用较小版本的模型例如opus-mt-small或手动调整Transformer的超参数。\n",
"4. 启用混合精度推理\n",
"简介混合精度推理允许部分计算使用半精度浮点数FP16从而减少内存占用并提高推理速度。\n",
"工具:\n",
"NVIDIA的TensorRT和**AMP (Automatic Mixed Precision)**是常用的工具可以自动处理FP16的计算。\n",
"5. 使用高效的解码策略\n",
"简介解码策略的选择影响推理速度。常用的解码方式如Beam Search虽然精度较高但速度较慢。\n",
"方法:\n",
"降低beam size减小beam size可以显著加快解码速度虽然可能会略微牺牲翻译质量。\n",
"Top-k sampling和Nucleus Sampling (Top-p sampling):这些方法通过限制词汇选择的范围来加快推理速度。\n",
"'''\n",
"\n",
"# 进行推理并测量时间\n",
"translated_text, time_taken = translate(input_text, model, tokenizer)\n",
@ -168,6 +198,14 @@
"print(f\"Translated Text: \\n{translated_text}\\n\")\n",
"print(f\"Inference Time: {time_taken:.4f} seconds\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4a44b25-a8bb-4a82-964a-0811c34c256c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

16
translate/zh-en/train.py Normal file → Executable file
View File

@ -6,15 +6,15 @@ from transformers import AdamW, get_scheduler
from sacrebleu.metrics import BLEU
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from dataloader.wikititle import Wikititle
from dataloader.multiTrans19 import MultiTRANS19
writer = SummaryWriter()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
train_set_size = 95000
valid_set_size = 5000
train_set_size = 80000
valid_set_size = 2000
test_data_size = 0
last_1k_loss = []
@ -32,10 +32,12 @@ epoch_num = 1
# 检查点文件路径默认为None
# checkpoint_path = None
checkpoint_path = "./saves/checkpoint_74000.bin" # 如果要从检查点继续训练,设置此路径
checkpoint_path = "./saves/checkpoint_76500.bin" # 如果要从检查点继续训练,设置此路径
data = Wikititle("./data/wikititles-v3.zh-en.tsv")
#data = Wikititle("./data/wikititles-v3.zh-en.tsv")
data = MultiTRANS19("./data/translation2019zh/translation2019zh_train.json")
print(len(data))
train_data, valid_data, test_data = random_split(data, [train_set_size, valid_set_size, test_data_size])
# data = TRANS("./data/translation2019zh/translation2019zh_train.json")
@ -135,7 +137,7 @@ def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss, st
"kmean_loss": kmean_loss,
"step": step,
}
torch.save(checkpoint, f"checkpoint_{step}.bin")
torch.save(checkpoint, f"./saves/checkpoint_{step}.bin")
return total_loss, step
@ -195,6 +197,6 @@ for t in range(epoch_num):
"kmean_loss": kmean_loss,
"step": step,
}
torch.save(checkpoint, f"step_{step}_valid_bleu_{valid_bleu:0.2f}_model_weights.bin")
torch.save(checkpoint, f"./saves/step_{step}_bleu_{valid_bleu:0.2f}.bin")
print("Done!")