This commit is contained in:
alikia2x (寒寒) 2024-09-01 22:17:04 +08:00
commit f28f83b48e
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
9 changed files with 1487 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
runs
.DS_Store
*.onnx
*.pt
*.bin
token_to_id.json
.ipynb_checkpoints

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 alikia2x
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,22 @@
Microsoft.
Copyright (c) Microsoft Corporation.
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,126 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "94ff7007",
"metadata": {},
"source": [
"# Convert to ONNX\n",
"\n",
"This notebook converts our model to [ONNX](https://onnx.ai/) format, which is the open standard for machine learning interoperability. In this way, we can run our model in JS (browser)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "311162fd-f957-4746-b524-25bb3e09efbc",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from torch import nn\n",
"import torch.utils.model_zoo as model_zoo\n",
"import torch.onnx\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fd182c6d-1e77-4bbb-bb53-8321d40ae002",
"metadata": {},
"outputs": [],
"source": [
"class TextCNN(nn.Module):\n",
" def __init__(self, input_dim, num_classes):\n",
" super(TextCNN, self).__init__()\n",
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
" \n",
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
" \n",
" self.dropout = nn.Dropout(0.5)\n",
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
" \n",
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
" \n",
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
" \n",
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
" \n",
" x = torch.cat((x1, x2, x3), dim=1)\n",
" x = self.dropout(x)\n",
" x = self.fc(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bdb597cb-d896-485c-8c9c-897b1d35e8d2",
"metadata": {},
"outputs": [],
"source": [
"model = torch.load(\"model.pt\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9f7a6e64-75f2-4fa9-8d1e-b83099765d02",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Example input: use random embedding vector to simulate real input\n",
"dummy_input = torch.randn(1, 64, 128) # (batch_size, seq_length, embedding_dim)\n",
"\n",
"# Export model\n",
"torch.onnx.export(\n",
" model, # The model to export\n",
" dummy_input, # Example input\n",
" \"model.onnx\", # File name\n",
" input_names=['input'], # Input name (Could customize)\n",
" output_names=['output'], # Output name (Could customize)\n",
" dynamic_axes={\n",
" 'input': {0: 'batch_size', 1: 'seq_length'}, # Dynamic batch and sequence length\n",
" 'output': {0: 'batch_size'}\n",
" },\n",
" opset_version=11 # ONNX versionensure the ONNX runtime supports it\n",
")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,350 @@
{
"weather": [
"天气如何?",
"现在的天气",
"今天的天气预报",
"现在的天气状况",
"今天天气怎么样",
"目前是什么天气",
"今天的天气概述",
"当前天气状况如何",
"今天会下雨吗?",
"今天晴天吗?",
"今天的天气状况如何",
"现在外面是什么天气",
"今天天气好么",
"今天有没有雾霾",
"明天有没有大风",
"今天会不会很冷",
"今天的天气会变化吗",
"今天晚上的天气如何",
"今天夜里会下雨吗",
"明天会下雨吗?",
"北京今天的预报",
"两天后上海的天气情况",
"现在的温度",
"现在多少度",
"外面有多热",
"明天热不热",
"现在的气温是多少",
"今天最高温度是多少",
"今天最低温度是多少",
"现在外面感觉冷吗",
"明天会比今天热吗",
"明天会比今天冷吗",
"今天的温度变化大吗",
"室外的温度是多少",
"达拉斯今天热不热",
"苏州现在天气怎么样",
"how's the weather",
"What's going on with the weather?",
"Can you give me an update on the weather?",
"How's the forecast looking today?",
"Give me a summary of the current weather.",
"Can you tell me the current weather?",
"What is the weather situation at the moment?",
"Could you provide a quick weather update?",
"Is it raining or sunny outside?",
"What's the weather like right now?",
"Tell me the current weather conditions.",
"How about the weather today?",
"What's the weather looking like for the next few hours?",
"Is it going to stay this way all day?",
"Could you give me a brief overview of the weather?",
"What's the general weather situation in our area?",
"Is it cloudy or clear outside?",
"What's the forecast saying for today's weather?",
"Is it going to be a warm day?",
"Are we expecting any storms today?",
"What's the weather condition outside my window?",
"Is it a typical day for this season in terms of weather?",
"how's the weather now?",
"What's the temperature like right now?",
"Can you tell me the current temperature?",
"How hot is it outside?",
"What's the temperature supposed to be today?",
"What is the current temp outside?",
"Could you tell me the outdoor temperature?",
"Is it cold or warm outside?",
"What's the high temperature for today?",
"What's the low temperature expected tonight?",
"How does the temperature feel outside?",
"Is it going to get warmer or cooler today?",
"What's the temperature in the shade?",
"Can you provide the current temp in Celsius?",
"What's the temperature in Fahrenheit right now?",
"Is it too hot to be outside?",
"What's the temperature like in the morning?",
"How about the temperature in the evening?",
"Is it warm enough to go swimming?",
"What's the temperature in the city center?",
"Can you tell me the temp in the nearby area?",
"Is it below freezing outside?",
"What's the average temperature for today?",
"Is the temperature dropping or rising?",
"What should I wear considering the temperature?"
],
"base64": [
"请将数据使用base64编码",
"需要将以下数据base 64编码",
"请将此字符串转为base 58",
"将数据转为base58编码",
"信息base64编码",
"请帮忙编码base64",
"将数据编码为base64",
"base 64编码",
"base64 在线编码",
"在线编码 base 64",
"编码 base64",
"请解码这个base64数据",
"有base64编码字符串需要解码",
"帮忙解码base64",
"将base 58编码转回原数据",
"解码base58信息",
"解码这个base64",
"将base64转文本",
"base64解码",
"base64 解码",
"base58 在线解码",
"在线解码 base58",
"解码 base64",
"Please encode this data with base64:",
"I need to encode the following data in base64",
"Could you encode this string using base64?",
"Convert this data to b64 encoding",
"I want to encode this information with base64",
"Help me encode this in base32",
"Can you encode this data to base64 format?",
"b64 encode",
"base64 encode",
"encode base64",
"base 64 encode online"
],
"url-encode": [
"编码 url",
"URL部分需要编码",
"请将URL部分编码",
"URL编码转换",
"编码url 段",
"URL数据编码",
"请解码这个URL",
"有 url 编码需要解码",
"解码这个URL",
"URL编码转回原URL",
"解码 URL部分",
"解码 URL 段",
"URL 编码转文本",
"encode URL: ",
"I need to encode this URL component: ",
"Convert url encoded",
"encode this URL for safe transmission",
"Encode this URL segment: ",
"decode this url ",
"convert encoded url to text",
"url decoder",
"URL encoder"
],
"html-encode": [
"请编码HTML实体",
"文本转为HTML实体",
"编码为HTML实体",
"文本HTML实体编码",
"预防HTML解析编码",
"HTML实体编码",
"文本HTML使用编码",
"html 转义",
"请解码HTML实体",
"HTML实体需要解码",
"解码HTML实体",
"HTML实体转回文本",
"HTML实体解码",
"解码HTML实体",
"HTML实体转文本",
"html 实体转义",
"html实体 转换",
"html &nbsp 转换",
"html nbsp 意思",
"Please encode HTML entities",
"Convert text to HTML entities",
"Encode to HTML entities",
"Encode text HTML entities",
"Prevent HTML parsing encoding",
"HTML entity encoding",
"Text HTML uses encoding",
"html escape",
"Please decode HTML entities",
"HTML entities need to be decoded",
"Decode HTML entities",
"Convert HTML entities to text",
"Decode HTML entities",
"Decode HTML entities",
"Convert HTML entities to text",
"HTML entity escape",
"html entity conversion",
"html &nbsp conversion",
"html nbsp meaning"
],
"ai.command": [
"写一个TypeScript的HelloWorld代码",
"检查以下内容的语法和清晰度",
"帮助我学习词汇:为我写一个句子让我填空,我会尝试选择正确的选项",
"改进这个Markdown内容的语法和表达用于GitHub仓库的README",
"你能想出一个我包的短名字吗",
"简化这段代码",
"帮助我学习中文",
"如何在网页开发中让屏幕阅读器自动聚焦到新弹出的元素上",
"总结以下文本:",
"这段代码有什么问题吗,或者可以简化吗?",
"生成一个打印'Hello, World!'的Python脚本",
"你能校对这篇论文的语法和标点错误吗?",
"为单词'serendipity'创建十个例句列表",
"你能重新格式化这个JSON使其更易读吗",
"为我的关于健康饮食的博客文章建议一个有创意的标题。",
"重构这个JavaScript函数使其更高效。",
"帮助我练习法语:提供一个有缺失单词的句子,我可以猜。",
"如何在模态框出现在网页上时让按钮自动聚焦",
"为我总结这篇新闻文章",
"你能审查这段代码片段的潜在安全漏洞吗",
"生成一个SQL查询来查找所有在过去30天内注册的用户",
"你能把这个段落翻译成西班牙语吗",
"根据以下过程描述创建一个流程图。",
"写一个计算数字阶乘的Python函数。",
"提供如何在Web应用程序中实现OAuth2的详细解释。",
"你能优化这张图片以便在网站上更快加载吗?",
"为专注于健身的新移动应用建议一些吸引人的标语",
"写一个每天备份我的文档文件夹的Bash脚本",
"帮助我起草一封请求与潜在客户会面的电子邮件。",
"你能把这个Markdown文档转换成HTML吗",
"生成一个从指定网站抓取数据的Python脚本。",
"你能找到单词'meticulous'的同义词吗?",
"写一个基于公共列连接两个表的SQL查询。",
"根据以下指令集创建一个流程图。",
"为一个新的任务管理工具建议一个名字。",
"在不改变其功能的情况下简化这段Python代码。",
"你能帮助我学习日语吗?",
"如何在用户点击网页上的按钮时让警告框出现?",
"把这个研究论文总结成要点",
"你能检查这个算法中是否有任何逻辑错误吗?",
"write a typescript helloworld code",
"Check the following content for grammar and clarity",
"Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
"Improve this markdown content in asepcts like grammar and expression, for a GitHub repo README.",
"can u think of a short name of my package",
"simplify this code",
"help me learn chinese",
"how to let the screen reader automatically focused to an newly poped up element in the web development",
"summarize following text:",
"Is there anything wrong with this code or can it be simplified?",
"generate a Python script that prints 'Hello, World!'",
"Can you proofread this essay for grammar and punctuation errors?",
"Create a list of ten example sentences for the word 'serendipity.'",
"Can you reformat this JSON to be more readable?",
"Suggest a creative title for my blog post about healthy eating.",
"Refactor this JavaScript function to make it more efficient.",
"Help me practice French: provide a sentence with a missing word that I can guess.",
"How do I make a button automatically focus when a modal appears on a webpage?",
"Summarize this news article for me.",
"Can you review this code snippet for potential security vulnerabilities?",
"Generate a SQL query to find all users who signed up in the last 30 days.",
"Can you translate this paragraph into Spanish?",
"Create a flowchart based on the following process description.",
"Write a Python function to calculate the factorial of a number.",
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
"Can you optimize this image for faster loading on a website?",
"Suggest some catchy taglines for a new mobile app focused on fitness.",
"Write a Bash script to back up my documents folder daily.",
"Help me draft an email to request a meeting with a potential client.",
"Can you convert this Markdown document into HTML?",
"Generate a Python script that scrapes data from a specified website.",
"Can you find the synonyms of the word 'meticulous'?",
"Write a SQL query to join two tables based on a common column.",
"Create a flowchart based on the following set of instructions.",
"Suggest a name for a new task management tool.",
"Simplify this Python code without changing its functionality.",
"Can you assist me in learning Japanese?",
"How can I make an alert box appear when a user clicks a button on a webpage?",
"Summarize this research paper into bullet points.",
"Can you check if there are any logical errors in this algorithm?"
],
"ai.question": [
"你认为哪个框架最适合性能敏感的项目?",
"什么是后量子密码学?",
"什么是密钥派生函数",
"什么是线性代数?",
"线性代数在计算机科学中的主要用途是什么?",
"我应该使用哪个IDE来编写Go语言",
"Go vs Java vs Kotlin哪个适合后端",
"哪种编程语言最适合数据分析",
"什么是量子计算",
"什么是哈希函数?",
"什么是微积分?",
"机器学习在金融中的主要应用有哪些?",
"写Python代码最好的文本编辑器是哪个",
"Python vs R vs Julia哪个更适合数据科学",
"监督学习和无监督学习的关键区别是什么?",
"数据库在Web应用程序中的作用是什么",
"什么是区块链技术",
"使用Docker进行应用程序部署的优势是什么",
"哪个云服务提供商提供最好的AI工具",
"加密是如何工作的?",
"负载均衡器在网络架构中的目的是什么?",
"机器学习和深度学习有什么区别",
"软件工程中最常见的设计模式有哪些",
"神经网络是如何学习的",
"使用微服务架构的主要好处是什么",
"编译器和解释器有什么区别?",
"Rust编程语言的关键特性是什么",
"HTTP和HTTPS有什么区别",
"使用像Git这样的版本控制系统有什么优势",
"什么是'边缘计算'的概念",
"哪种编程语言最适合构建移动应用?",
"关系数据库和NoSQL数据库有什么不同",
"算法在计算机科学中的重要性是什么?",
"API在软件开发中的作用是什么",
"保护Web应用程序的最佳实践是什么",
"虚拟现实和增强现实有什么区别?",
"机器翻译是如何工作的?",
"Which framework do you think is the most suitable for performance sensitive projects?",
"What is post-quantum cryptography",
"What is a key derivation function?",
"What is Linear Algebra?",
"What is the main use of linear algebra in computer science",
"which IDE should I use for Go",
"Go vs Java vs Koltin, which for a backend",
"Which programming language is best suited for data analysis?",
"What is quantum computing?",
"What is a hash function?",
"What is calculus?",
"What are the main applications of machine learning in finance?",
"Which text editor is best for writing Python code?",
"Python vs R vs Julia, which is better for data science?",
"What are the key differences between supervised and unsupervised learning?",
"What is the role of a database in a web application?",
"What is blockchain technology?",
"What are the advantages of using Docker for application deployment?",
"Which cloud service provider offers the best AI tools?",
"How does encryption work?",
"What is the purpose of a load balancer in network architecture?",
"What is the difference between machine learning and deep learning?",
"What are the most common design patterns in software engineering?",
"How does a neural network learn?",
"What is the main benefit of using a microservices architecture?",
"What is the difference between a compiler and an interpreter?",
"What are the key features of the Rust programming language?",
"What is the difference between HTTP and HTTPS?",
"What are the advantages of using a version control system like Git?",
"What is the concept of 'edge computing'?",
"Which programming language is best for building mobile apps?",
"How does a relational database differ from a NoSQL database?",
"What is the importance of algorithms in computer science?",
"What is the role of an API in software development?",
"What are the best practices for securing a web application?",
"What is the difference between virtual reality and augmented reality?",
"How does machine translation work?"
],
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
}

View File

@ -0,0 +1,199 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0d107178",
"metadata": {},
"source": [
"# Extract the Dictionary & Embedding\n",
"\n",
"Our model uses the dictionary and embedding layers from `Phi-3-mini-4k-instruct`, a pre-trained transformer model developed by Microsoft. Although our model architecture is not a transformer, we can still benefit from its tokenizer and embedding layers."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "222bdd54-c115-4845-b4e6-c63e4f9ac6b4",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModel\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "13f1ed38-ad39-4e6a-8af7-65ace2d14f8d",
"metadata": {},
"outputs": [],
"source": [
"model_name=\"microsoft/Phi-3-mini-4k-instruct\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1aeb02c7c8084b1eb1b8e3178882fd60",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Load models\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModel.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0b0ba45f-5de9-465d-aa45-57ae45c5fb36",
"metadata": {},
"outputs": [],
"source": [
"embedding_layer = model.get_input_embeddings()\n",
"vocab = tokenizer.get_vocab()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b2f7e08d-c578-4b0c-ad75-cdb545b5433f",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.decomposition import PCA\n",
"from transformers import AutoTokenizer, AutoModel\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "686e1109-e210-45a3-8c58-b8e92bbe85ff",
"metadata": {},
"outputs": [],
"source": [
"DIMENSIONS = 128"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9c5d7235-f690-4b34-80a4-9b1db1c19100",
"metadata": {},
"outputs": [],
"source": [
"embeddings = []\n",
"for token_id in range(len(vocab)):\n",
" embedding_vector = embedding_layer(torch.tensor([token_id])).detach().numpy()\n",
" embeddings.append(embedding_vector)\n",
"\n",
"# Convert vectors to np arrays\n",
"embeddings = np.vstack(embeddings)\n",
"\n",
"# Use PCA to decrease dimension\n",
"pca = PCA(n_components=DIMENSIONS)\n",
"reduced_embeddings = pca.fit_transform(embeddings)"
]
},
{
"cell_type": "markdown",
"id": "6b834f53-7f39-41ab-9d18-71978e988b30",
"metadata": {},
"source": [
"## Save Model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f231dfdd-5f4f-4d0f-ab5d-7a1365535713",
"metadata": {},
"outputs": [],
"source": [
"# Create dict of tokenID -> dimension-reduced embedding\n",
"token_id_to_reduced_embedding = {token_id: reduced_embeddings[token_id] for token_id in range(len(vocab))}\n",
"\n",
"torch.save(token_id_to_reduced_embedding, \"token_id_to_reduced_embedding.pt\")\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e28a612c-7960-42e6-aa36-abd0732f404e",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"# Create dict of token to {token_id, reduced_embedding}\n",
"token_to_id = {}\n",
"for token, token_id in vocab.items():\n",
" token_to_id[token] = token_id\n",
"\n",
"# Save as JSON\n",
"with open(\"token_to_id.json\", \"w\") as f:\n",
" json.dump(token_to_id, f)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7a466c54-7c55-4e84-957e-454ae35896ac",
"metadata": {},
"outputs": [],
"source": [
"import struct\n",
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
" for token_id in range(len(vocab)):\n",
" # Write token id (2 bytes)\n",
" f.write(struct.pack('H', token_id))\n",
" # Write embedding vector (128 float numbers)\n",
" f.write(struct.pack('128f', *reduced_embeddings[token_id]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,179 @@
<!-- srcbook:{"language":"typescript","tsconfig.json":{"compilerOptions":{"module":"nodenext","moduleResolution":"nodenext","target":"es2022","resolveJsonModule":true,"noEmit":true,"allowImportingTsExtensions":true},"include":["src/**/*"],"exclude":["node_modules"]}} -->
# sparkastML intention classification
###### package.json
```json
{
"type": "module",
"dependencies": {
"@types/node": "latest",
"@xenova/transformers": "^2.17.2",
"onnxruntime-web": "^1.19.0",
"tsx": "latest",
"typescript": "latest"
}
}
```
###### tokenizer.ts
```typescript
export type TokenDict = { [key: string]: number };
type TokenDict = { [key: string]: number };
function tokenize(query: string, tokenDict: TokenDict): number[] {
const tokenIds: number[] = [];
let index = 0;
// Replace spaces with "▁"
query = "▁"+query.replace(/ /g, "▁");
query = query.replace(/\n/g, "<0x0A>");
while (index < query.length) {
let bestToken = null;
let bestLength = 0;
// Step 2: Find the longest token that matches the beginning of the remaining query
for (const token in tokenDict) {
if (query.startsWith(token, index) && token.length > bestLength) {
bestToken = token;
bestLength = token.length;
}
}
if (bestToken) {
tokenIds.push(tokenDict[bestToken]);
index += bestLength;
} else {
// Step 3: Handle the case where no token matches
const char = query[index];
if (char.charCodeAt(0) <= 127) {
// If the character is ASCII, and it doesn't match any token, treat it as an unknown token
throw new Error(`Unknown token: ${char}`);
} else {
// If the character is non-ASCII, convert it to a series of bytes and match each byte
const bytes = new TextEncoder().encode(char);
for (const byte of bytes) {
const byteToken = `<0x${byte.toString(16).toUpperCase()}>`;
if (tokenDict[byteToken] !== undefined) {
tokenIds.push(tokenDict[byteToken]);
} else {
throw new Error(`Unknown byte token: ${byteToken}`);
}
}
}
index += 1;
}
}
return tokenIds;
}
export default tokenize
```
###### embedding.ts
```typescript
import * as fs from 'fs';
import * as path from 'path';
type EmbeddingDict = { [key: number]: Float32Array };
function getEmbeddingLayer(buffer: Buffer): EmbeddingDict {
const dict: EmbeddingDict = {};
const entrySize = 514;
const numEntries = buffer.length / entrySize;
for (let i = 0; i < numEntries; i++) {
const offset = i * entrySize;
const key = buffer.readUInt16LE(offset);
const floatArray = new Float32Array(128);
for (let j = 0; j < 128; j++) {
floatArray[j] = buffer.readFloatLE(offset + 2 + j * 4);
}
dict[key] = floatArray;
}
return dict;
}
function getEmbedding(tokenIds: number[], embeddingDict: EmbeddingDict, contextSize: number) {
let result = [];
for (let i = 0; i < contextSize; i++) {
if (i < tokenIds.length) {
const tokenId = tokenIds[i];
result = result.concat(Array.from(embeddingDict[tokenId]))
}
else {
result = result.concat(new Array(128).fill(0))
}
}
return new Float32Array(result);
}
export {getEmbeddingLayer, getEmbedding};
```
###### load.ts
```typescript
import * as ort from 'onnxruntime-web';
import * as fs from 'fs';
import tokenize, {TokenDict} from "./tokenizer.ts"
import {getEmbeddingLayer, getEmbedding} from "./embedding.ts"
const embedding_file = './token_embeddings.bin';
const embedding_data = fs.readFileSync(embedding_file);
const embedding_buffer = Buffer.from(embedding_data);
const query = `Will it rain tomorrow`;
const model_path = './model.onnx';
const vocabData = fs.readFileSync('./token_to_id.json');
const vocabDict = JSON.parse(vocabData.toString());
let lastLogCall = new Date().getTime();
function log(task: string) {
const currentTime = new Date().getTime();
const costTime = currentTime - lastLogCall;
console.log(`[${currentTime}] (+${costTime}ms) ${task}`)
lastLogCall = new Date().getTime();
}
async function loadModel(modelPath: string) {
const session = await ort.InferenceSession.create(modelPath);
return session;
}
async function runInference(query: string, embedding_buffer: Buffer, modelPath: string, vocabDict: TokenDict) {
const session = await loadModel(modelPath);
log("loadModel:end");
const inputText = query;
const queryLength = query.length;
const tokenIds = await tokenize(query, vocabDict);
log("tokenize:end");
const embeddingDict = getEmbeddingLayer(embedding_buffer);
const e = getEmbedding(tokenIds, embeddingDict, 12);
log("getEmbedding:end");
const inputTensor = new ort.Tensor('float32', e, [1, 12, 128]);
const feeds = { 'input': inputTensor };
const results = await session.run(feeds);
log("inference:end");
const output = results.output.data;
const predictedClassIndex = output.indexOf(Math.max(...output));
return output;
}
console.log("Perdicted class:", await runInference(query, embedding_buffer, model_path, vocabDict));
```

View File

@ -0,0 +1,8 @@
[
"你好",
"你好谢谢小笼包再见",
"我爱你",
"嘿嘿嘿诶嘿",
"为什么",
"拼多多"
]

View File

@ -0,0 +1,575 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a6a3195f-d099-4bf4-846f-51f403954818",
"metadata": {},
"source": [
"# sparkastML: Training the Intention Classification Model\n",
"\n",
"This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n",
"\n",
"In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from sklearn.model_selection import train_test_split\n",
"from transformers import AutoTokenizer, AutoModel\n",
"import torch\n",
"import numpy as np\n",
"from scipy.spatial.distance import euclidean\n",
"from scipy.stats import weibull_min\n",
"from sklearn.preprocessing import normalize\n",
"import torch.nn.functional as F\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n",
"DIMENSIONS = 128\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
]
},
{
"cell_type": "markdown",
"id": "1ae14906-338d-4c99-87ed-bb1acd22b295",
"metadata": {},
"source": [
"## Load Data\n",
"\n",
"We load the data from `data.json`, and also get the negative sample from the `noise.json`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a206071c-ce4e-4de4-b936-bfc70d13708a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
" embeddings = torch.tensor(embeddings)\n"
]
}
],
"source": [
"# Load data\n",
"with open('data.json', 'r') as f:\n",
" data = json.load(f)\n",
"\n",
"# Create map: class to index\n",
"class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n",
"idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n",
"\n",
"# Preprocess data, convert sentences to the format of (class idx, embedding)\n",
"def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n",
" dataset = []\n",
" for label, sentences in data.items():\n",
" for sentence in sentences:\n",
" # Tokenize the sentence and convert tokens to embedding vectors\n",
" tokens = tokenizer.tokenize(sentence)\n",
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
" embeddings = torch.tensor(embeddings)\n",
" dataset.append((class_to_idx[label], embeddings))\n",
" return dataset\n",
"\n",
"# Load embedding map\n",
"embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n",
"\n",
"# Get preprocessed dataset\n",
"dataset = preprocess_data(data, embedding_map, tokenizer)\n",
"\n",
"# Train-test split\n",
"train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n",
"\n",
"class TextDataset(Dataset):\n",
" def __init__(self, data):\n",
" self.data = data\n",
"\n",
" def __len__(self):\n",
" return len(self.data)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n",
"\n",
" def collate_fn(self, batch):\n",
" labels, embeddings = zip(*batch)\n",
" labels = torch.tensor(labels)\n",
" embeddings = pad_sequence(embeddings, batch_first=True)\n",
" return labels, embeddings\n",
"\n",
"train_dataset = TextDataset(train_data)\n",
"val_dataset = TextDataset(val_data)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n",
"val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"class NegativeSampleDataset(Dataset):\n",
" def __init__(self, negative_samples):\n",
" \"\"\"\n",
" negative_samples: List or array of negative sample embeddings or raw text\n",
" \"\"\"\n",
" self.samples = negative_samples\n",
" \n",
" def __len__(self):\n",
" return len(self.samples)\n",
" \n",
" def __getitem__(self, idx):\n",
" return self.samples[idx]\n",
"\n",
" def collate_fn(self, batch):\n",
" embeddings = pad_sequence(batch, batch_first=True)\n",
" return embeddings\n",
"\n",
"with open('noise.json', 'r') as f:\n",
" negative_samples_list = json.load(f)\n",
"\n",
"negative_embedding_list = []\n",
"\n",
"for sentence in negative_samples_list:\n",
" tokens = tokenizer.tokenize(sentence)\n",
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
" embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n",
" embeddings = torch.tensor(embeddings)\n",
" negative_embedding_list.append(embeddings)\n",
"\n",
"negative_dataset = NegativeSampleDataset(negative_embedding_list)\n",
"negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n"
]
},
{
"cell_type": "markdown",
"id": "600febe4-2484-4aad-90a1-2bc821fdce1a",
"metadata": {},
"source": [
"## Implementating the Model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "adf624ac-ad63-437b-95f6-b02b7253b91e",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class TextCNN(nn.Module):\n",
" def __init__(self, input_dim, num_classes):\n",
" super(TextCNN, self).__init__()\n",
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
" \n",
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
" \n",
" self.dropout = nn.Dropout(0.5)\n",
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
" \n",
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
" \n",
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
" \n",
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
" \n",
" x = torch.cat((x1, x2, x3), dim=1)\n",
" x = self.dropout(x)\n",
" x = self.fc(x)\n",
" return x\n",
"\n",
"# Initialize model\n",
"input_dim = DIMENSIONS\n",
"num_classes = len(class_to_idx)\n",
"model = TextCNN(input_dim, num_classes)\n"
]
},
{
"cell_type": "markdown",
"id": "2750e17d-8a60-40c7-851b-1a567d0ee82b",
"metadata": {},
"source": [
"## Energy-based Models"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb",
"metadata": {},
"outputs": [],
"source": [
"def energy_score(logits):\n",
" # Energy score is minus logsumexp\n",
" return -torch.logsumexp(logits, dim=1)\n",
"\n",
"def generate_noise(batch_size, seq_length ,input_dim, device):\n",
" # Generate a Gaussian noise\n",
" return torch.randn(batch_size, seq_length, input_dim).to(device)\n"
]
},
{
"cell_type": "markdown",
"id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/50], Loss: 12.5108\n",
"Epoch [2/50], Loss: 10.7305\n",
"Epoch [3/50], Loss: 10.2943\n",
"Epoch [4/50], Loss: 9.9350\n",
"Epoch [5/50], Loss: 9.7991\n",
"Epoch [6/50], Loss: 9.6443\n",
"Epoch [7/50], Loss: 9.4762\n",
"Epoch [8/50], Loss: 9.4637\n",
"Epoch [9/50], Loss: 9.3025\n",
"Epoch [10/50], Loss: 9.1719\n",
"Epoch [11/50], Loss: 9.0632\n",
"Epoch [12/50], Loss: 8.9741\n",
"Epoch [13/50], Loss: 8.8487\n",
"Epoch [14/50], Loss: 8.6565\n",
"Epoch [15/50], Loss: 8.5830\n",
"Epoch [16/50], Loss: 8.4196\n",
"Epoch [17/50], Loss: 8.2319\n",
"Epoch [18/50], Loss: 8.0655\n",
"Epoch [19/50], Loss: 7.7140\n",
"Epoch [20/50], Loss: 7.6921\n",
"Epoch [21/50], Loss: 7.3375\n",
"Epoch [22/50], Loss: 7.2297\n",
"Epoch [23/50], Loss: 6.8833\n",
"Epoch [24/50], Loss: 6.8534\n",
"Epoch [25/50], Loss: 6.4557\n",
"Epoch [26/50], Loss: 6.1365\n",
"Epoch [27/50], Loss: 5.8558\n",
"Epoch [28/50], Loss: 5.5030\n",
"Epoch [29/50], Loss: 5.1604\n",
"Epoch [30/50], Loss: 4.7742\n",
"Epoch [31/50], Loss: 4.5958\n",
"Epoch [32/50], Loss: 4.0713\n",
"Epoch [33/50], Loss: 3.8872\n",
"Epoch [34/50], Loss: 3.5240\n",
"Epoch [35/50], Loss: 3.3115\n",
"Epoch [36/50], Loss: 2.5667\n",
"Epoch [37/50], Loss: 2.6709\n",
"Epoch [38/50], Loss: 1.8075\n",
"Epoch [39/50], Loss: 1.6654\n",
"Epoch [40/50], Loss: 0.4622\n",
"Epoch [41/50], Loss: 0.4719\n",
"Epoch [42/50], Loss: -0.4037\n",
"Epoch [43/50], Loss: -0.9405\n",
"Epoch [44/50], Loss: -1.7204\n",
"Epoch [45/50], Loss: -2.4124\n",
"Epoch [46/50], Loss: -3.0032\n",
"Epoch [47/50], Loss: -2.7123\n",
"Epoch [48/50], Loss: -3.6953\n",
"Epoch [49/50], Loss: -3.7212\n",
"Epoch [50/50], Loss: -3.7558\n"
]
}
],
"source": [
"import torch.optim as optim\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=8e-4)\n",
"\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import tensorboard\n",
"writer = SummaryWriter()\n",
"\n",
"def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n",
" model.train()\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" model.to(device)\n",
" \n",
" negative_iter = iter(negative_loader)\n",
" \n",
" for epoch in range(num_epochs):\n",
" total_loss = 0\n",
" for batch_idx, (labels, embeddings) in enumerate(train_loader):\n",
" embeddings = embeddings.to(device)\n",
" labels = labels.to(device)\n",
" \n",
" batch_size = embeddings.size(0)\n",
" \n",
" # ---------------------\n",
" # 1. Positive sample\n",
" # ---------------------\n",
" optimizer.zero_grad()\n",
" outputs = model(embeddings) # logits from the model\n",
" \n",
" class_loss = criterion(outputs, labels)\n",
" \n",
" # Energy of positive sample\n",
" known_energy = energy_score(outputs)\n",
" energy_loss_known = known_energy.mean()\n",
" \n",
" # ------------------------------------\n",
" # 2. Negative sample - Random Noise\n",
" # ------------------------------------\n",
" noise_embeddings = torch.randn_like(embeddings).to(device)\n",
" noise_outputs = model(noise_embeddings)\n",
" noise_energy = energy_score(noise_outputs)\n",
" energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n",
" \n",
" # ------------------------------------\n",
" # 3. Negative sample - custom corpus\n",
" # ------------------------------------\n",
" \n",
" try:\n",
" negative_samples = next(negative_iter)\n",
" except StopIteration:\n",
" negative_iter = iter(negative_loader)\n",
" negative_samples = next(negative_iter)\n",
" negative_samples = negative_samples.to(device)\n",
" negative_outputs = model(negative_samples)\n",
" negative_energy = energy_score(negative_outputs)\n",
" energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n",
" \n",
" # -----------------------------\n",
" # 4. Overall Loss calculation\n",
" # -----------------------------\n",
" total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n",
" total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n",
"\n",
" writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n",
" writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n",
" writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n",
" \n",
" total_loss_batch.backward()\n",
" optimizer.step()\n",
" \n",
" total_loss += total_loss_batch.item()\n",
" \n",
" avg_loss = total_loss / len(train_loader)\n",
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n",
"\n",
"train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n",
"writer.flush()\n"
]
},
{
"cell_type": "markdown",
"id": "e6d29558-f497-4033-8488-169bd25ce881",
"metadata": {},
"source": [
"## Evalutation"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "472702e6-db4a-4faa-9e92-7510e6eacbb1",
"metadata": {},
"outputs": [],
"source": [
"ENERGY_THRESHOLD = -3"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9315\n",
"Precision: 1.0000\n",
"Recall: 0.9254\n",
"F1 Score: 0.9612\n"
]
}
],
"source": [
"from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n",
"\n",
"def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n",
" model.eval()\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" \n",
" all_preds = []\n",
" all_labels = []\n",
" \n",
" # Evaluate positive sample\n",
" with torch.no_grad():\n",
" for labels, embeddings in known_loader:\n",
" embeddings = embeddings.to(device)\n",
" logits = model(embeddings)\n",
" energy = energy_score(logits)\n",
" \n",
" preds = (energy <= energy_threshold).long()\n",
" all_preds.extend(preds.cpu().numpy())\n",
" all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n",
" \n",
" # Evaluate negative sample\n",
" with torch.no_grad():\n",
" for embeddings in unknown_loader:\n",
" embeddings = embeddings.to(device)\n",
" logits = model(embeddings)\n",
" energy = energy_score(logits)\n",
" \n",
" preds = (energy <= energy_threshold).long()\n",
" all_preds.extend(preds.cpu().numpy())\n",
" all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n",
" \n",
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
" accuracy = accuracy_score(all_labels, all_preds)\n",
"\n",
" print(f'Accuracy: {accuracy:.4f}')\n",
" print(f'Precision: {precision:.4f}')\n",
" print(f'Recall: {recall:.4f}')\n",
" print(f'F1 Score: {f1:.4f}')\n",
"\n",
"evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "ba614054-75e1-4f61-ace5-aeb11e29a222",
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"torch.save(model, \"model.pt\")"
]
},
{
"cell_type": "markdown",
"id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c",
"metadata": {},
"source": [
"## Inference"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "03928d75-81c8-4298-ab8a-d7f8a758b561",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n"
]
}
],
"source": [
"def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n",
" model.eval()\n",
" tokens = tokenizer.tokenize(sentence)\n",
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
" embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n",
" \n",
" with torch.no_grad():\n",
" logits = model(embeddings)\n",
" probabilities = F.softmax(logits, dim=1)\n",
" max_prob, predicted = torch.max(probabilities, 1)\n",
" \n",
" # Calculate energy score\n",
" energy = energy_score(logits)\n",
"\n",
" # If energy > threshold, consider the input as unknown class\n",
" if energy.item() > energy_threshold:\n",
" return [\"Unknown\", max_prob.item(), energy.item()]\n",
" else:\n",
" return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n",
"\n",
"# Example usage:\n",
"sentence = \"weather today\"\n",
"energy_threshold = ENERGY_THRESHOLD\n",
"predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n",
"print(f'Predicted: {predicted}')\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}