sparkastML/translate-old/zh-en/train.py

203 lines
6.4 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AdamW, get_scheduler
from sacrebleu.metrics import BLEU
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from dataloader.multiTrans19 import MultiTRANS19
writer = SummaryWriter()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
train_set_size = 80000
valid_set_size = 2000
test_data_size = 0
last_1k_loss = []
kmean_loss = 0.0
total_loss = 0.0
best_bleu = 0.0
step = 0
max_input_length = 128
max_target_length = 128
batch_size = 8
learning_rate = 1e-5
epoch_num = 1
# 检查点文件路径默认为None
# checkpoint_path = None
checkpoint_path = "./saves/checkpoint_76500.bin" # 如果要从检查点继续训练,设置此路径
#data = Wikititle("./data/wikititles-v3.zh-en.tsv")
data = MultiTRANS19("./data/translation2019zh/translation2019zh_train.json")
print(len(data))
train_data, valid_data, test_data = random_split(data, [train_set_size, valid_set_size, test_data_size])
# data = TRANS("./data/translation2019zh/translation2019zh_train.json")
# train_data, valid_data = random_split(data, [train_set_size, valid_set_size])
# test_data = TRANS("./data/translation2019zh/translation2019zh_valid.json")
model_checkpoint = "Helsinki-NLP/opus-mt-zh-en"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model = model.to(device)
# 如果指定了检查点路径,则从检查点加载模型状态
if checkpoint_path is not None:
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint_data = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint_data["model_state_dict"])
total_loss = checkpoint_data.get("total_loss", 0.0)
step = checkpoint_data.get("step", 0)
kmean_loss = total_loss / step
last_1k_loss = [kmean_loss] * 1000
def collote_fn(batch_samples):
batch_inputs, batch_targets = [], []
for sample in batch_samples:
batch_inputs.append(sample["chinese"])
batch_targets.append(sample["english"])
batch_data = tokenizer(
batch_inputs,
padding=True,
max_length=max_input_length,
truncation=True,
return_tensors="pt",
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
batch_targets,
padding=True,
max_length=max_target_length,
truncation=True,
return_tensors="pt",
)["input_ids"]
batch_data["decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels
)
end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
labels[idx][end_idx + 1 :] = -100
batch_data["labels"] = labels
batch_data = {k: v.to(device) for k, v in batch_data.items()}
return batch_data
train_dataloader = DataLoader(
train_data, batch_size=batch_size, shuffle=True, collate_fn=collote_fn
)
valid_dataloader = DataLoader(
valid_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn
)
test_dataloader = DataLoader(
test_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn
)
def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss, step):
progress_bar = tqdm(range(len(dataloader)))
progress_bar.set_description(f"loss: {0:>7f}")
model.train()
for batch, batch_data in enumerate(dataloader, start=1):
outputs = model(**batch_data)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
total_loss += loss.item()
del last_1k_loss[0]
last_1k_loss.append(loss.item())
kmean_loss = sum(last_1k_loss) / len(last_1k_loss)
progress_bar.set_description(
f"loss: {kmean_loss:>7f}"
)
progress_bar.update(1)
step += 1
writer.add_scalar("Loss", kmean_loss, step)
writer.add_scalar("Overall Loss", total_loss / step, step)
if step % 250 == 0:
checkpoint = {
"model_state_dict": model.state_dict(),
"total_loss": total_loss,
"kmean_loss": kmean_loss,
"step": step,
}
torch.save(checkpoint, f"./saves/checkpoint_{step}.bin")
return total_loss, step
bleu = BLEU()
def test_loop(dataloader, model):
preds, labels = [], []
model.eval()
for batch_data in tqdm(dataloader):
with torch.no_grad():
generated_tokens = (
model.generate(
batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
max_length=max_target_length,
no_repeat_ngram_size=3,
)
.cpu()
.numpy()
)
label_tokens = batch_data["labels"].cpu().numpy()
decoded_preds = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)
label_tokens = np.where(
label_tokens != -100, label_tokens, tokenizer.pad_token_id
)
decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)
preds += [pred.strip() for pred in decoded_preds]
labels += [[label.strip()] for label in decoded_labels]
bleu_score = bleu.corpus_score(preds, labels).score
print(f"BLEU: {bleu_score:>0.2f}\n")
return bleu_score
optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=int(0.1 * epoch_num * len(train_dataloader)),
num_training_steps=epoch_num * len(train_dataloader),
)
for t in range(epoch_num):
print(f"Epoch {t+1}/{epoch_num}\n {'-'*20}")
total_loss, step = train_loop(
train_dataloader, model, optimizer, lr_scheduler, t + 1, total_loss, step
)
valid_bleu = test_loop(valid_dataloader, model)
print("saving new weights...\n")
checkpoint = {
"model_state_dict": model.state_dict(),
"total_loss": total_loss,
"kmean_loss": kmean_loss,
"step": step,
}
torch.save(checkpoint, f"./saves/step_{step}_bleu_{valid_bleu:0.2f}.bin")
print("Done!")