update: translate
improve speed
This commit is contained in:
parent
12b9b910f4
commit
bb0aa5b79b
48
translate/zh-en/dataloader/multiTrans19.py
Executable file
48
translate/zh-en/dataloader/multiTrans19.py
Executable file
@ -0,0 +1,48 @@
|
||||
import json, random
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
max_dataset_size = 82000
|
||||
|
||||
class MultiTRANS19(Dataset):
|
||||
def __init__(self, data_file):
|
||||
self.data = self.load_data(data_file)
|
||||
|
||||
def load_data(self, data_file):
|
||||
Data = []
|
||||
file_lines = []
|
||||
|
||||
with open(data_file, "rt", encoding="utf-8") as f:
|
||||
file_lines = f.readlines()
|
||||
|
||||
combine_number_list = []
|
||||
for _ in range(max_dataset_size):
|
||||
num = random.randint(2, 7)
|
||||
combine_number_list.append(num)
|
||||
|
||||
file_lines = random.sample(file_lines, sum(combine_number_list))
|
||||
|
||||
total = 0
|
||||
for combine_count in combine_number_list:
|
||||
num_combination = combine_number_list[combine_count]
|
||||
sample = {
|
||||
"chinese": "",
|
||||
"english": ""
|
||||
}
|
||||
for line in file_lines[total: total+num_combination]:
|
||||
try:
|
||||
line_sample = json.loads(line.strip())
|
||||
sample["chinese"] += line_sample["chinese"]
|
||||
sample["english"] += line_sample["english"]
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error decoding line: {e}")
|
||||
|
||||
Data.append(sample)
|
||||
total+=num_combination
|
||||
|
||||
return Data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
@ -2,23 +2,10 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 19,
|
||||
"id": "07b697c8-5cc2-4021-9ab8-e7e3c90065ee",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/models/marian/tokenization_marian.py:175: UserWarning: Recommended: pip install sacremoses.\n",
|
||||
" warnings.warn(\"Recommended: pip install sacremoses.\")\n",
|
||||
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
|
||||
" warnings.warn(\n",
|
||||
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_69064/1647496252.py:14: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
|
||||
" model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"import torch\n",
|
||||
@ -26,37 +13,38 @@
|
||||
"\n",
|
||||
"# 定义参数\n",
|
||||
"model_checkpoint = \"Helsinki-NLP/opus-mt-zh-en\"\n",
|
||||
"checkpoint_path = \"./saves/step_74500_valid_bleu_30.28_model_weights.bin\" # 假设使用训练中的checkpoint\n",
|
||||
"checkpoint_path = \"./saves/step_86500_bleu_29.87.bin\" # 假设使用训练中的checkpoint\n",
|
||||
"\n",
|
||||
"# 加载tokenizer和模型\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n",
|
||||
"\n",
|
||||
"# 加载checkpoint\n",
|
||||
"model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))\n",
|
||||
"#model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')[\"model_state_dict\"])\n",
|
||||
"model.eval()\n",
|
||||
"\n",
|
||||
"# 将模型转移到设备\n",
|
||||
"device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
|
||||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
"model = model.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 10,
|
||||
"id": "ccfb5004-2bdd-4d64-88a3-2af96b87092c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def infer_translation(input_text, model, tokenizer, max_length=128, num_beams=1, length_penalty=1.2):\n",
|
||||
"def infer_translation_batch(input_texts, model, tokenizer, max_length=512, num_beams=1, length_penalty=1):\n",
|
||||
" # 记录推理开始时间\n",
|
||||
" start_time = time.time()\n",
|
||||
"\n",
|
||||
" # 预处理输入文本\n",
|
||||
" # 预处理输入文本(批量处理)\n",
|
||||
" inputs = tokenizer(\n",
|
||||
" input_text,\n",
|
||||
" input_texts,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
" padding=\"max_length\",\n",
|
||||
" padding=True, # 使用动态填充,对齐批量输入的长度\n",
|
||||
" truncation=True,\n",
|
||||
" max_length=max_length,\n",
|
||||
" ).to(device)\n",
|
||||
"\n",
|
||||
@ -64,54 +52,41 @@
|
||||
" with torch.no_grad():\n",
|
||||
" output_tokens = model.generate(\n",
|
||||
" inputs[\"input_ids\"],\n",
|
||||
" max_length=max_length,\n",
|
||||
" num_beams=num_beams,\n",
|
||||
" length_penalty=length_penalty,\n",
|
||||
" early_stopping=True,\n",
|
||||
" no_repeat_ngram_size=2,\n",
|
||||
" temperature = 0.3,\n",
|
||||
" top_p = 0.85,\n",
|
||||
" do_sample = False\n",
|
||||
" early_stopping=False,\n",
|
||||
" #temperature=0.5,\n",
|
||||
" #top_p=0.90,\n",
|
||||
" do_sample=False\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # 解码生成的tokens为文本\n",
|
||||
" translation = tokenizer.decode(output_tokens[0], skip_special_tokens=True)\n",
|
||||
" # 解码生成的tokens为文本(批量处理)\n",
|
||||
" translations = [\n",
|
||||
" tokenizer.decode(output, skip_special_tokens=True) for output in output_tokens\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # 记录推理结束时间\n",
|
||||
" end_time = time.time()\n",
|
||||
" inference_time = end_time - start_time\n",
|
||||
"\n",
|
||||
" return translation, inference_time\n",
|
||||
" return translations, inference_time\n",
|
||||
"\n",
|
||||
"def translate(input_text, model, tokenizer):\n",
|
||||
"def translate(input_text, model, tokenizer, batch_size=16):\n",
|
||||
" lines = input_text.splitlines()\n",
|
||||
" \n",
|
||||
" # 存储每一行的翻译结果\n",
|
||||
" translations = []\n",
|
||||
" total_time = 0 \n",
|
||||
" total_time = 0\n",
|
||||
" \n",
|
||||
" # 对每一行进行翻译\n",
|
||||
" for line in lines:\n",
|
||||
" if line.strip() == \"\":\n",
|
||||
" translations.append(\"\")\n",
|
||||
" # 分批处理\n",
|
||||
" for i in range(0, len(lines), batch_size):\n",
|
||||
" batch_lines = [line for line in lines[i:i + batch_size] if line.strip()]\n",
|
||||
" if not batch_lines:\n",
|
||||
" translations.extend([\"\"] * len(batch_lines))\n",
|
||||
" continue\n",
|
||||
" #对于长行按句翻译\n",
|
||||
" if len(line) > 64 and '。' in line:\n",
|
||||
" sentences = line.split('。')\n",
|
||||
" translated_sentences=[]\n",
|
||||
" for sentence in sentences:\n",
|
||||
" if sentence.strip() == \"\":\n",
|
||||
" continue\n",
|
||||
" translation, time_cost = infer_translation(sentence, model, tokenizer)\n",
|
||||
" translated_sentences.append(translation)\n",
|
||||
" total_time += time_cost\n",
|
||||
" #print(sentence,translation)\n",
|
||||
" translations.append(\" \".join(translated_sentences))\n",
|
||||
" else:\n",
|
||||
" translation, time_cost = infer_translation(line, model, tokenizer)\n",
|
||||
" #print(line,translation)\n",
|
||||
" translations.append(translation)\n",
|
||||
" total_time += time_cost\n",
|
||||
" batch_translations, time_cost = infer_translation_batch(batch_lines, model, tokenizer)\n",
|
||||
" translations.extend(batch_translations)\n",
|
||||
" total_time += time_cost\n",
|
||||
" \n",
|
||||
" final_translation = \"\\n\".join(translations)\n",
|
||||
" \n",
|
||||
@ -120,45 +95,100 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 20,
|
||||
"id": "d5d35c96-3c4a-487c-ac26-d3d97f1208a6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.3` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
|
||||
" warnings.warn(\n",
|
||||
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:572: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.85` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
|
||||
" warnings.warn(\n",
|
||||
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:615: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.\n",
|
||||
" warnings.warn(\n",
|
||||
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:634: UserWarning: `num_beams` is set to 1. However, `length_penalty` is set to `1.2` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `length_penalty`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Original Text: \n",
|
||||
"自2000年左右,台湾的珍珠奶茶传入中国大陆,市场规模逐步扩大。当地不断推出新口味的奶茶、水果茶和奶盖茶等创新饮品,并提供多样化的配料选择,统称为新式茶饮。2018年起,奶茶品牌开始采用网红营销策略,使得部分城市门店顾客络绎不绝。尽管消费者有多达两千种的搭配选择,但销量最高的依旧是珍珠、红豆和布丁这三种经典配料。\n",
|
||||
"面对激烈的市场竞争,茶饮品牌开始区分不同的档次,从使用红茶粉和奶精的低成本产品,到采用新鲜牛奶和现场煮制的高级奶茶,甚至高端茶叶如大红袍、龙井茶也成为一些品牌的选用。\n",
|
||||
"\n",
|
||||
"为了降低Transformer翻译模型(如基于Helsinki-NLP的Opus模型)的推理时间并提高性能,以下是一些常见且有效的优化方法:\n",
|
||||
"\n",
|
||||
"1. 模型量化\n",
|
||||
"简介:量化是通过使用低精度数值表示模型权重(例如将32位浮点数转换为8位整数)来减少模型的计算量和内存占用,从而加快推理速度。\n",
|
||||
"方法:\n",
|
||||
"Post-training quantization (PTQ):模型训练后对权重进行量化。\n",
|
||||
"Quantization-aware training (QAT):在训练时引入量化,通常效果比PTQ更好。\n",
|
||||
"2. 模型剪枝\n",
|
||||
"简介:剪枝通过移除模型中对推理结果影响较小的权重和节点来减小模型规模,从而加速推理。\n",
|
||||
"方法:\n",
|
||||
"结构化剪枝:移除整个层、注意力头或神经元。\n",
|
||||
"非结构化剪枝:移除个别的低权重参数。\n",
|
||||
"3. 减少模型尺寸\n",
|
||||
"简介:通过使用更小的模型架构(例如减少层数、隐藏层维度或注意力头的数量),可以减少计算量和推理时间。\n",
|
||||
"方法:使用较小版本的模型,例如opus-mt-small,或手动调整Transformer的超参数。\n",
|
||||
"4. 启用混合精度推理\n",
|
||||
"简介:混合精度推理允许部分计算使用半精度浮点数(FP16),从而减少内存占用并提高推理速度。\n",
|
||||
"工具:\n",
|
||||
"NVIDIA的TensorRT和**AMP (Automatic Mixed Precision)**是常用的工具,可以自动处理FP16的计算。\n",
|
||||
"5. 使用高效的解码策略\n",
|
||||
"简介:解码策略的选择影响推理速度。常用的解码方式如Beam Search虽然精度较高,但速度较慢。\n",
|
||||
"方法:\n",
|
||||
"降低beam size:减小beam size可以显著加快解码速度,虽然可能会略微牺牲翻译质量。\n",
|
||||
"Top-k sampling和Nucleus Sampling (Top-p sampling):这些方法通过限制词汇选择的范围来加快推理速度。\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Translated Text: \n",
|
||||
"Since about 2000, the Pearl Milk Tea of Taiwan has been spreading into the mainland, and the market has gradually expanded The new tea, fruit tea and milk tea are introduced in the local market, and the variety of ingredients is offered, collectively known as new-style tea. Since 2018, the milk tea brand has adopted a mesh marketing strategy, which has made some city stores more and more customers. Despite the fact that consumers have as many as 2, 000 combinations, the highest sales are still the three classic ingredients: pearls and red beans and pudding.\n",
|
||||
"In the face of fierce market competition, tea and tea brands began to differentiate between low-cost products using red tea powder and cream, high-grade milk tea made from fresh milk and live, even high tea such as big red robes and dragon well tea have become a few brand selections.\n",
|
||||
"To reduce the time of reasoning and improve performance of the Transformer translation model (e.g., the Opus model based on Helsinki-NLP), the following are common and effective methods of optimization:\n",
|
||||
"Model quantification\n",
|
||||
"Profile: Quantification reduces model computing and memory occupancy by using low precision values to indicate model weights (e.g., converting 32-digit float points to 8-digit integer values), thereby accelerating reasoning.\n",
|
||||
"Methodology:\n",
|
||||
"Post-training Quantisation (PTQ): Quantifying weights after model training.\n",
|
||||
"Quantification-aware trading (QAT): Quantification is introduced in training, usually with better results than PTQ.\n",
|
||||
"Model cutting\n",
|
||||
"Profile: Cuts reduce the size of the model by removing weights and nodes in the model that influence the reasoning results less.\n",
|
||||
"Methodology:\n",
|
||||
"Structured cut-off: removes the whole layer, attention head or neuron.\n",
|
||||
"Unstructured cut-off: removes individual low weight parameters.\n",
|
||||
"3. Reduction of model size\n",
|
||||
"Profile: The calculation and reasoning time can be reduced by using smaller model structures (e.g., reducing the number of layers, hidden layers or the number of attention points).\n",
|
||||
"Method: Use smaller versions of models, such as opus-mt-small, or manually adjust Transformer's hyperparameters.\n",
|
||||
"4. Enable mixed precision reasoning\n",
|
||||
"Introduction: The mixed precision reasoning allows for partial calculation of semi-precision floats (FP16), thereby reducing memory occupancy and increasing the speed of reasoning.\n",
|
||||
"Tools:\n",
|
||||
"The NVIDIA TensorRT and **AMP (Automatic Mixed Precision)** are commonly used tools that can automatically process FP16 calculations.\n",
|
||||
"Use of efficient decoding strategies\n",
|
||||
"Profile: The selection of the decoding strategy affects the speed of reasoning. Common decoding methods such as BeamSearch are more precise but slow.\n",
|
||||
"Methodology:\n",
|
||||
"Lower beam size: Reduction of beam size can significantly accelerate decoding, although it may be at the expense of translation quality.\n",
|
||||
"Top-k sampling and Nucleus Sampling (Top-p sampling): These methods accelerate reasoning by limiting the range of vocabulary selections.\n",
|
||||
"\n",
|
||||
"Inference Time: 3.8918 seconds\n"
|
||||
"Inference Time: 2.8956 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 用户输入\n",
|
||||
"input_text = '''自2000年左右,台湾的珍珠奶茶传入中国大陆,市场规模逐步扩大。当地不断推出新口味的奶茶、水果茶和奶盖茶等创新饮品,并提供多样化的配料选择,统称为新式茶饮。2018年起,奶茶品牌开始采用网红营销策略,使得部分城市门店顾客络绎不绝。尽管消费者有多达两千种的搭配选择,但销量最高的依旧是珍珠、红豆和布丁这三种经典配料。\n",
|
||||
"面对激烈的市场竞争,茶饮品牌开始区分不同的档次,从使用红茶粉和奶精的低成本产品,到采用新鲜牛奶和现场煮制的高级奶茶,甚至高端茶叶如大红袍、龙井茶也成为一些品牌的选用。'''\n",
|
||||
"input_text = '''\n",
|
||||
"为了降低Transformer翻译模型(如基于Helsinki-NLP的Opus模型)的推理时间并提高性能,以下是一些常见且有效的优化方法:\n",
|
||||
"\n",
|
||||
"1. 模型量化\n",
|
||||
"简介:量化是通过使用低精度数值表示模型权重(例如将32位浮点数转换为8位整数)来减少模型的计算量和内存占用,从而加快推理速度。\n",
|
||||
"方法:\n",
|
||||
"Post-training quantization (PTQ):模型训练后对权重进行量化。\n",
|
||||
"Quantization-aware training (QAT):在训练时引入量化,通常效果比PTQ更好。\n",
|
||||
"2. 模型剪枝\n",
|
||||
"简介:剪枝通过移除模型中对推理结果影响较小的权重和节点来减小模型规模,从而加速推理。\n",
|
||||
"方法:\n",
|
||||
"结构化剪枝:移除整个层、注意力头或神经元。\n",
|
||||
"非结构化剪枝:移除个别的低权重参数。\n",
|
||||
"3. 减少模型尺寸\n",
|
||||
"简介:通过使用更小的模型架构(例如减少层数、隐藏层维度或注意力头的数量),可以减少计算量和推理时间。\n",
|
||||
"方法:使用较小版本的模型,例如opus-mt-small,或手动调整Transformer的超参数。\n",
|
||||
"4. 启用混合精度推理\n",
|
||||
"简介:混合精度推理允许部分计算使用半精度浮点数(FP16),从而减少内存占用并提高推理速度。\n",
|
||||
"工具:\n",
|
||||
"NVIDIA的TensorRT和**AMP (Automatic Mixed Precision)**是常用的工具,可以自动处理FP16的计算。\n",
|
||||
"5. 使用高效的解码策略\n",
|
||||
"简介:解码策略的选择影响推理速度。常用的解码方式如Beam Search虽然精度较高,但速度较慢。\n",
|
||||
"方法:\n",
|
||||
"降低beam size:减小beam size可以显著加快解码速度,虽然可能会略微牺牲翻译质量。\n",
|
||||
"Top-k sampling和Nucleus Sampling (Top-p sampling):这些方法通过限制词汇选择的范围来加快推理速度。\n",
|
||||
"'''\n",
|
||||
"\n",
|
||||
"# 进行推理并测量时间\n",
|
||||
"translated_text, time_taken = translate(input_text, model, tokenizer)\n",
|
||||
@ -168,6 +198,14 @@
|
||||
"print(f\"Translated Text: \\n{translated_text}\\n\")\n",
|
||||
"print(f\"Inference Time: {time_taken:.4f} seconds\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e4a44b25-a8bb-4a82-964a-0811c34c256c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
16
translate/zh-en/train.py
Normal file → Executable file
16
translate/zh-en/train.py
Normal file → Executable file
@ -6,15 +6,15 @@ from transformers import AdamW, get_scheduler
|
||||
from sacrebleu.metrics import BLEU
|
||||
from tqdm.auto import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from dataloader.wikititle import Wikititle
|
||||
from dataloader.multiTrans19 import MultiTRANS19
|
||||
|
||||
writer = SummaryWriter()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using {device} device")
|
||||
|
||||
train_set_size = 95000
|
||||
valid_set_size = 5000
|
||||
train_set_size = 80000
|
||||
valid_set_size = 2000
|
||||
test_data_size = 0
|
||||
|
||||
last_1k_loss = []
|
||||
@ -32,10 +32,12 @@ epoch_num = 1
|
||||
|
||||
# 检查点文件路径,默认为None
|
||||
# checkpoint_path = None
|
||||
checkpoint_path = "./saves/checkpoint_74000.bin" # 如果要从检查点继续训练,设置此路径
|
||||
checkpoint_path = "./saves/checkpoint_76500.bin" # 如果要从检查点继续训练,设置此路径
|
||||
|
||||
|
||||
data = Wikititle("./data/wikititles-v3.zh-en.tsv")
|
||||
#data = Wikititle("./data/wikititles-v3.zh-en.tsv")
|
||||
data = MultiTRANS19("./data/translation2019zh/translation2019zh_train.json")
|
||||
print(len(data))
|
||||
train_data, valid_data, test_data = random_split(data, [train_set_size, valid_set_size, test_data_size])
|
||||
|
||||
# data = TRANS("./data/translation2019zh/translation2019zh_train.json")
|
||||
@ -135,7 +137,7 @@ def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss, st
|
||||
"kmean_loss": kmean_loss,
|
||||
"step": step,
|
||||
}
|
||||
torch.save(checkpoint, f"checkpoint_{step}.bin")
|
||||
torch.save(checkpoint, f"./saves/checkpoint_{step}.bin")
|
||||
return total_loss, step
|
||||
|
||||
|
||||
@ -195,6 +197,6 @@ for t in range(epoch_num):
|
||||
"kmean_loss": kmean_loss,
|
||||
"step": step,
|
||||
}
|
||||
torch.save(checkpoint, f"step_{step}_valid_bleu_{valid_bleu:0.2f}_model_weights.bin")
|
||||
torch.save(checkpoint, f"./saves/step_{step}_bleu_{valid_bleu:0.2f}.bin")
|
||||
|
||||
print("Done!")
|
||||
|
Loading…
Reference in New Issue
Block a user