sparkastML/translate-old/zh-en/train.ipynb

339 lines
18 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "e95d90ec-1f93-45d9-ab8a-ee3d0bae293d",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import os\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader, random_split\n",
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
"from transformers import AdamW, get_scheduler\n",
"from sacrebleu.metrics import BLEU\n",
"from tqdm.auto import tqdm\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9b8e703a-a5b5-43bf-9b12-2220d869145a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cpu device\n"
]
}
],
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"print(f'Using {device} device')\n",
"\n",
"max_dataset_size = 22000\n",
"train_set_size = 20000\n",
"valid_set_size = 2000\n",
"\n",
"max_input_length = 128\n",
"max_target_length = 128\n",
"\n",
"batch_size = 16\n",
"learning_rate = 1e-5\n",
"epoch_num = 3"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3db1484a-e923-44b9-a2e6-52178a8c09ee",
"metadata": {},
"outputs": [],
"source": [
"class TRANS(Dataset):\n",
" def __init__(self, data_file):\n",
" self.data = self.load_data(data_file)\n",
" \n",
" def load_data(self, data_file):\n",
" Data = {}\n",
" with open(data_file, 'rt', encoding='utf-8') as f:\n",
" for idx, line in enumerate(f):\n",
" if idx >= max_dataset_size:\n",
" break\n",
" sample = json.loads(line.strip())\n",
" Data[idx] = sample\n",
" return Data\n",
" \n",
" def __len__(self):\n",
" return len(self.data)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0258cad4-f498-4952-ac29-e103ae8e9041",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/models/marian/tokenization_marian.py:175: UserWarning: Recommended: pip install sacremoses.\n",
" warnings.warn(\"Recommended: pip install sacremoses.\")\n",
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
}
],
"source": [
"data = TRANS('./data/translation2019zh/translation2019zh_train.json')\n",
"train_data, valid_data = random_split(data, [train_set_size, valid_set_size])\n",
"test_data = TRANS('./data/translation2019zh/translation2019zh_valid.json')\n",
"\n",
"model_checkpoint = \"Helsinki-NLP/opus-mt-zh-en\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n",
"model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "437fb69c-59f6-48f0-9c74-330cf4862b22",
"metadata": {},
"outputs": [],
"source": [
"def collote_fn(batch_samples):\n",
" batch_inputs, batch_targets = [], []\n",
" for sample in batch_samples:\n",
" batch_inputs.append(sample['chinese'])\n",
" batch_targets.append(sample['english'])\n",
" batch_data = tokenizer(\n",
" batch_inputs, \n",
" padding=True, \n",
" max_length=max_input_length,\n",
" truncation=True, \n",
" return_tensors=\"pt\"\n",
" )\n",
" with tokenizer.as_target_tokenizer():\n",
" labels = tokenizer(\n",
" batch_targets, \n",
" padding=True, \n",
" max_length=max_target_length,\n",
" truncation=True, \n",
" return_tensors=\"pt\"\n",
" )[\"input_ids\"]\n",
" batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)\n",
" end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]\n",
" for idx, end_idx in enumerate(end_token_index):\n",
" labels[idx][end_idx+1:] = -100\n",
" batch_data['labels'] = labels\n",
" return batch_data\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b9f261d8-02ca-47fc-92d7-6d495ae9c6a1",
"metadata": {},
"outputs": [],
"source": [
"train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collote_fn)\n",
"valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn)\n",
"test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6fcfa14a-a81b-4a3f-a459-cc0c06f4fa70",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
}
],
"source": [
"def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):\n",
" progress_bar = tqdm(range(len(dataloader)))\n",
" progress_bar.set_description(f'loss: {0:>7f}')\n",
" finish_batch_num = (epoch-1) * len(dataloader)\n",
" \n",
" model.train()\n",
" for batch, batch_data in enumerate(dataloader, start=1):\n",
" batch_data = batch_data.to(device)\n",
" outputs = model(**batch_data)\n",
" loss = outputs.loss\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
"\n",
" total_loss += loss.item()\n",
" progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')\n",
" progress_bar.update(1)\n",
" return total_loss\n",
"\n",
"bleu = BLEU()\n",
"\n",
"def test_loop(dataloader, model):\n",
" preds, labels = [], []\n",
" \n",
" model.eval()\n",
" for batch_data in tqdm(dataloader):\n",
" batch_data = batch_data.to(device)\n",
" with torch.no_grad():\n",
" generated_tokens = model.generate(\n",
" batch_data[\"input_ids\"],\n",
" attention_mask=batch_data[\"attention_mask\"],\n",
" max_length=max_target_length,\n",
" ).cpu().numpy()\n",
" label_tokens = batch_data[\"labels\"].cpu().numpy()\n",
" \n",
" decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
" label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)\n",
" decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)\n",
"\n",
" preds += [pred.strip() for pred in decoded_preds]\n",
" labels += [[label.strip()] for label in decoded_labels]\n",
" bleu_score = bleu.corpus_score(preds, labels).score\n",
" print(f\"BLEU: {bleu_score:>0.2f}\\n\")\n",
" return bleu_score\n",
"\n",
"optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
"lr_scheduler = get_scheduler(\n",
" \"linear\",\n",
" optimizer=optimizer,\n",
" num_warmup_steps=0,\n",
" num_training_steps=epoch_num*len(train_dataloader),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "12068522-df42-484f-97f1-13ce588bf47b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "896ba74b-1a6a-402c-b94a-e9cf47bb0d65",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"-------------------------------\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0453b70899854c0191a93b53748ddaa0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12500 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:4126: UserWarning: `as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your labels by using the argument `text_target` of the regular `__call__` method (either in the same call as your input texts if you use the same keyword arguments, or in a separate call.\n",
" warnings.warn(\n"
]
},
{
"ename": "RuntimeError",
"evalue": "MPS backend out of memory (MPS allocated: 9.37 GB, other allocations: 8.66 GB, max allowed: 18.13 GB). Tried to allocate 222.17 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[12], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(epoch_num):\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mt\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch_num\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m-------------------------------\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 6\u001b[0m total_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr_scheduler\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtotal_loss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m valid_bleu \u001b[38;5;241m=\u001b[39m test_loop(valid_dataloader, model)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m valid_bleu \u001b[38;5;241m>\u001b[39m best_bleu:\n",
"Cell \u001b[0;32mIn[10], line 13\u001b[0m, in \u001b[0;36mtrain_loop\u001b[0;34m(dataloader, model, optimizer, lr_scheduler, epoch, total_loss)\u001b[0m\n\u001b[1;32m 10\u001b[0m loss \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mloss\n\u001b[1;32m 12\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 13\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 15\u001b[0m lr_scheduler\u001b[38;5;241m.\u001b[39mstep()\n",
"File \u001b[0;32m/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/torch/_tensor.py:522\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 512\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 513\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 514\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 515\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 520\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 521\u001b[0m )\n\u001b[0;32m--> 522\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 523\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 524\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/torch/autograd/__init__.py:347\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 342\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 344\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 346\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 347\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 348\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 349\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 350\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 351\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 353\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 355\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/torch/autograd/graph.py:818\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 816\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m 817\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 818\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 819\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 820\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 821\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 822\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
"\u001b[0;31mRuntimeError\u001b[0m: MPS backend out of memory (MPS allocated: 9.37 GB, other allocations: 8.66 GB, max allowed: 18.13 GB). Tried to allocate 222.17 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)."
]
}
],
"source": [
"epoch_num = 3\n",
"total_loss = 0.\n",
"best_bleu = 0.\n",
"for t in range(epoch_num):\n",
" print(f\"Epoch {t+1}/{epoch_num}\\n-------------------------------\")\n",
" total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)\n",
" valid_bleu = test_loop(valid_dataloader, model)\n",
" if valid_bleu > best_bleu:\n",
" best_bleu = valid_bleu\n",
" print('saving new weights...\\n')\n",
" torch.save(\n",
" model.state_dict(), \n",
" f'epoch_{t+1}_valid_bleu_{valid_bleu:0.2f}_model_weights.bin'\n",
" )\n",
"print(\"Done!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fd3439a-058a-4220-9b65-b355b52f74b5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}