init
This commit is contained in:
commit
f28f83b48e
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
runs
|
||||||
|
.DS_Store
|
||||||
|
*.onnx
|
||||||
|
*.pt
|
||||||
|
*.bin
|
||||||
|
token_to_id.json
|
||||||
|
.ipynb_checkpoints
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 alikia2x
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
22
intention-classify/LICENSE
Normal file
22
intention-classify/LICENSE
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
Microsoft.
|
||||||
|
Copyright (c) Microsoft Corporation.
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
126
intention-classify/convert_onnx.ipynb
Normal file
126
intention-classify/convert_onnx.ipynb
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "94ff7007",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Convert to ONNX\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook converts our model to [ONNX](https://onnx.ai/) format, which is the open standard for machine learning interoperability. In this way, we can run our model in JS (browser)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "311162fd-f957-4746-b524-25bb3e09efbc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"from torch import nn\n",
|
||||||
|
"import torch.utils.model_zoo as model_zoo\n",
|
||||||
|
"import torch.onnx\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.nn.functional as F"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "fd182c6d-1e77-4bbb-bb53-8321d40ae002",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class TextCNN(nn.Module):\n",
|
||||||
|
" def __init__(self, input_dim, num_classes):\n",
|
||||||
|
" super(TextCNN, self).__init__()\n",
|
||||||
|
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
|
||||||
|
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
|
||||||
|
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
|
||||||
|
" \n",
|
||||||
|
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||||
|
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||||
|
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||||
|
" \n",
|
||||||
|
" self.dropout = nn.Dropout(0.5)\n",
|
||||||
|
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
|
||||||
|
" \n",
|
||||||
|
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
|
||||||
|
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
|
||||||
|
" \n",
|
||||||
|
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
|
||||||
|
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
|
||||||
|
" \n",
|
||||||
|
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
|
||||||
|
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
|
||||||
|
" \n",
|
||||||
|
" x = torch.cat((x1, x2, x3), dim=1)\n",
|
||||||
|
" x = self.dropout(x)\n",
|
||||||
|
" x = self.fc(x)\n",
|
||||||
|
" return x"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "bdb597cb-d896-485c-8c9c-897b1d35e8d2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = torch.load(\"model.pt\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "9f7a6e64-75f2-4fa9-8d1e-b83099765d02",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"# Example input: use random embedding vector to simulate real input\n",
|
||||||
|
"dummy_input = torch.randn(1, 64, 128) # (batch_size, seq_length, embedding_dim)\n",
|
||||||
|
"\n",
|
||||||
|
"# Export model\n",
|
||||||
|
"torch.onnx.export(\n",
|
||||||
|
" model, # The model to export\n",
|
||||||
|
" dummy_input, # Example input\n",
|
||||||
|
" \"model.onnx\", # File name\n",
|
||||||
|
" input_names=['input'], # Input name (Could customize)\n",
|
||||||
|
" output_names=['output'], # Output name (Could customize)\n",
|
||||||
|
" dynamic_axes={\n",
|
||||||
|
" 'input': {0: 'batch_size', 1: 'seq_length'}, # Dynamic batch and sequence length\n",
|
||||||
|
" 'output': {0: 'batch_size'}\n",
|
||||||
|
" },\n",
|
||||||
|
" opset_version=11 # ONNX version,ensure the ONNX runtime supports it\n",
|
||||||
|
")\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.19"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
350
intention-classify/data.json
Normal file
350
intention-classify/data.json
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
{
|
||||||
|
"weather": [
|
||||||
|
"天气如何?",
|
||||||
|
"现在的天气",
|
||||||
|
"今天的天气预报",
|
||||||
|
"现在的天气状况",
|
||||||
|
"今天天气怎么样",
|
||||||
|
"目前是什么天气",
|
||||||
|
"今天的天气概述",
|
||||||
|
"当前天气状况如何",
|
||||||
|
"今天会下雨吗?",
|
||||||
|
"今天晴天吗?",
|
||||||
|
"今天的天气状况如何",
|
||||||
|
"现在外面是什么天气",
|
||||||
|
"今天天气好么",
|
||||||
|
"今天有没有雾霾",
|
||||||
|
"明天有没有大风",
|
||||||
|
"今天会不会很冷",
|
||||||
|
"今天的天气会变化吗",
|
||||||
|
"今天晚上的天气如何",
|
||||||
|
"今天夜里会下雨吗",
|
||||||
|
"明天会下雨吗?",
|
||||||
|
"北京今天的预报",
|
||||||
|
"两天后上海的天气情况",
|
||||||
|
"现在的温度",
|
||||||
|
"现在多少度",
|
||||||
|
"外面有多热",
|
||||||
|
"明天热不热",
|
||||||
|
"现在的气温是多少",
|
||||||
|
"今天最高温度是多少",
|
||||||
|
"今天最低温度是多少",
|
||||||
|
"现在外面感觉冷吗",
|
||||||
|
"明天会比今天热吗",
|
||||||
|
"明天会比今天冷吗",
|
||||||
|
"今天的温度变化大吗",
|
||||||
|
"室外的温度是多少",
|
||||||
|
"达拉斯今天热不热",
|
||||||
|
"苏州现在天气怎么样",
|
||||||
|
"how's the weather",
|
||||||
|
"What's going on with the weather?",
|
||||||
|
"Can you give me an update on the weather?",
|
||||||
|
"How's the forecast looking today?",
|
||||||
|
"Give me a summary of the current weather.",
|
||||||
|
"Can you tell me the current weather?",
|
||||||
|
"What is the weather situation at the moment?",
|
||||||
|
"Could you provide a quick weather update?",
|
||||||
|
"Is it raining or sunny outside?",
|
||||||
|
"What's the weather like right now?",
|
||||||
|
"Tell me the current weather conditions.",
|
||||||
|
"How about the weather today?",
|
||||||
|
"What's the weather looking like for the next few hours?",
|
||||||
|
"Is it going to stay this way all day?",
|
||||||
|
"Could you give me a brief overview of the weather?",
|
||||||
|
"What's the general weather situation in our area?",
|
||||||
|
"Is it cloudy or clear outside?",
|
||||||
|
"What's the forecast saying for today's weather?",
|
||||||
|
"Is it going to be a warm day?",
|
||||||
|
"Are we expecting any storms today?",
|
||||||
|
"What's the weather condition outside my window?",
|
||||||
|
"Is it a typical day for this season in terms of weather?",
|
||||||
|
"how's the weather now?",
|
||||||
|
"What's the temperature like right now?",
|
||||||
|
"Can you tell me the current temperature?",
|
||||||
|
"How hot is it outside?",
|
||||||
|
"What's the temperature supposed to be today?",
|
||||||
|
"What is the current temp outside?",
|
||||||
|
"Could you tell me the outdoor temperature?",
|
||||||
|
"Is it cold or warm outside?",
|
||||||
|
"What's the high temperature for today?",
|
||||||
|
"What's the low temperature expected tonight?",
|
||||||
|
"How does the temperature feel outside?",
|
||||||
|
"Is it going to get warmer or cooler today?",
|
||||||
|
"What's the temperature in the shade?",
|
||||||
|
"Can you provide the current temp in Celsius?",
|
||||||
|
"What's the temperature in Fahrenheit right now?",
|
||||||
|
"Is it too hot to be outside?",
|
||||||
|
"What's the temperature like in the morning?",
|
||||||
|
"How about the temperature in the evening?",
|
||||||
|
"Is it warm enough to go swimming?",
|
||||||
|
"What's the temperature in the city center?",
|
||||||
|
"Can you tell me the temp in the nearby area?",
|
||||||
|
"Is it below freezing outside?",
|
||||||
|
"What's the average temperature for today?",
|
||||||
|
"Is the temperature dropping or rising?",
|
||||||
|
"What should I wear considering the temperature?"
|
||||||
|
],
|
||||||
|
"base64": [
|
||||||
|
"请将数据使用base64编码",
|
||||||
|
"需要将以下数据base 64编码",
|
||||||
|
"请将此字符串转为base 58",
|
||||||
|
"将数据转为base58编码",
|
||||||
|
"信息base64编码",
|
||||||
|
"请帮忙编码base64",
|
||||||
|
"将数据编码为base64",
|
||||||
|
"base 64编码",
|
||||||
|
"base64 在线编码",
|
||||||
|
"在线编码 base 64",
|
||||||
|
"编码 base64",
|
||||||
|
"请解码这个base64数据",
|
||||||
|
"有base64编码字符串需要解码",
|
||||||
|
"帮忙解码base64",
|
||||||
|
"将base 58编码转回原数据",
|
||||||
|
"解码base58信息",
|
||||||
|
"解码这个base64",
|
||||||
|
"将base64转文本",
|
||||||
|
"base64解码",
|
||||||
|
"base64 解码",
|
||||||
|
"base58 在线解码",
|
||||||
|
"在线解码 base58",
|
||||||
|
"解码 base64",
|
||||||
|
"Please encode this data with base64:",
|
||||||
|
"I need to encode the following data in base64",
|
||||||
|
"Could you encode this string using base64?",
|
||||||
|
"Convert this data to b64 encoding",
|
||||||
|
"I want to encode this information with base64",
|
||||||
|
"Help me encode this in base32",
|
||||||
|
"Can you encode this data to base64 format?",
|
||||||
|
"b64 encode",
|
||||||
|
"base64 encode",
|
||||||
|
"encode base64",
|
||||||
|
"base 64 encode online"
|
||||||
|
],
|
||||||
|
|
||||||
|
"url-encode": [
|
||||||
|
"编码 url",
|
||||||
|
"URL部分需要编码",
|
||||||
|
"请将URL部分编码",
|
||||||
|
"URL编码转换",
|
||||||
|
"编码url 段",
|
||||||
|
"URL数据编码",
|
||||||
|
"请解码这个URL",
|
||||||
|
"有 url 编码需要解码",
|
||||||
|
"解码这个URL",
|
||||||
|
"URL编码转回原URL",
|
||||||
|
"解码 URL部分",
|
||||||
|
"解码 URL 段",
|
||||||
|
"URL 编码转文本",
|
||||||
|
"encode URL: ",
|
||||||
|
"I need to encode this URL component: ",
|
||||||
|
"Convert url encoded",
|
||||||
|
"encode this URL for safe transmission",
|
||||||
|
"Encode this URL segment: ",
|
||||||
|
"decode this url ",
|
||||||
|
"convert encoded url to text",
|
||||||
|
"url decoder",
|
||||||
|
"URL encoder"
|
||||||
|
],
|
||||||
|
|
||||||
|
"html-encode": [
|
||||||
|
"请编码HTML实体",
|
||||||
|
"文本转为HTML实体",
|
||||||
|
"编码为HTML实体",
|
||||||
|
"文本HTML实体编码",
|
||||||
|
"预防HTML解析编码",
|
||||||
|
"HTML实体编码",
|
||||||
|
"文本HTML使用编码",
|
||||||
|
"html 转义",
|
||||||
|
"请解码HTML实体",
|
||||||
|
"HTML实体需要解码",
|
||||||
|
"解码HTML实体",
|
||||||
|
"HTML实体转回文本",
|
||||||
|
"HTML实体解码",
|
||||||
|
"解码HTML实体",
|
||||||
|
"HTML实体转文本",
|
||||||
|
"html 实体转义",
|
||||||
|
"html实体 转换",
|
||||||
|
"html   转换",
|
||||||
|
"html nbsp 意思",
|
||||||
|
"Please encode HTML entities",
|
||||||
|
"Convert text to HTML entities",
|
||||||
|
"Encode to HTML entities",
|
||||||
|
"Encode text HTML entities",
|
||||||
|
"Prevent HTML parsing encoding",
|
||||||
|
"HTML entity encoding",
|
||||||
|
"Text HTML uses encoding",
|
||||||
|
"html escape",
|
||||||
|
"Please decode HTML entities",
|
||||||
|
"HTML entities need to be decoded",
|
||||||
|
"Decode HTML entities",
|
||||||
|
"Convert HTML entities to text",
|
||||||
|
"Decode HTML entities",
|
||||||
|
"Decode HTML entities",
|
||||||
|
"Convert HTML entities to text",
|
||||||
|
"HTML entity escape",
|
||||||
|
"html entity conversion",
|
||||||
|
"html   conversion",
|
||||||
|
"html nbsp meaning"
|
||||||
|
],
|
||||||
|
|
||||||
|
"ai.command": [
|
||||||
|
"写一个TypeScript的HelloWorld代码",
|
||||||
|
"检查以下内容的语法和清晰度",
|
||||||
|
"帮助我学习词汇:为我写一个句子让我填空,我会尝试选择正确的选项",
|
||||||
|
"改进这个Markdown内容的语法和表达,用于GitHub仓库的README",
|
||||||
|
"你能想出一个我包的短名字吗",
|
||||||
|
"简化这段代码",
|
||||||
|
"帮助我学习中文",
|
||||||
|
"如何在网页开发中让屏幕阅读器自动聚焦到新弹出的元素上",
|
||||||
|
"总结以下文本:",
|
||||||
|
"这段代码有什么问题吗,或者可以简化吗?",
|
||||||
|
"生成一个打印'Hello, World!'的Python脚本",
|
||||||
|
"你能校对这篇论文的语法和标点错误吗?",
|
||||||
|
"为单词'serendipity'创建十个例句列表",
|
||||||
|
"你能重新格式化这个JSON使其更易读吗?",
|
||||||
|
"为我的关于健康饮食的博客文章建议一个有创意的标题。",
|
||||||
|
"重构这个JavaScript函数使其更高效。",
|
||||||
|
"帮助我练习法语:提供一个有缺失单词的句子,我可以猜。",
|
||||||
|
"如何在模态框出现在网页上时让按钮自动聚焦",
|
||||||
|
"为我总结这篇新闻文章",
|
||||||
|
"你能审查这段代码片段的潜在安全漏洞吗",
|
||||||
|
"生成一个SQL查询来查找所有在过去30天内注册的用户",
|
||||||
|
"你能把这个段落翻译成西班牙语吗",
|
||||||
|
"根据以下过程描述创建一个流程图。",
|
||||||
|
"写一个计算数字阶乘的Python函数。",
|
||||||
|
"提供如何在Web应用程序中实现OAuth2的详细解释。",
|
||||||
|
"你能优化这张图片以便在网站上更快加载吗?",
|
||||||
|
"为专注于健身的新移动应用建议一些吸引人的标语",
|
||||||
|
"写一个每天备份我的文档文件夹的Bash脚本",
|
||||||
|
"帮助我起草一封请求与潜在客户会面的电子邮件。",
|
||||||
|
"你能把这个Markdown文档转换成HTML吗?",
|
||||||
|
"生成一个从指定网站抓取数据的Python脚本。",
|
||||||
|
"你能找到单词'meticulous'的同义词吗?",
|
||||||
|
"写一个基于公共列连接两个表的SQL查询。",
|
||||||
|
"根据以下指令集创建一个流程图。",
|
||||||
|
"为一个新的任务管理工具建议一个名字。",
|
||||||
|
"在不改变其功能的情况下简化这段Python代码。",
|
||||||
|
"你能帮助我学习日语吗?",
|
||||||
|
"如何在用户点击网页上的按钮时让警告框出现?",
|
||||||
|
"把这个研究论文总结成要点",
|
||||||
|
"你能检查这个算法中是否有任何逻辑错误吗?",
|
||||||
|
"write a typescript helloworld code",
|
||||||
|
"Check the following content for grammar and clarity",
|
||||||
|
"Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
|
||||||
|
"Improve this markdown content in asepcts like grammar and expression, for a GitHub repo README.",
|
||||||
|
"can u think of a short name of my package",
|
||||||
|
"simplify this code",
|
||||||
|
"help me learn chinese",
|
||||||
|
"how to let the screen reader automatically focused to an newly poped up element in the web development",
|
||||||
|
"summarize following text:",
|
||||||
|
"Is there anything wrong with this code or can it be simplified?",
|
||||||
|
"generate a Python script that prints 'Hello, World!'",
|
||||||
|
"Can you proofread this essay for grammar and punctuation errors?",
|
||||||
|
"Create a list of ten example sentences for the word 'serendipity.'",
|
||||||
|
"Can you reformat this JSON to be more readable?",
|
||||||
|
"Suggest a creative title for my blog post about healthy eating.",
|
||||||
|
"Refactor this JavaScript function to make it more efficient.",
|
||||||
|
"Help me practice French: provide a sentence with a missing word that I can guess.",
|
||||||
|
"How do I make a button automatically focus when a modal appears on a webpage?",
|
||||||
|
"Summarize this news article for me.",
|
||||||
|
"Can you review this code snippet for potential security vulnerabilities?",
|
||||||
|
"Generate a SQL query to find all users who signed up in the last 30 days.",
|
||||||
|
"Can you translate this paragraph into Spanish?",
|
||||||
|
"Create a flowchart based on the following process description.",
|
||||||
|
"Write a Python function to calculate the factorial of a number.",
|
||||||
|
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
|
||||||
|
"Can you optimize this image for faster loading on a website?",
|
||||||
|
"Suggest some catchy taglines for a new mobile app focused on fitness.",
|
||||||
|
"Write a Bash script to back up my documents folder daily.",
|
||||||
|
"Help me draft an email to request a meeting with a potential client.",
|
||||||
|
"Can you convert this Markdown document into HTML?",
|
||||||
|
"Generate a Python script that scrapes data from a specified website.",
|
||||||
|
"Can you find the synonyms of the word 'meticulous'?",
|
||||||
|
"Write a SQL query to join two tables based on a common column.",
|
||||||
|
"Create a flowchart based on the following set of instructions.",
|
||||||
|
"Suggest a name for a new task management tool.",
|
||||||
|
"Simplify this Python code without changing its functionality.",
|
||||||
|
"Can you assist me in learning Japanese?",
|
||||||
|
"How can I make an alert box appear when a user clicks a button on a webpage?",
|
||||||
|
"Summarize this research paper into bullet points.",
|
||||||
|
"Can you check if there are any logical errors in this algorithm?"
|
||||||
|
],
|
||||||
|
|
||||||
|
"ai.question": [
|
||||||
|
"你认为哪个框架最适合性能敏感的项目?",
|
||||||
|
"什么是后量子密码学?",
|
||||||
|
"什么是密钥派生函数",
|
||||||
|
"什么是线性代数?",
|
||||||
|
"线性代数在计算机科学中的主要用途是什么?",
|
||||||
|
"我应该使用哪个IDE来编写Go语言?",
|
||||||
|
"Go vs Java vs Kotlin,哪个适合后端",
|
||||||
|
"哪种编程语言最适合数据分析",
|
||||||
|
"什么是量子计算",
|
||||||
|
"什么是哈希函数?",
|
||||||
|
"什么是微积分?",
|
||||||
|
"机器学习在金融中的主要应用有哪些?",
|
||||||
|
"写Python代码最好的文本编辑器是哪个?",
|
||||||
|
"Python vs R vs Julia,哪个更适合数据科学?",
|
||||||
|
"监督学习和无监督学习的关键区别是什么?",
|
||||||
|
"数据库在Web应用程序中的作用是什么",
|
||||||
|
"什么是区块链技术",
|
||||||
|
"使用Docker进行应用程序部署的优势是什么?",
|
||||||
|
"哪个云服务提供商提供最好的AI工具?",
|
||||||
|
"加密是如何工作的?",
|
||||||
|
"负载均衡器在网络架构中的目的是什么?",
|
||||||
|
"机器学习和深度学习有什么区别",
|
||||||
|
"软件工程中最常见的设计模式有哪些",
|
||||||
|
"神经网络是如何学习的",
|
||||||
|
"使用微服务架构的主要好处是什么",
|
||||||
|
"编译器和解释器有什么区别?",
|
||||||
|
"Rust编程语言的关键特性是什么?",
|
||||||
|
"HTTP和HTTPS有什么区别",
|
||||||
|
"使用像Git这样的版本控制系统有什么优势?",
|
||||||
|
"什么是'边缘计算'的概念",
|
||||||
|
"哪种编程语言最适合构建移动应用?",
|
||||||
|
"关系数据库和NoSQL数据库有什么不同?",
|
||||||
|
"算法在计算机科学中的重要性是什么?",
|
||||||
|
"API在软件开发中的作用是什么",
|
||||||
|
"保护Web应用程序的最佳实践是什么",
|
||||||
|
"虚拟现实和增强现实有什么区别?",
|
||||||
|
"机器翻译是如何工作的?",
|
||||||
|
"Which framework do you think is the most suitable for performance sensitive projects?",
|
||||||
|
"What is post-quantum cryptography",
|
||||||
|
"What is a key derivation function?",
|
||||||
|
"What is Linear Algebra?",
|
||||||
|
"What is the main use of linear algebra in computer science",
|
||||||
|
"which IDE should I use for Go",
|
||||||
|
"Go vs Java vs Koltin, which for a backend",
|
||||||
|
"Which programming language is best suited for data analysis?",
|
||||||
|
"What is quantum computing?",
|
||||||
|
"What is a hash function?",
|
||||||
|
"What is calculus?",
|
||||||
|
"What are the main applications of machine learning in finance?",
|
||||||
|
"Which text editor is best for writing Python code?",
|
||||||
|
"Python vs R vs Julia, which is better for data science?",
|
||||||
|
"What are the key differences between supervised and unsupervised learning?",
|
||||||
|
"What is the role of a database in a web application?",
|
||||||
|
"What is blockchain technology?",
|
||||||
|
"What are the advantages of using Docker for application deployment?",
|
||||||
|
"Which cloud service provider offers the best AI tools?",
|
||||||
|
"How does encryption work?",
|
||||||
|
"What is the purpose of a load balancer in network architecture?",
|
||||||
|
"What is the difference between machine learning and deep learning?",
|
||||||
|
"What are the most common design patterns in software engineering?",
|
||||||
|
"How does a neural network learn?",
|
||||||
|
"What is the main benefit of using a microservices architecture?",
|
||||||
|
"What is the difference between a compiler and an interpreter?",
|
||||||
|
"What are the key features of the Rust programming language?",
|
||||||
|
"What is the difference between HTTP and HTTPS?",
|
||||||
|
"What are the advantages of using a version control system like Git?",
|
||||||
|
"What is the concept of 'edge computing'?",
|
||||||
|
"Which programming language is best for building mobile apps?",
|
||||||
|
"How does a relational database differ from a NoSQL database?",
|
||||||
|
"What is the importance of algorithms in computer science?",
|
||||||
|
"What is the role of an API in software development?",
|
||||||
|
"What are the best practices for securing a web application?",
|
||||||
|
"What is the difference between virtual reality and augmented reality?",
|
||||||
|
"How does machine translation work?"
|
||||||
|
],
|
||||||
|
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
|
||||||
|
}
|
199
intention-classify/extract.ipynb
Normal file
199
intention-classify/extract.ipynb
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "0d107178",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Extract the Dictionary & Embedding\n",
|
||||||
|
"\n",
|
||||||
|
"Our model uses the dictionary and embedding layers from `Phi-3-mini-4k-instruct`, a pre-trained transformer model developed by Microsoft. Although our model architecture is not a transformer, we can still benefit from its tokenizer and embedding layers."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "222bdd54-c115-4845-b4e6-c63e4f9ac6b4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from transformers import AutoTokenizer, AutoModel\n",
|
||||||
|
"import torch"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "13f1ed38-ad39-4e6a-8af7-65ace2d14f8d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model_name=\"microsoft/Phi-3-mini-4k-instruct\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "1aeb02c7c8084b1eb1b8e3178882fd60",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Load models\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||||
|
"model = AutoModel.from_pretrained(model_name)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "0b0ba45f-5de9-465d-aa45-57ae45c5fb36",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embedding_layer = model.get_input_embeddings()\n",
|
||||||
|
"vocab = tokenizer.get_vocab()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "b2f7e08d-c578-4b0c-ad75-cdb545b5433f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.decomposition import PCA\n",
|
||||||
|
"from transformers import AutoTokenizer, AutoModel\n",
|
||||||
|
"import numpy as np"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "686e1109-e210-45a3-8c58-b8e92bbe85ff",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"DIMENSIONS = 128"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "9c5d7235-f690-4b34-80a4-9b1db1c19100",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"embeddings = []\n",
|
||||||
|
"for token_id in range(len(vocab)):\n",
|
||||||
|
" embedding_vector = embedding_layer(torch.tensor([token_id])).detach().numpy()\n",
|
||||||
|
" embeddings.append(embedding_vector)\n",
|
||||||
|
"\n",
|
||||||
|
"# Convert vectors to np arrays\n",
|
||||||
|
"embeddings = np.vstack(embeddings)\n",
|
||||||
|
"\n",
|
||||||
|
"# Use PCA to decrease dimension\n",
|
||||||
|
"pca = PCA(n_components=DIMENSIONS)\n",
|
||||||
|
"reduced_embeddings = pca.fit_transform(embeddings)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "6b834f53-7f39-41ab-9d18-71978e988b30",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Save Model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "f231dfdd-5f4f-4d0f-ab5d-7a1365535713",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create dict of tokenID -> dimension-reduced embedding\n",
|
||||||
|
"token_id_to_reduced_embedding = {token_id: reduced_embeddings[token_id] for token_id in range(len(vocab))}\n",
|
||||||
|
"\n",
|
||||||
|
"torch.save(token_id_to_reduced_embedding, \"token_id_to_reduced_embedding.pt\")\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "e28a612c-7960-42e6-aa36-abd0732f404e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import json\n",
|
||||||
|
"\n",
|
||||||
|
"# Create dict of token to {token_id, reduced_embedding}\n",
|
||||||
|
"token_to_id = {}\n",
|
||||||
|
"for token, token_id in vocab.items():\n",
|
||||||
|
" token_to_id[token] = token_id\n",
|
||||||
|
"\n",
|
||||||
|
"# Save as JSON\n",
|
||||||
|
"with open(\"token_to_id.json\", \"w\") as f:\n",
|
||||||
|
" json.dump(token_to_id, f)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "7a466c54-7c55-4e84-957e-454ae35896ac",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import struct\n",
|
||||||
|
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
|
||||||
|
" for token_id in range(len(vocab)):\n",
|
||||||
|
" # Write token id (2 bytes)\n",
|
||||||
|
" f.write(struct.pack('H', token_id))\n",
|
||||||
|
" # Write embedding vector (128 float numbers)\n",
|
||||||
|
" f.write(struct.pack('128f', *reduced_embeddings[token_id]))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.19"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
179
intention-classify/inference-js.src.md
Normal file
179
intention-classify/inference-js.src.md
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
<!-- srcbook:{"language":"typescript","tsconfig.json":{"compilerOptions":{"module":"nodenext","moduleResolution":"nodenext","target":"es2022","resolveJsonModule":true,"noEmit":true,"allowImportingTsExtensions":true},"include":["src/**/*"],"exclude":["node_modules"]}} -->
|
||||||
|
|
||||||
|
# sparkastML intention classification
|
||||||
|
|
||||||
|
###### package.json
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "module",
|
||||||
|
"dependencies": {
|
||||||
|
"@types/node": "latest",
|
||||||
|
"@xenova/transformers": "^2.17.2",
|
||||||
|
"onnxruntime-web": "^1.19.0",
|
||||||
|
"tsx": "latest",
|
||||||
|
"typescript": "latest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
###### tokenizer.ts
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export type TokenDict = { [key: string]: number };
|
||||||
|
|
||||||
|
type TokenDict = { [key: string]: number };
|
||||||
|
|
||||||
|
function tokenize(query: string, tokenDict: TokenDict): number[] {
|
||||||
|
const tokenIds: number[] = [];
|
||||||
|
let index = 0;
|
||||||
|
|
||||||
|
// Replace spaces with "▁"
|
||||||
|
query = "▁"+query.replace(/ /g, "▁");
|
||||||
|
query = query.replace(/\n/g, "<0x0A>");
|
||||||
|
|
||||||
|
while (index < query.length) {
|
||||||
|
let bestToken = null;
|
||||||
|
let bestLength = 0;
|
||||||
|
|
||||||
|
// Step 2: Find the longest token that matches the beginning of the remaining query
|
||||||
|
for (const token in tokenDict) {
|
||||||
|
if (query.startsWith(token, index) && token.length > bestLength) {
|
||||||
|
bestToken = token;
|
||||||
|
bestLength = token.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bestToken) {
|
||||||
|
tokenIds.push(tokenDict[bestToken]);
|
||||||
|
index += bestLength;
|
||||||
|
} else {
|
||||||
|
// Step 3: Handle the case where no token matches
|
||||||
|
const char = query[index];
|
||||||
|
if (char.charCodeAt(0) <= 127) {
|
||||||
|
// If the character is ASCII, and it doesn't match any token, treat it as an unknown token
|
||||||
|
throw new Error(`Unknown token: ${char}`);
|
||||||
|
} else {
|
||||||
|
// If the character is non-ASCII, convert it to a series of bytes and match each byte
|
||||||
|
const bytes = new TextEncoder().encode(char);
|
||||||
|
for (const byte of bytes) {
|
||||||
|
const byteToken = `<0x${byte.toString(16).toUpperCase()}>`;
|
||||||
|
if (tokenDict[byteToken] !== undefined) {
|
||||||
|
tokenIds.push(tokenDict[byteToken]);
|
||||||
|
} else {
|
||||||
|
throw new Error(`Unknown byte token: ${byteToken}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenIds;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default tokenize
|
||||||
|
```
|
||||||
|
|
||||||
|
###### embedding.ts
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import * as fs from 'fs';
|
||||||
|
import * as path from 'path';
|
||||||
|
|
||||||
|
type EmbeddingDict = { [key: number]: Float32Array };
|
||||||
|
|
||||||
|
function getEmbeddingLayer(buffer: Buffer): EmbeddingDict {
|
||||||
|
const dict: EmbeddingDict = {};
|
||||||
|
|
||||||
|
const entrySize = 514;
|
||||||
|
const numEntries = buffer.length / entrySize;
|
||||||
|
|
||||||
|
for (let i = 0; i < numEntries; i++) {
|
||||||
|
const offset = i * entrySize;
|
||||||
|
const key = buffer.readUInt16LE(offset);
|
||||||
|
const floatArray = new Float32Array(128);
|
||||||
|
|
||||||
|
for (let j = 0; j < 128; j++) {
|
||||||
|
floatArray[j] = buffer.readFloatLE(offset + 2 + j * 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
dict[key] = floatArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
return dict;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getEmbedding(tokenIds: number[], embeddingDict: EmbeddingDict, contextSize: number) {
|
||||||
|
let result = [];
|
||||||
|
for (let i = 0; i < contextSize; i++) {
|
||||||
|
if (i < tokenIds.length) {
|
||||||
|
const tokenId = tokenIds[i];
|
||||||
|
result = result.concat(Array.from(embeddingDict[tokenId]))
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
result = result.concat(new Array(128).fill(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new Float32Array(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
export {getEmbeddingLayer, getEmbedding};
|
||||||
|
```
|
||||||
|
|
||||||
|
###### load.ts
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import * as ort from 'onnxruntime-web';
|
||||||
|
import * as fs from 'fs';
|
||||||
|
import tokenize, {TokenDict} from "./tokenizer.ts"
|
||||||
|
import {getEmbeddingLayer, getEmbedding} from "./embedding.ts"
|
||||||
|
|
||||||
|
const embedding_file = './token_embeddings.bin';
|
||||||
|
const embedding_data = fs.readFileSync(embedding_file);
|
||||||
|
const embedding_buffer = Buffer.from(embedding_data);
|
||||||
|
const query = `Will it rain tomorrow`;
|
||||||
|
const model_path = './model.onnx';
|
||||||
|
const vocabData = fs.readFileSync('./token_to_id.json');
|
||||||
|
const vocabDict = JSON.parse(vocabData.toString());
|
||||||
|
|
||||||
|
let lastLogCall = new Date().getTime();
|
||||||
|
|
||||||
|
function log(task: string) {
|
||||||
|
const currentTime = new Date().getTime();
|
||||||
|
const costTime = currentTime - lastLogCall;
|
||||||
|
console.log(`[${currentTime}] (+${costTime}ms) ${task}`)
|
||||||
|
lastLogCall = new Date().getTime();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadModel(modelPath: string) {
|
||||||
|
const session = await ort.InferenceSession.create(modelPath);
|
||||||
|
return session;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function runInference(query: string, embedding_buffer: Buffer, modelPath: string, vocabDict: TokenDict) {
|
||||||
|
const session = await loadModel(modelPath);
|
||||||
|
log("loadModel:end");
|
||||||
|
const inputText = query;
|
||||||
|
const queryLength = query.length;
|
||||||
|
const tokenIds = await tokenize(query, vocabDict);
|
||||||
|
log("tokenize:end");
|
||||||
|
const embeddingDict = getEmbeddingLayer(embedding_buffer);
|
||||||
|
const e = getEmbedding(tokenIds, embeddingDict, 12);
|
||||||
|
log("getEmbedding:end");
|
||||||
|
|
||||||
|
const inputTensor = new ort.Tensor('float32', e, [1, 12, 128]);
|
||||||
|
|
||||||
|
const feeds = { 'input': inputTensor };
|
||||||
|
const results = await session.run(feeds);
|
||||||
|
log("inference:end");
|
||||||
|
|
||||||
|
const output = results.output.data;
|
||||||
|
const predictedClassIndex = output.indexOf(Math.max(...output));
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("Perdicted class:", await runInference(query, embedding_buffer, model_path, vocabDict));
|
||||||
|
```
|
8
intention-classify/noise.json
Normal file
8
intention-classify/noise.json
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
[
|
||||||
|
"你好",
|
||||||
|
"你好谢谢小笼包再见",
|
||||||
|
"我爱你",
|
||||||
|
"嘿嘿嘿诶嘿",
|
||||||
|
"为什么",
|
||||||
|
"拼多多"
|
||||||
|
]
|
575
intention-classify/train.ipynb
Normal file
575
intention-classify/train.ipynb
Normal file
@ -0,0 +1,575 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a6a3195f-d099-4bf4-846f-51f403954818",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# sparkastML: Training the Intention Classification Model\n",
|
||||||
|
"\n",
|
||||||
|
"This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n",
|
||||||
|
"\n",
|
||||||
|
"In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import json\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
|
"from torch.nn.utils.rnn import pad_sequence\n",
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from transformers import AutoTokenizer, AutoModel\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from scipy.spatial.distance import euclidean\n",
|
||||||
|
"from scipy.stats import weibull_min\n",
|
||||||
|
"from sklearn.preprocessing import normalize\n",
|
||||||
|
"import torch.nn.functional as F\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n",
|
||||||
|
"DIMENSIONS = 128\n",
|
||||||
|
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "1ae14906-338d-4c99-87ed-bb1acd22b295",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load Data\n",
|
||||||
|
"\n",
|
||||||
|
"We load the data from `data.json`, and also get the negative sample from the `noise.json`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "a206071c-ce4e-4de4-b936-bfc70d13708a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
|
||||||
|
" embeddings = torch.tensor(embeddings)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Load data\n",
|
||||||
|
"with open('data.json', 'r') as f:\n",
|
||||||
|
" data = json.load(f)\n",
|
||||||
|
"\n",
|
||||||
|
"# Create map: class to index\n",
|
||||||
|
"class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n",
|
||||||
|
"idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n",
|
||||||
|
"\n",
|
||||||
|
"# Preprocess data, convert sentences to the format of (class idx, embedding)\n",
|
||||||
|
"def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n",
|
||||||
|
" dataset = []\n",
|
||||||
|
" for label, sentences in data.items():\n",
|
||||||
|
" for sentence in sentences:\n",
|
||||||
|
" # Tokenize the sentence and convert tokens to embedding vectors\n",
|
||||||
|
" tokens = tokenizer.tokenize(sentence)\n",
|
||||||
|
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||||
|
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
||||||
|
" embeddings = torch.tensor(embeddings)\n",
|
||||||
|
" dataset.append((class_to_idx[label], embeddings))\n",
|
||||||
|
" return dataset\n",
|
||||||
|
"\n",
|
||||||
|
"# Load embedding map\n",
|
||||||
|
"embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n",
|
||||||
|
"\n",
|
||||||
|
"# Get preprocessed dataset\n",
|
||||||
|
"dataset = preprocess_data(data, embedding_map, tokenizer)\n",
|
||||||
|
"\n",
|
||||||
|
"# Train-test split\n",
|
||||||
|
"train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n",
|
||||||
|
"\n",
|
||||||
|
"class TextDataset(Dataset):\n",
|
||||||
|
" def __init__(self, data):\n",
|
||||||
|
" self.data = data\n",
|
||||||
|
"\n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.data)\n",
|
||||||
|
"\n",
|
||||||
|
" def __getitem__(self, idx):\n",
|
||||||
|
" return self.data[idx]\n",
|
||||||
|
"\n",
|
||||||
|
" def collate_fn(self, batch):\n",
|
||||||
|
" labels, embeddings = zip(*batch)\n",
|
||||||
|
" labels = torch.tensor(labels)\n",
|
||||||
|
" embeddings = pad_sequence(embeddings, batch_first=True)\n",
|
||||||
|
" return labels, embeddings\n",
|
||||||
|
"\n",
|
||||||
|
"train_dataset = TextDataset(train_data)\n",
|
||||||
|
"val_dataset = TextDataset(val_data)\n",
|
||||||
|
"\n",
|
||||||
|
"train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n",
|
||||||
|
"val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"from torch.utils.data import Dataset, DataLoader\n",
|
||||||
|
"\n",
|
||||||
|
"class NegativeSampleDataset(Dataset):\n",
|
||||||
|
" def __init__(self, negative_samples):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" negative_samples: List or array of negative sample embeddings or raw text\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" self.samples = negative_samples\n",
|
||||||
|
" \n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return len(self.samples)\n",
|
||||||
|
" \n",
|
||||||
|
" def __getitem__(self, idx):\n",
|
||||||
|
" return self.samples[idx]\n",
|
||||||
|
"\n",
|
||||||
|
" def collate_fn(self, batch):\n",
|
||||||
|
" embeddings = pad_sequence(batch, batch_first=True)\n",
|
||||||
|
" return embeddings\n",
|
||||||
|
"\n",
|
||||||
|
"with open('noise.json', 'r') as f:\n",
|
||||||
|
" negative_samples_list = json.load(f)\n",
|
||||||
|
"\n",
|
||||||
|
"negative_embedding_list = []\n",
|
||||||
|
"\n",
|
||||||
|
"for sentence in negative_samples_list:\n",
|
||||||
|
" tokens = tokenizer.tokenize(sentence)\n",
|
||||||
|
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||||
|
" embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n",
|
||||||
|
" embeddings = torch.tensor(embeddings)\n",
|
||||||
|
" negative_embedding_list.append(embeddings)\n",
|
||||||
|
"\n",
|
||||||
|
"negative_dataset = NegativeSampleDataset(negative_embedding_list)\n",
|
||||||
|
"negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "600febe4-2484-4aad-90a1-2bc821fdce1a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Implementating the Model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "adf624ac-ad63-437b-95f6-b02b7253b91e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"\n",
|
||||||
|
"class TextCNN(nn.Module):\n",
|
||||||
|
" def __init__(self, input_dim, num_classes):\n",
|
||||||
|
" super(TextCNN, self).__init__()\n",
|
||||||
|
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
|
||||||
|
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
|
||||||
|
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
|
||||||
|
" \n",
|
||||||
|
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||||
|
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||||
|
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||||
|
" \n",
|
||||||
|
" self.dropout = nn.Dropout(0.5)\n",
|
||||||
|
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
|
||||||
|
"\n",
|
||||||
|
" def forward(self, x):\n",
|
||||||
|
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
|
||||||
|
" \n",
|
||||||
|
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
|
||||||
|
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
|
||||||
|
" \n",
|
||||||
|
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
|
||||||
|
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
|
||||||
|
" \n",
|
||||||
|
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
|
||||||
|
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
|
||||||
|
" \n",
|
||||||
|
" x = torch.cat((x1, x2, x3), dim=1)\n",
|
||||||
|
" x = self.dropout(x)\n",
|
||||||
|
" x = self.fc(x)\n",
|
||||||
|
" return x\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize model\n",
|
||||||
|
"input_dim = DIMENSIONS\n",
|
||||||
|
"num_classes = len(class_to_idx)\n",
|
||||||
|
"model = TextCNN(input_dim, num_classes)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2750e17d-8a60-40c7-851b-1a567d0ee82b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Energy-based Models"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def energy_score(logits):\n",
|
||||||
|
" # Energy score is minus logsumexp\n",
|
||||||
|
" return -torch.logsumexp(logits, dim=1)\n",
|
||||||
|
"\n",
|
||||||
|
"def generate_noise(batch_size, seq_length ,input_dim, device):\n",
|
||||||
|
" # Generate a Gaussian noise\n",
|
||||||
|
" return torch.randn(batch_size, seq_length, input_dim).to(device)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Training"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch [1/50], Loss: 12.5108\n",
|
||||||
|
"Epoch [2/50], Loss: 10.7305\n",
|
||||||
|
"Epoch [3/50], Loss: 10.2943\n",
|
||||||
|
"Epoch [4/50], Loss: 9.9350\n",
|
||||||
|
"Epoch [5/50], Loss: 9.7991\n",
|
||||||
|
"Epoch [6/50], Loss: 9.6443\n",
|
||||||
|
"Epoch [7/50], Loss: 9.4762\n",
|
||||||
|
"Epoch [8/50], Loss: 9.4637\n",
|
||||||
|
"Epoch [9/50], Loss: 9.3025\n",
|
||||||
|
"Epoch [10/50], Loss: 9.1719\n",
|
||||||
|
"Epoch [11/50], Loss: 9.0632\n",
|
||||||
|
"Epoch [12/50], Loss: 8.9741\n",
|
||||||
|
"Epoch [13/50], Loss: 8.8487\n",
|
||||||
|
"Epoch [14/50], Loss: 8.6565\n",
|
||||||
|
"Epoch [15/50], Loss: 8.5830\n",
|
||||||
|
"Epoch [16/50], Loss: 8.4196\n",
|
||||||
|
"Epoch [17/50], Loss: 8.2319\n",
|
||||||
|
"Epoch [18/50], Loss: 8.0655\n",
|
||||||
|
"Epoch [19/50], Loss: 7.7140\n",
|
||||||
|
"Epoch [20/50], Loss: 7.6921\n",
|
||||||
|
"Epoch [21/50], Loss: 7.3375\n",
|
||||||
|
"Epoch [22/50], Loss: 7.2297\n",
|
||||||
|
"Epoch [23/50], Loss: 6.8833\n",
|
||||||
|
"Epoch [24/50], Loss: 6.8534\n",
|
||||||
|
"Epoch [25/50], Loss: 6.4557\n",
|
||||||
|
"Epoch [26/50], Loss: 6.1365\n",
|
||||||
|
"Epoch [27/50], Loss: 5.8558\n",
|
||||||
|
"Epoch [28/50], Loss: 5.5030\n",
|
||||||
|
"Epoch [29/50], Loss: 5.1604\n",
|
||||||
|
"Epoch [30/50], Loss: 4.7742\n",
|
||||||
|
"Epoch [31/50], Loss: 4.5958\n",
|
||||||
|
"Epoch [32/50], Loss: 4.0713\n",
|
||||||
|
"Epoch [33/50], Loss: 3.8872\n",
|
||||||
|
"Epoch [34/50], Loss: 3.5240\n",
|
||||||
|
"Epoch [35/50], Loss: 3.3115\n",
|
||||||
|
"Epoch [36/50], Loss: 2.5667\n",
|
||||||
|
"Epoch [37/50], Loss: 2.6709\n",
|
||||||
|
"Epoch [38/50], Loss: 1.8075\n",
|
||||||
|
"Epoch [39/50], Loss: 1.6654\n",
|
||||||
|
"Epoch [40/50], Loss: 0.4622\n",
|
||||||
|
"Epoch [41/50], Loss: 0.4719\n",
|
||||||
|
"Epoch [42/50], Loss: -0.4037\n",
|
||||||
|
"Epoch [43/50], Loss: -0.9405\n",
|
||||||
|
"Epoch [44/50], Loss: -1.7204\n",
|
||||||
|
"Epoch [45/50], Loss: -2.4124\n",
|
||||||
|
"Epoch [46/50], Loss: -3.0032\n",
|
||||||
|
"Epoch [47/50], Loss: -2.7123\n",
|
||||||
|
"Epoch [48/50], Loss: -3.6953\n",
|
||||||
|
"Epoch [49/50], Loss: -3.7212\n",
|
||||||
|
"Epoch [50/50], Loss: -3.7558\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import torch.optim as optim\n",
|
||||||
|
"\n",
|
||||||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||||||
|
"optimizer = optim.Adam(model.parameters(), lr=8e-4)\n",
|
||||||
|
"\n",
|
||||||
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||||
|
"import tensorboard\n",
|
||||||
|
"writer = SummaryWriter()\n",
|
||||||
|
"\n",
|
||||||
|
"def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n",
|
||||||
|
" model.train()\n",
|
||||||
|
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
" model.to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" negative_iter = iter(negative_loader)\n",
|
||||||
|
" \n",
|
||||||
|
" for epoch in range(num_epochs):\n",
|
||||||
|
" total_loss = 0\n",
|
||||||
|
" for batch_idx, (labels, embeddings) in enumerate(train_loader):\n",
|
||||||
|
" embeddings = embeddings.to(device)\n",
|
||||||
|
" labels = labels.to(device)\n",
|
||||||
|
" \n",
|
||||||
|
" batch_size = embeddings.size(0)\n",
|
||||||
|
" \n",
|
||||||
|
" # ---------------------\n",
|
||||||
|
" # 1. Positive sample\n",
|
||||||
|
" # ---------------------\n",
|
||||||
|
" optimizer.zero_grad()\n",
|
||||||
|
" outputs = model(embeddings) # logits from the model\n",
|
||||||
|
" \n",
|
||||||
|
" class_loss = criterion(outputs, labels)\n",
|
||||||
|
" \n",
|
||||||
|
" # Energy of positive sample\n",
|
||||||
|
" known_energy = energy_score(outputs)\n",
|
||||||
|
" energy_loss_known = known_energy.mean()\n",
|
||||||
|
" \n",
|
||||||
|
" # ------------------------------------\n",
|
||||||
|
" # 2. Negative sample - Random Noise\n",
|
||||||
|
" # ------------------------------------\n",
|
||||||
|
" noise_embeddings = torch.randn_like(embeddings).to(device)\n",
|
||||||
|
" noise_outputs = model(noise_embeddings)\n",
|
||||||
|
" noise_energy = energy_score(noise_outputs)\n",
|
||||||
|
" energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n",
|
||||||
|
" \n",
|
||||||
|
" # ------------------------------------\n",
|
||||||
|
" # 3. Negative sample - custom corpus\n",
|
||||||
|
" # ------------------------------------\n",
|
||||||
|
" \n",
|
||||||
|
" try:\n",
|
||||||
|
" negative_samples = next(negative_iter)\n",
|
||||||
|
" except StopIteration:\n",
|
||||||
|
" negative_iter = iter(negative_loader)\n",
|
||||||
|
" negative_samples = next(negative_iter)\n",
|
||||||
|
" negative_samples = negative_samples.to(device)\n",
|
||||||
|
" negative_outputs = model(negative_samples)\n",
|
||||||
|
" negative_energy = energy_score(negative_outputs)\n",
|
||||||
|
" energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n",
|
||||||
|
" \n",
|
||||||
|
" # -----------------------------\n",
|
||||||
|
" # 4. Overall Loss calculation\n",
|
||||||
|
" # -----------------------------\n",
|
||||||
|
" total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n",
|
||||||
|
" total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n",
|
||||||
|
"\n",
|
||||||
|
" writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n",
|
||||||
|
" writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n",
|
||||||
|
" writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n",
|
||||||
|
" \n",
|
||||||
|
" total_loss_batch.backward()\n",
|
||||||
|
" optimizer.step()\n",
|
||||||
|
" \n",
|
||||||
|
" total_loss += total_loss_batch.item()\n",
|
||||||
|
" \n",
|
||||||
|
" avg_loss = total_loss / len(train_loader)\n",
|
||||||
|
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n",
|
||||||
|
"\n",
|
||||||
|
"train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n",
|
||||||
|
"writer.flush()\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e6d29558-f497-4033-8488-169bd25ce881",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Evalutation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "472702e6-db4a-4faa-9e92-7510e6eacbb1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"ENERGY_THRESHOLD = -3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Accuracy: 0.9315\n",
|
||||||
|
"Precision: 1.0000\n",
|
||||||
|
"Recall: 0.9254\n",
|
||||||
|
"F1 Score: 0.9612\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n",
|
||||||
|
"\n",
|
||||||
|
"def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
|
" \n",
|
||||||
|
" all_preds = []\n",
|
||||||
|
" all_labels = []\n",
|
||||||
|
" \n",
|
||||||
|
" # Evaluate positive sample\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" for labels, embeddings in known_loader:\n",
|
||||||
|
" embeddings = embeddings.to(device)\n",
|
||||||
|
" logits = model(embeddings)\n",
|
||||||
|
" energy = energy_score(logits)\n",
|
||||||
|
" \n",
|
||||||
|
" preds = (energy <= energy_threshold).long()\n",
|
||||||
|
" all_preds.extend(preds.cpu().numpy())\n",
|
||||||
|
" all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n",
|
||||||
|
" \n",
|
||||||
|
" # Evaluate negative sample\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" for embeddings in unknown_loader:\n",
|
||||||
|
" embeddings = embeddings.to(device)\n",
|
||||||
|
" logits = model(embeddings)\n",
|
||||||
|
" energy = energy_score(logits)\n",
|
||||||
|
" \n",
|
||||||
|
" preds = (energy <= energy_threshold).long()\n",
|
||||||
|
" all_preds.extend(preds.cpu().numpy())\n",
|
||||||
|
" all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n",
|
||||||
|
" \n",
|
||||||
|
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
||||||
|
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
||||||
|
"\n",
|
||||||
|
" print(f'Accuracy: {accuracy:.4f}')\n",
|
||||||
|
" print(f'Precision: {precision:.4f}')\n",
|
||||||
|
" print(f'Recall: {recall:.4f}')\n",
|
||||||
|
" print(f'F1 Score: {f1:.4f}')\n",
|
||||||
|
"\n",
|
||||||
|
"evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 20,
|
||||||
|
"id": "ba614054-75e1-4f61-ace5-aeb11e29a222",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Save the model\n",
|
||||||
|
"torch.save(model, \"model.pt\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Inference"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 19,
|
||||||
|
"id": "03928d75-81c8-4298-ab8a-d7f8a758b561",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" tokens = tokenizer.tokenize(sentence)\n",
|
||||||
|
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||||
|
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
||||||
|
" embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n",
|
||||||
|
" \n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" logits = model(embeddings)\n",
|
||||||
|
" probabilities = F.softmax(logits, dim=1)\n",
|
||||||
|
" max_prob, predicted = torch.max(probabilities, 1)\n",
|
||||||
|
" \n",
|
||||||
|
" # Calculate energy score\n",
|
||||||
|
" energy = energy_score(logits)\n",
|
||||||
|
"\n",
|
||||||
|
" # If energy > threshold, consider the input as unknown class\n",
|
||||||
|
" if energy.item() > energy_threshold:\n",
|
||||||
|
" return [\"Unknown\", max_prob.item(), energy.item()]\n",
|
||||||
|
" else:\n",
|
||||||
|
" return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n",
|
||||||
|
"\n",
|
||||||
|
"# Example usage:\n",
|
||||||
|
"sentence = \"weather today\"\n",
|
||||||
|
"energy_threshold = ENERGY_THRESHOLD\n",
|
||||||
|
"predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n",
|
||||||
|
"print(f'Predicted: {predicted}')\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.19"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user