{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "e95d90ec-1f93-45d9-ab8a-ee3d0bae293d", "metadata": {}, "outputs": [], "source": [ "import random\n", "import os\n", "import numpy as np\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader, random_split\n", "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", "from transformers import AdamW, get_scheduler\n", "from sacrebleu.metrics import BLEU\n", "from tqdm.auto import tqdm\n", "import json" ] }, { "cell_type": "code", "execution_count": 3, "id": "9b8e703a-a5b5-43bf-9b12-2220d869145a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cpu device\n" ] } ], "source": [ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f'Using {device} device')\n", "\n", "max_dataset_size = 22000\n", "train_set_size = 20000\n", "valid_set_size = 2000\n", "\n", "max_input_length = 128\n", "max_target_length = 128\n", "\n", "batch_size = 16\n", "learning_rate = 1e-5\n", "epoch_num = 3" ] }, { "cell_type": "code", "execution_count": 4, "id": "3db1484a-e923-44b9-a2e6-52178a8c09ee", "metadata": {}, "outputs": [], "source": [ "class TRANS(Dataset):\n", " def __init__(self, data_file):\n", " self.data = self.load_data(data_file)\n", " \n", " def load_data(self, data_file):\n", " Data = {}\n", " with open(data_file, 'rt', encoding='utf-8') as f:\n", " for idx, line in enumerate(f):\n", " if idx >= max_dataset_size:\n", " break\n", " sample = json.loads(line.strip())\n", " Data[idx] = sample\n", " return Data\n", " \n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " return self.data[idx]\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "0258cad4-f498-4952-ac29-e103ae8e9041", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/models/marian/tokenization_marian.py:175: UserWarning: Recommended: pip install sacremoses.\n", " warnings.warn(\"Recommended: pip install sacremoses.\")\n", "/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n" ] } ], "source": [ "data = TRANS('./data/translation2019zh/translation2019zh_train.json')\n", "train_data, valid_data = random_split(data, [train_set_size, valid_set_size])\n", "test_data = TRANS('./data/translation2019zh/translation2019zh_valid.json')\n", "\n", "model_checkpoint = \"Helsinki-NLP/opus-mt-zh-en\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 6, "id": "437fb69c-59f6-48f0-9c74-330cf4862b22", "metadata": {}, "outputs": [], "source": [ "def collote_fn(batch_samples):\n", " batch_inputs, batch_targets = [], []\n", " for sample in batch_samples:\n", " batch_inputs.append(sample['chinese'])\n", " batch_targets.append(sample['english'])\n", " batch_data = tokenizer(\n", " batch_inputs, \n", " padding=True, \n", " max_length=max_input_length,\n", " truncation=True, \n", " return_tensors=\"pt\"\n", " )\n", " with tokenizer.as_target_tokenizer():\n", " labels = tokenizer(\n", " batch_targets, \n", " padding=True, \n", " max_length=max_target_length,\n", " truncation=True, \n", " return_tensors=\"pt\"\n", " )[\"input_ids\"]\n", " batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)\n", " end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]\n", " for idx, end_idx in enumerate(end_token_index):\n", " labels[idx][end_idx+1:] = -100\n", " batch_data['labels'] = labels\n", " return batch_data\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "b9f261d8-02ca-47fc-92d7-6d495ae9c6a1", "metadata": {}, "outputs": [], "source": [ "train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collote_fn)\n", "valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn)\n", "test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn)" ] }, { "cell_type": "code", "execution_count": 8, "id": "6fcfa14a-a81b-4a3f-a459-cc0c06f4fa70", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] } ], "source": [ "def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):\n", " progress_bar = tqdm(range(len(dataloader)))\n", " progress_bar.set_description(f'loss: {0:>7f}')\n", " finish_batch_num = (epoch-1) * len(dataloader)\n", " \n", " model.train()\n", " for batch, batch_data in enumerate(dataloader, start=1):\n", " batch_data = batch_data.to(device)\n", " outputs = model(**batch_data)\n", " loss = outputs.loss\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " lr_scheduler.step()\n", "\n", " total_loss += loss.item()\n", " progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')\n", " progress_bar.update(1)\n", " return total_loss\n", "\n", "bleu = BLEU()\n", "\n", "def test_loop(dataloader, model):\n", " preds, labels = [], []\n", " \n", " model.eval()\n", " for batch_data in tqdm(dataloader):\n", " batch_data = batch_data.to(device)\n", " with torch.no_grad():\n", " generated_tokens = model.generate(\n", " batch_data[\"input_ids\"],\n", " attention_mask=batch_data[\"attention_mask\"],\n", " max_length=max_target_length,\n", " ).cpu().numpy()\n", " label_tokens = batch_data[\"labels\"].cpu().numpy()\n", " \n", " decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n", " label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)\n", " decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)\n", "\n", " preds += [pred.strip() for pred in decoded_preds]\n", " labels += [[label.strip()] for label in decoded_labels]\n", " bleu_score = bleu.corpus_score(preds, labels).score\n", " print(f\"BLEU: {bleu_score:>0.2f}\\n\")\n", " return bleu_score\n", "\n", "optimizer = AdamW(model.parameters(), lr=learning_rate)\n", "lr_scheduler = get_scheduler(\n", " \"linear\",\n", " optimizer=optimizer,\n", " num_warmup_steps=0,\n", " num_training_steps=epoch_num*len(train_dataloader),\n", ")\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "12068522-df42-484f-97f1-13ce588bf47b", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'" ] }, { "cell_type": "code", "execution_count": 12, "id": "896ba74b-1a6a-402c-b94a-e9cf47bb0d65", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "-------------------------------\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0453b70899854c0191a93b53748ddaa0", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/12500 [00:00 6\u001b[0m total_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr_scheduler\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtotal_loss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m valid_bleu \u001b[38;5;241m=\u001b[39m test_loop(valid_dataloader, model)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m valid_bleu \u001b[38;5;241m>\u001b[39m best_bleu:\n", "Cell \u001b[0;32mIn[10], line 13\u001b[0m, in \u001b[0;36mtrain_loop\u001b[0;34m(dataloader, model, optimizer, lr_scheduler, epoch, total_loss)\u001b[0m\n\u001b[1;32m 10\u001b[0m loss \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mloss\n\u001b[1;32m 12\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 13\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 15\u001b[0m lr_scheduler\u001b[38;5;241m.\u001b[39mstep()\n", "File \u001b[0;32m/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/torch/_tensor.py:522\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 512\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 513\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 514\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 515\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 520\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 521\u001b[0m )\n\u001b[0;32m--> 522\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 523\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 524\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/torch/autograd/__init__.py:347\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 342\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 344\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 346\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 347\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 348\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 349\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 350\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 351\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 353\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 355\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/torch/autograd/graph.py:818\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 816\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m 817\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 818\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 819\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 820\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 821\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 822\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n", "\u001b[0;31mRuntimeError\u001b[0m: MPS backend out of memory (MPS allocated: 9.37 GB, other allocations: 8.66 GB, max allowed: 18.13 GB). Tried to allocate 222.17 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)." ] } ], "source": [ "epoch_num = 3\n", "total_loss = 0.\n", "best_bleu = 0.\n", "for t in range(epoch_num):\n", " print(f\"Epoch {t+1}/{epoch_num}\\n-------------------------------\")\n", " total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)\n", " valid_bleu = test_loop(valid_dataloader, model)\n", " if valid_bleu > best_bleu:\n", " best_bleu = valid_bleu\n", " print('saving new weights...\\n')\n", " torch.save(\n", " model.state_dict(), \n", " f'epoch_{t+1}_valid_bleu_{valid_bleu:0.2f}_model_weights.bin'\n", " )\n", "print(\"Done!\")" ] }, { "cell_type": "code", "execution_count": null, "id": "6fd3439a-058a-4220-9b65-b355b52f74b5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }