{ "cells": [ { "cell_type": "code", "execution_count": 16, "id": "07b697c8-5cc2-4021-9ab8-e7e3c90065ee", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/models/marian/tokenization_marian.py:175: UserWarning: Recommended: pip install sacremoses.\n", " warnings.warn(\"Recommended: pip install sacremoses.\")\n", "/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n", "/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_69064/1647496252.py:14: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", " model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))\n" ] } ], "source": [ "import time\n", "import torch\n", "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", "\n", "# 定义参数\n", "model_checkpoint = \"Helsinki-NLP/opus-mt-zh-en\"\n", "checkpoint_path = \"./saves/step_74500_valid_bleu_30.28_model_weights.bin\" # 假设使用训练中的checkpoint\n", "\n", "# 加载tokenizer和模型\n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n", "\n", "# 加载checkpoint\n", "model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))\n", "model.eval()\n", "\n", "# 将模型转移到设备\n", "device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n", "model = model.to(device)" ] }, { "cell_type": "code", "execution_count": 24, "id": "ccfb5004-2bdd-4d64-88a3-2af96b87092c", "metadata": {}, "outputs": [], "source": [ "def infer_translation(input_text, model, tokenizer, max_length=128, num_beams=1, length_penalty=1.2):\n", " # 记录推理开始时间\n", " start_time = time.time()\n", "\n", " # 预处理输入文本\n", " inputs = tokenizer(\n", " input_text,\n", " return_tensors=\"pt\",\n", " padding=\"max_length\",\n", " max_length=max_length,\n", " ).to(device)\n", "\n", " # 模型生成翻译\n", " with torch.no_grad():\n", " output_tokens = model.generate(\n", " inputs[\"input_ids\"],\n", " max_length=max_length,\n", " num_beams=num_beams,\n", " length_penalty=length_penalty,\n", " early_stopping=True,\n", " no_repeat_ngram_size=2,\n", " temperature = 0.3,\n", " top_p = 0.85,\n", " do_sample = False\n", " )\n", "\n", " # 解码生成的tokens为文本\n", " translation = tokenizer.decode(output_tokens[0], skip_special_tokens=True)\n", "\n", " # 记录推理结束时间\n", " end_time = time.time()\n", " inference_time = end_time - start_time\n", "\n", " return translation, inference_time\n", "\n", "def translate(input_text, model, tokenizer):\n", " lines = input_text.splitlines()\n", " \n", " # 存储每一行的翻译结果\n", " translations = []\n", " total_time = 0 \n", " \n", " # 对每一行进行翻译\n", " for line in lines:\n", " if line.strip() == \"\":\n", " translations.append(\"\")\n", " continue\n", " #对于长行按句翻译\n", " if len(line) > 64 and '。' in line:\n", " sentences = line.split('。')\n", " translated_sentences=[]\n", " for sentence in sentences:\n", " if sentence.strip() == \"\":\n", " continue\n", " translation, time_cost = infer_translation(sentence, model, tokenizer)\n", " translated_sentences.append(translation)\n", " total_time += time_cost\n", " #print(sentence,translation)\n", " translations.append(\" \".join(translated_sentences))\n", " else:\n", " translation, time_cost = infer_translation(line, model, tokenizer)\n", " #print(line,translation)\n", " translations.append(translation)\n", " total_time += time_cost\n", " \n", " final_translation = \"\\n\".join(translations)\n", " \n", " return final_translation, total_time\n" ] }, { "cell_type": "code", "execution_count": 25, "id": "d5d35c96-3c4a-487c-ac26-d3d97f1208a6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:567: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.3` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", " warnings.warn(\n", "/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:572: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.85` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n", " warnings.warn(\n", "/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:615: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.\n", " warnings.warn(\n", "/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:634: UserWarning: `num_beams` is set to 1. However, `length_penalty` is set to `1.2` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `length_penalty`.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Original Text: \n", "自2000年左右,台湾的珍珠奶茶传入中国大陆,市场规模逐步扩大。当地不断推出新口味的奶茶、水果茶和奶盖茶等创新饮品,并提供多样化的配料选择,统称为新式茶饮。2018年起,奶茶品牌开始采用网红营销策略,使得部分城市门店顾客络绎不绝。尽管消费者有多达两千种的搭配选择,但销量最高的依旧是珍珠、红豆和布丁这三种经典配料。\n", "面对激烈的市场竞争,茶饮品牌开始区分不同的档次,从使用红茶粉和奶精的低成本产品,到采用新鲜牛奶和现场煮制的高级奶茶,甚至高端茶叶如大红袍、龙井茶也成为一些品牌的选用。\n", "\n", "\n", "Translated Text: \n", "Since about 2000, the Pearl Milk Tea of Taiwan has been spreading into the mainland, and the market has gradually expanded The new tea, fruit tea and milk tea are introduced in the local market, and the variety of ingredients is offered, collectively known as new-style tea. Since 2018, the milk tea brand has adopted a mesh marketing strategy, which has made some city stores more and more customers. Despite the fact that consumers have as many as 2, 000 combinations, the highest sales are still the three classic ingredients: pearls and red beans and pudding.\n", "In the face of fierce market competition, tea and tea brands began to differentiate between low-cost products using red tea powder and cream, high-grade milk tea made from fresh milk and live, even high tea such as big red robes and dragon well tea have become a few brand selections.\n", "\n", "Inference Time: 3.8918 seconds\n" ] } ], "source": [ "# 用户输入\n", "input_text = '''自2000年左右,台湾的珍珠奶茶传入中国大陆,市场规模逐步扩大。当地不断推出新口味的奶茶、水果茶和奶盖茶等创新饮品,并提供多样化的配料选择,统称为新式茶饮。2018年起,奶茶品牌开始采用网红营销策略,使得部分城市门店顾客络绎不绝。尽管消费者有多达两千种的搭配选择,但销量最高的依旧是珍珠、红豆和布丁这三种经典配料。\n", "面对激烈的市场竞争,茶饮品牌开始区分不同的档次,从使用红茶粉和奶精的低成本产品,到采用新鲜牛奶和现场煮制的高级奶茶,甚至高端茶叶如大红袍、龙井茶也成为一些品牌的选用。'''\n", "\n", "# 进行推理并测量时间\n", "translated_text, time_taken = translate(input_text, model, tokenizer)\n", "\n", "# 输出结果\n", "print(f\"Original Text: \\n{input_text}\\n\\n\")\n", "print(f\"Translated Text: \\n{translated_text}\\n\")\n", "print(f\"Inference Time: {time_taken:.4f} seconds\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }