commit f28f83b48ec1ec6394fa5a2e6ac0fc0e69fcebdf Author: alikia2x Date: Sun Sep 1 22:17:04 2024 +0800 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..610e5e3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +runs +.DS_Store +*.onnx +*.pt +*.bin +token_to_id.json +.ipynb_checkpoints \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..ea7c668 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 alikia2x + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/intention-classify/LICENSE b/intention-classify/LICENSE new file mode 100644 index 0000000..8ab7b49 --- /dev/null +++ b/intention-classify/LICENSE @@ -0,0 +1,22 @@ +Microsoft. +Copyright (c) Microsoft Corporation. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/intention-classify/convert_onnx.ipynb b/intention-classify/convert_onnx.ipynb new file mode 100644 index 0000000..92c684a --- /dev/null +++ b/intention-classify/convert_onnx.ipynb @@ -0,0 +1,126 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "94ff7007", + "metadata": {}, + "source": [ + "# Convert to ONNX\n", + "\n", + "This notebook converts our model to [ONNX](https://onnx.ai/) format, which is the open standard for machine learning interoperability. In this way, we can run our model in JS (browser)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "311162fd-f957-4746-b524-25bb3e09efbc", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "from torch import nn\n", + "import torch.utils.model_zoo as model_zoo\n", + "import torch.onnx\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fd182c6d-1e77-4bbb-bb53-8321d40ae002", + "metadata": {}, + "outputs": [], + "source": [ + "class TextCNN(nn.Module):\n", + " def __init__(self, input_dim, num_classes):\n", + " super(TextCNN, self).__init__()\n", + " self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n", + " self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n", + " self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n", + " \n", + " self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n", + " self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n", + " self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n", + " \n", + " self.dropout = nn.Dropout(0.5)\n", + " self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n", + "\n", + " def forward(self, x):\n", + " x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n", + " \n", + " x1 = F.relu(self.bn1(self.conv1(x)))\n", + " x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n", + " \n", + " x2 = F.relu(self.bn2(self.conv2(x)))\n", + " x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n", + " \n", + " x3 = F.relu(self.bn3(self.conv3(x)))\n", + " x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n", + " \n", + " x = torch.cat((x1, x2, x3), dim=1)\n", + " x = self.dropout(x)\n", + " x = self.fc(x)\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bdb597cb-d896-485c-8c9c-897b1d35e8d2", + "metadata": {}, + "outputs": [], + "source": [ + "model = torch.load(\"model.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9f7a6e64-75f2-4fa9-8d1e-b83099765d02", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Example input: use random embedding vector to simulate real input\n", + "dummy_input = torch.randn(1, 64, 128) # (batch_size, seq_length, embedding_dim)\n", + "\n", + "# Export model\n", + "torch.onnx.export(\n", + " model, # The model to export\n", + " dummy_input, # Example input\n", + " \"model.onnx\", # File name\n", + " input_names=['input'], # Input name (Could customize)\n", + " output_names=['output'], # Output name (Could customize)\n", + " dynamic_axes={\n", + " 'input': {0: 'batch_size', 1: 'seq_length'}, # Dynamic batch and sequence length\n", + " 'output': {0: 'batch_size'}\n", + " },\n", + " opset_version=11 # ONNX version,ensure the ONNX runtime supports it\n", + ")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/intention-classify/data.json b/intention-classify/data.json new file mode 100644 index 0000000..27b5e93 --- /dev/null +++ b/intention-classify/data.json @@ -0,0 +1,350 @@ +{ + "weather": [ + "天气如何?", + "现在的天气", + "今天的天气预报", + "现在的天气状况", + "今天天气怎么样", + "目前是什么天气", + "今天的天气概述", + "当前天气状况如何", + "今天会下雨吗?", + "今天晴天吗?", + "今天的天气状况如何", + "现在外面是什么天气", + "今天天气好么", + "今天有没有雾霾", + "明天有没有大风", + "今天会不会很冷", + "今天的天气会变化吗", + "今天晚上的天气如何", + "今天夜里会下雨吗", + "明天会下雨吗?", + "北京今天的预报", + "两天后上海的天气情况", + "现在的温度", + "现在多少度", + "外面有多热", + "明天热不热", + "现在的气温是多少", + "今天最高温度是多少", + "今天最低温度是多少", + "现在外面感觉冷吗", + "明天会比今天热吗", + "明天会比今天冷吗", + "今天的温度变化大吗", + "室外的温度是多少", + "达拉斯今天热不热", + "苏州现在天气怎么样", + "how's the weather", + "What's going on with the weather?", + "Can you give me an update on the weather?", + "How's the forecast looking today?", + "Give me a summary of the current weather.", + "Can you tell me the current weather?", + "What is the weather situation at the moment?", + "Could you provide a quick weather update?", + "Is it raining or sunny outside?", + "What's the weather like right now?", + "Tell me the current weather conditions.", + "How about the weather today?", + "What's the weather looking like for the next few hours?", + "Is it going to stay this way all day?", + "Could you give me a brief overview of the weather?", + "What's the general weather situation in our area?", + "Is it cloudy or clear outside?", + "What's the forecast saying for today's weather?", + "Is it going to be a warm day?", + "Are we expecting any storms today?", + "What's the weather condition outside my window?", + "Is it a typical day for this season in terms of weather?", + "how's the weather now?", + "What's the temperature like right now?", + "Can you tell me the current temperature?", + "How hot is it outside?", + "What's the temperature supposed to be today?", + "What is the current temp outside?", + "Could you tell me the outdoor temperature?", + "Is it cold or warm outside?", + "What's the high temperature for today?", + "What's the low temperature expected tonight?", + "How does the temperature feel outside?", + "Is it going to get warmer or cooler today?", + "What's the temperature in the shade?", + "Can you provide the current temp in Celsius?", + "What's the temperature in Fahrenheit right now?", + "Is it too hot to be outside?", + "What's the temperature like in the morning?", + "How about the temperature in the evening?", + "Is it warm enough to go swimming?", + "What's the temperature in the city center?", + "Can you tell me the temp in the nearby area?", + "Is it below freezing outside?", + "What's the average temperature for today?", + "Is the temperature dropping or rising?", + "What should I wear considering the temperature?" + ], + "base64": [ + "请将数据使用base64编码", + "需要将以下数据base 64编码", + "请将此字符串转为base 58", + "将数据转为base58编码", + "信息base64编码", + "请帮忙编码base64", + "将数据编码为base64", + "base 64编码", + "base64 在线编码", + "在线编码 base 64", + "编码 base64", + "请解码这个base64数据", + "有base64编码字符串需要解码", + "帮忙解码base64", + "将base 58编码转回原数据", + "解码base58信息", + "解码这个base64", + "将base64转文本", + "base64解码", + "base64 解码", + "base58 在线解码", + "在线解码 base58", + "解码 base64", + "Please encode this data with base64:", + "I need to encode the following data in base64", + "Could you encode this string using base64?", + "Convert this data to b64 encoding", + "I want to encode this information with base64", + "Help me encode this in base32", + "Can you encode this data to base64 format?", + "b64 encode", + "base64 encode", + "encode base64", + "base 64 encode online" + ], + + "url-encode": [ + "编码 url", + "URL部分需要编码", + "请将URL部分编码", + "URL编码转换", + "编码url 段", + "URL数据编码", + "请解码这个URL", + "有 url 编码需要解码", + "解码这个URL", + "URL编码转回原URL", + "解码 URL部分", + "解码 URL 段", + "URL 编码转文本", + "encode URL: ", + "I need to encode this URL component: ", + "Convert url encoded", + "encode this URL for safe transmission", + "Encode this URL segment: ", + "decode this url ", + "convert encoded url to text", + "url decoder", + "URL encoder" + ], + + "html-encode": [ + "请编码HTML实体", + "文本转为HTML实体", + "编码为HTML实体", + "文本HTML实体编码", + "预防HTML解析编码", + "HTML实体编码", + "文本HTML使用编码", + "html 转义", + "请解码HTML实体", + "HTML实体需要解码", + "解码HTML实体", + "HTML实体转回文本", + "HTML实体解码", + "解码HTML实体", + "HTML实体转文本", + "html 实体转义", + "html实体 转换", + "html   转换", + "html nbsp 意思", + "Please encode HTML entities", + "Convert text to HTML entities", + "Encode to HTML entities", + "Encode text HTML entities", + "Prevent HTML parsing encoding", + "HTML entity encoding", + "Text HTML uses encoding", + "html escape", + "Please decode HTML entities", + "HTML entities need to be decoded", + "Decode HTML entities", + "Convert HTML entities to text", + "Decode HTML entities", + "Decode HTML entities", + "Convert HTML entities to text", + "HTML entity escape", + "html entity conversion", + "html   conversion", + "html nbsp meaning" + ], + + "ai.command": [ + "写一个TypeScript的HelloWorld代码", + "检查以下内容的语法和清晰度", + "帮助我学习词汇:为我写一个句子让我填空,我会尝试选择正确的选项", + "改进这个Markdown内容的语法和表达,用于GitHub仓库的README", + "你能想出一个我包的短名字吗", + "简化这段代码", + "帮助我学习中文", + "如何在网页开发中让屏幕阅读器自动聚焦到新弹出的元素上", + "总结以下文本:", + "这段代码有什么问题吗,或者可以简化吗?", + "生成一个打印'Hello, World!'的Python脚本", + "你能校对这篇论文的语法和标点错误吗?", + "为单词'serendipity'创建十个例句列表", + "你能重新格式化这个JSON使其更易读吗?", + "为我的关于健康饮食的博客文章建议一个有创意的标题。", + "重构这个JavaScript函数使其更高效。", + "帮助我练习法语:提供一个有缺失单词的句子,我可以猜。", + "如何在模态框出现在网页上时让按钮自动聚焦", + "为我总结这篇新闻文章", + "你能审查这段代码片段的潜在安全漏洞吗", + "生成一个SQL查询来查找所有在过去30天内注册的用户", + "你能把这个段落翻译成西班牙语吗", + "根据以下过程描述创建一个流程图。", + "写一个计算数字阶乘的Python函数。", + "提供如何在Web应用程序中实现OAuth2的详细解释。", + "你能优化这张图片以便在网站上更快加载吗?", + "为专注于健身的新移动应用建议一些吸引人的标语", + "写一个每天备份我的文档文件夹的Bash脚本", + "帮助我起草一封请求与潜在客户会面的电子邮件。", + "你能把这个Markdown文档转换成HTML吗?", + "生成一个从指定网站抓取数据的Python脚本。", + "你能找到单词'meticulous'的同义词吗?", + "写一个基于公共列连接两个表的SQL查询。", + "根据以下指令集创建一个流程图。", + "为一个新的任务管理工具建议一个名字。", + "在不改变其功能的情况下简化这段Python代码。", + "你能帮助我学习日语吗?", + "如何在用户点击网页上的按钮时让警告框出现?", + "把这个研究论文总结成要点", + "你能检查这个算法中是否有任何逻辑错误吗?", + "write a typescript helloworld code", + "Check the following content for grammar and clarity", + "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", + "Improve this markdown content in asepcts like grammar and expression, for a GitHub repo README.", + "can u think of a short name of my package", + "simplify this code", + "help me learn chinese", + "how to let the screen reader automatically focused to an newly poped up element in the web development", + "summarize following text:", + "Is there anything wrong with this code or can it be simplified?", + "generate a Python script that prints 'Hello, World!'", + "Can you proofread this essay for grammar and punctuation errors?", + "Create a list of ten example sentences for the word 'serendipity.'", + "Can you reformat this JSON to be more readable?", + "Suggest a creative title for my blog post about healthy eating.", + "Refactor this JavaScript function to make it more efficient.", + "Help me practice French: provide a sentence with a missing word that I can guess.", + "How do I make a button automatically focus when a modal appears on a webpage?", + "Summarize this news article for me.", + "Can you review this code snippet for potential security vulnerabilities?", + "Generate a SQL query to find all users who signed up in the last 30 days.", + "Can you translate this paragraph into Spanish?", + "Create a flowchart based on the following process description.", + "Write a Python function to calculate the factorial of a number.", + "Provide a detailed explanation of how to implement OAuth2 in a web application.", + "Can you optimize this image for faster loading on a website?", + "Suggest some catchy taglines for a new mobile app focused on fitness.", + "Write a Bash script to back up my documents folder daily.", + "Help me draft an email to request a meeting with a potential client.", + "Can you convert this Markdown document into HTML?", + "Generate a Python script that scrapes data from a specified website.", + "Can you find the synonyms of the word 'meticulous'?", + "Write a SQL query to join two tables based on a common column.", + "Create a flowchart based on the following set of instructions.", + "Suggest a name for a new task management tool.", + "Simplify this Python code without changing its functionality.", + "Can you assist me in learning Japanese?", + "How can I make an alert box appear when a user clicks a button on a webpage?", + "Summarize this research paper into bullet points.", + "Can you check if there are any logical errors in this algorithm?" + ], + + "ai.question": [ + "你认为哪个框架最适合性能敏感的项目?", + "什么是后量子密码学?", + "什么是密钥派生函数", + "什么是线性代数?", + "线性代数在计算机科学中的主要用途是什么?", + "我应该使用哪个IDE来编写Go语言?", + "Go vs Java vs Kotlin,哪个适合后端", + "哪种编程语言最适合数据分析", + "什么是量子计算", + "什么是哈希函数?", + "什么是微积分?", + "机器学习在金融中的主要应用有哪些?", + "写Python代码最好的文本编辑器是哪个?", + "Python vs R vs Julia,哪个更适合数据科学?", + "监督学习和无监督学习的关键区别是什么?", + "数据库在Web应用程序中的作用是什么", + "什么是区块链技术", + "使用Docker进行应用程序部署的优势是什么?", + "哪个云服务提供商提供最好的AI工具?", + "加密是如何工作的?", + "负载均衡器在网络架构中的目的是什么?", + "机器学习和深度学习有什么区别", + "软件工程中最常见的设计模式有哪些", + "神经网络是如何学习的", + "使用微服务架构的主要好处是什么", + "编译器和解释器有什么区别?", + "Rust编程语言的关键特性是什么?", + "HTTP和HTTPS有什么区别", + "使用像Git这样的版本控制系统有什么优势?", + "什么是'边缘计算'的概念", + "哪种编程语言最适合构建移动应用?", + "关系数据库和NoSQL数据库有什么不同?", + "算法在计算机科学中的重要性是什么?", + "API在软件开发中的作用是什么", + "保护Web应用程序的最佳实践是什么", + "虚拟现实和增强现实有什么区别?", + "机器翻译是如何工作的?", + "Which framework do you think is the most suitable for performance sensitive projects?", + "What is post-quantum cryptography", + "What is a key derivation function?", + "What is Linear Algebra?", + "What is the main use of linear algebra in computer science", + "which IDE should I use for Go", + "Go vs Java vs Koltin, which for a backend", + "Which programming language is best suited for data analysis?", + "What is quantum computing?", + "What is a hash function?", + "What is calculus?", + "What are the main applications of machine learning in finance?", + "Which text editor is best for writing Python code?", + "Python vs R vs Julia, which is better for data science?", + "What are the key differences between supervised and unsupervised learning?", + "What is the role of a database in a web application?", + "What is blockchain technology?", + "What are the advantages of using Docker for application deployment?", + "Which cloud service provider offers the best AI tools?", + "How does encryption work?", + "What is the purpose of a load balancer in network architecture?", + "What is the difference between machine learning and deep learning?", + "What are the most common design patterns in software engineering?", + "How does a neural network learn?", + "What is the main benefit of using a microservices architecture?", + "What is the difference between a compiler and an interpreter?", + "What are the key features of the Rust programming language?", + "What is the difference between HTTP and HTTPS?", + "What are the advantages of using a version control system like Git?", + "What is the concept of 'edge computing'?", + "Which programming language is best for building mobile apps?", + "How does a relational database differ from a NoSQL database?", + "What is the importance of algorithms in computer science?", + "What is the role of an API in software development?", + "What are the best practices for securing a web application?", + "What is the difference between virtual reality and augmented reality?", + "How does machine translation work?" + ], + "datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"] +} diff --git a/intention-classify/extract.ipynb b/intention-classify/extract.ipynb new file mode 100644 index 0000000..5fcbb2f --- /dev/null +++ b/intention-classify/extract.ipynb @@ -0,0 +1,199 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0d107178", + "metadata": {}, + "source": [ + "# Extract the Dictionary & Embedding\n", + "\n", + "Our model uses the dictionary and embedding layers from `Phi-3-mini-4k-instruct`, a pre-trained transformer model developed by Microsoft. Although our model architecture is not a transformer, we can still benefit from its tokenizer and embedding layers." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "222bdd54-c115-4845-b4e6-c63e4f9ac6b4", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer, AutoModel\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "13f1ed38-ad39-4e6a-8af7-65ace2d14f8d", + "metadata": {}, + "outputs": [], + "source": [ + "model_name=\"microsoft/Phi-3-mini-4k-instruct\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c1de25fc-e90a-425b-8520-3a57fa534b94", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1aeb02c7c8084b1eb1b8e3178882fd60", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00 dimension-reduced embedding\n", + "token_id_to_reduced_embedding = {token_id: reduced_embeddings[token_id] for token_id in range(len(vocab))}\n", + "\n", + "torch.save(token_id_to_reduced_embedding, \"token_id_to_reduced_embedding.pt\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e28a612c-7960-42e6-aa36-abd0732f404e", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "# Create dict of token to {token_id, reduced_embedding}\n", + "token_to_id = {}\n", + "for token, token_id in vocab.items():\n", + " token_to_id[token] = token_id\n", + "\n", + "# Save as JSON\n", + "with open(\"token_to_id.json\", \"w\") as f:\n", + " json.dump(token_to_id, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "7a466c54-7c55-4e84-957e-454ae35896ac", + "metadata": {}, + "outputs": [], + "source": [ + "import struct\n", + "with open(\"token_embeddings.bin\", \"wb\") as f:\n", + " for token_id in range(len(vocab)):\n", + " # Write token id (2 bytes)\n", + " f.write(struct.pack('H', token_id))\n", + " # Write embedding vector (128 float numbers)\n", + " f.write(struct.pack('128f', *reduced_embeddings[token_id]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/intention-classify/inference-js.src.md b/intention-classify/inference-js.src.md new file mode 100644 index 0000000..875c0b0 --- /dev/null +++ b/intention-classify/inference-js.src.md @@ -0,0 +1,179 @@ + + +# sparkastML intention classification + +###### package.json + +```json +{ + "type": "module", + "dependencies": { + "@types/node": "latest", + "@xenova/transformers": "^2.17.2", + "onnxruntime-web": "^1.19.0", + "tsx": "latest", + "typescript": "latest" + } +} + +``` + +###### tokenizer.ts + +```typescript +export type TokenDict = { [key: string]: number }; + +type TokenDict = { [key: string]: number }; + +function tokenize(query: string, tokenDict: TokenDict): number[] { + const tokenIds: number[] = []; + let index = 0; + + // Replace spaces with "▁" + query = "▁"+query.replace(/ /g, "▁"); + query = query.replace(/\n/g, "<0x0A>"); + + while (index < query.length) { + let bestToken = null; + let bestLength = 0; + + // Step 2: Find the longest token that matches the beginning of the remaining query + for (const token in tokenDict) { + if (query.startsWith(token, index) && token.length > bestLength) { + bestToken = token; + bestLength = token.length; + } + } + + if (bestToken) { + tokenIds.push(tokenDict[bestToken]); + index += bestLength; + } else { + // Step 3: Handle the case where no token matches + const char = query[index]; + if (char.charCodeAt(0) <= 127) { + // If the character is ASCII, and it doesn't match any token, treat it as an unknown token + throw new Error(`Unknown token: ${char}`); + } else { + // If the character is non-ASCII, convert it to a series of bytes and match each byte + const bytes = new TextEncoder().encode(char); + for (const byte of bytes) { + const byteToken = `<0x${byte.toString(16).toUpperCase()}>`; + if (tokenDict[byteToken] !== undefined) { + tokenIds.push(tokenDict[byteToken]); + } else { + throw new Error(`Unknown byte token: ${byteToken}`); + } + } + } + index += 1; + } + } + + return tokenIds; +} + +export default tokenize +``` + +###### embedding.ts + +```typescript +import * as fs from 'fs'; +import * as path from 'path'; + +type EmbeddingDict = { [key: number]: Float32Array }; + +function getEmbeddingLayer(buffer: Buffer): EmbeddingDict { + const dict: EmbeddingDict = {}; + + const entrySize = 514; + const numEntries = buffer.length / entrySize; + + for (let i = 0; i < numEntries; i++) { + const offset = i * entrySize; + const key = buffer.readUInt16LE(offset); + const floatArray = new Float32Array(128); + + for (let j = 0; j < 128; j++) { + floatArray[j] = buffer.readFloatLE(offset + 2 + j * 4); + } + + dict[key] = floatArray; + } + + return dict; +} + +function getEmbedding(tokenIds: number[], embeddingDict: EmbeddingDict, contextSize: number) { + let result = []; + for (let i = 0; i < contextSize; i++) { + if (i < tokenIds.length) { + const tokenId = tokenIds[i]; + result = result.concat(Array.from(embeddingDict[tokenId])) + } + else { + result = result.concat(new Array(128).fill(0)) + } + } + return new Float32Array(result); +} + +export {getEmbeddingLayer, getEmbedding}; +``` + +###### load.ts + +```typescript +import * as ort from 'onnxruntime-web'; +import * as fs from 'fs'; +import tokenize, {TokenDict} from "./tokenizer.ts" +import {getEmbeddingLayer, getEmbedding} from "./embedding.ts" + +const embedding_file = './token_embeddings.bin'; +const embedding_data = fs.readFileSync(embedding_file); +const embedding_buffer = Buffer.from(embedding_data); +const query = `Will it rain tomorrow`; +const model_path = './model.onnx'; +const vocabData = fs.readFileSync('./token_to_id.json'); +const vocabDict = JSON.parse(vocabData.toString()); + +let lastLogCall = new Date().getTime(); + +function log(task: string) { + const currentTime = new Date().getTime(); + const costTime = currentTime - lastLogCall; + console.log(`[${currentTime}] (+${costTime}ms) ${task}`) + lastLogCall = new Date().getTime(); +} + +async function loadModel(modelPath: string) { + const session = await ort.InferenceSession.create(modelPath); + return session; +} + +async function runInference(query: string, embedding_buffer: Buffer, modelPath: string, vocabDict: TokenDict) { + const session = await loadModel(modelPath); + log("loadModel:end"); + const inputText = query; + const queryLength = query.length; + const tokenIds = await tokenize(query, vocabDict); + log("tokenize:end"); + const embeddingDict = getEmbeddingLayer(embedding_buffer); + const e = getEmbedding(tokenIds, embeddingDict, 12); + log("getEmbedding:end"); + + const inputTensor = new ort.Tensor('float32', e, [1, 12, 128]); + + const feeds = { 'input': inputTensor }; + const results = await session.run(feeds); + log("inference:end"); + + const output = results.output.data; + const predictedClassIndex = output.indexOf(Math.max(...output)); + + return output; +} + +console.log("Perdicted class:", await runInference(query, embedding_buffer, model_path, vocabDict)); +``` diff --git a/intention-classify/noise.json b/intention-classify/noise.json new file mode 100644 index 0000000..dfc2666 --- /dev/null +++ b/intention-classify/noise.json @@ -0,0 +1,8 @@ +[ + "你好", + "你好谢谢小笼包再见", + "我爱你", + "嘿嘿嘿诶嘿", + "为什么", + "拼多多" +] \ No newline at end of file diff --git a/intention-classify/train.ipynb b/intention-classify/train.ipynb new file mode 100644 index 0000000..3fea561 --- /dev/null +++ b/intention-classify/train.ipynb @@ -0,0 +1,575 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a6a3195f-d099-4bf4-846f-51f403954818", + "metadata": {}, + "source": [ + "# sparkastML: Training the Intention Classification Model\n", + "\n", + "This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n", + "\n", + "In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import torch\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "from sklearn.model_selection import train_test_split\n", + "from transformers import AutoTokenizer, AutoModel\n", + "import torch\n", + "import numpy as np\n", + "from scipy.spatial.distance import euclidean\n", + "from scipy.stats import weibull_min\n", + "from sklearn.preprocessing import normalize\n", + "import torch.nn.functional as F\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + } + ], + "source": [ + "model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n", + "DIMENSIONS = 128\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)" + ] + }, + { + "cell_type": "markdown", + "id": "1ae14906-338d-4c99-87ed-bb1acd22b295", + "metadata": {}, + "source": [ + "## Load Data\n", + "\n", + "We load the data from `data.json`, and also get the negative sample from the `noise.json`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a206071c-ce4e-4de4-b936-bfc70d13708a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n", + " embeddings = torch.tensor(embeddings)\n" + ] + } + ], + "source": [ + "# Load data\n", + "with open('data.json', 'r') as f:\n", + " data = json.load(f)\n", + "\n", + "# Create map: class to index\n", + "class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n", + "idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n", + "\n", + "# Preprocess data, convert sentences to the format of (class idx, embedding)\n", + "def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n", + " dataset = []\n", + " for label, sentences in data.items():\n", + " for sentence in sentences:\n", + " # Tokenize the sentence and convert tokens to embedding vectors\n", + " tokens = tokenizer.tokenize(sentence)\n", + " token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", + " embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n", + " embeddings = torch.tensor(embeddings)\n", + " dataset.append((class_to_idx[label], embeddings))\n", + " return dataset\n", + "\n", + "# Load embedding map\n", + "embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n", + "\n", + "# Get preprocessed dataset\n", + "dataset = preprocess_data(data, embedding_map, tokenizer)\n", + "\n", + "# Train-test split\n", + "train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n", + "\n", + "class TextDataset(Dataset):\n", + " def __init__(self, data):\n", + " self.data = data\n", + "\n", + " def __len__(self):\n", + " return len(self.data)\n", + "\n", + " def __getitem__(self, idx):\n", + " return self.data[idx]\n", + "\n", + " def collate_fn(self, batch):\n", + " labels, embeddings = zip(*batch)\n", + " labels = torch.tensor(labels)\n", + " embeddings = pad_sequence(embeddings, batch_first=True)\n", + " return labels, embeddings\n", + "\n", + "train_dataset = TextDataset(train_data)\n", + "val_dataset = TextDataset(val_data)\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n", + "val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "class NegativeSampleDataset(Dataset):\n", + " def __init__(self, negative_samples):\n", + " \"\"\"\n", + " negative_samples: List or array of negative sample embeddings or raw text\n", + " \"\"\"\n", + " self.samples = negative_samples\n", + " \n", + " def __len__(self):\n", + " return len(self.samples)\n", + " \n", + " def __getitem__(self, idx):\n", + " return self.samples[idx]\n", + "\n", + " def collate_fn(self, batch):\n", + " embeddings = pad_sequence(batch, batch_first=True)\n", + " return embeddings\n", + "\n", + "with open('noise.json', 'r') as f:\n", + " negative_samples_list = json.load(f)\n", + "\n", + "negative_embedding_list = []\n", + "\n", + "for sentence in negative_samples_list:\n", + " tokens = tokenizer.tokenize(sentence)\n", + " token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", + " embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n", + " embeddings = torch.tensor(embeddings)\n", + " negative_embedding_list.append(embeddings)\n", + "\n", + "negative_dataset = NegativeSampleDataset(negative_embedding_list)\n", + "negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n" + ] + }, + { + "cell_type": "markdown", + "id": "600febe4-2484-4aad-90a1-2bc821fdce1a", + "metadata": {}, + "source": [ + "## Implementating the Model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "adf624ac-ad63-437b-95f6-b02b7253b91e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class TextCNN(nn.Module):\n", + " def __init__(self, input_dim, num_classes):\n", + " super(TextCNN, self).__init__()\n", + " self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n", + " self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n", + " self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n", + " \n", + " self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n", + " self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n", + " self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n", + " \n", + " self.dropout = nn.Dropout(0.5)\n", + " self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n", + "\n", + " def forward(self, x):\n", + " x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n", + " \n", + " x1 = F.relu(self.bn1(self.conv1(x)))\n", + " x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n", + " \n", + " x2 = F.relu(self.bn2(self.conv2(x)))\n", + " x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n", + " \n", + " x3 = F.relu(self.bn3(self.conv3(x)))\n", + " x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n", + " \n", + " x = torch.cat((x1, x2, x3), dim=1)\n", + " x = self.dropout(x)\n", + " x = self.fc(x)\n", + " return x\n", + "\n", + "# Initialize model\n", + "input_dim = DIMENSIONS\n", + "num_classes = len(class_to_idx)\n", + "model = TextCNN(input_dim, num_classes)\n" + ] + }, + { + "cell_type": "markdown", + "id": "2750e17d-8a60-40c7-851b-1a567d0ee82b", + "metadata": {}, + "source": [ + "## Energy-based Models" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb", + "metadata": {}, + "outputs": [], + "source": [ + "def energy_score(logits):\n", + " # Energy score is minus logsumexp\n", + " return -torch.logsumexp(logits, dim=1)\n", + "\n", + "def generate_noise(batch_size, seq_length ,input_dim, device):\n", + " # Generate a Gaussian noise\n", + " return torch.randn(batch_size, seq_length, input_dim).to(device)\n" + ] + }, + { + "cell_type": "markdown", + "id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [1/50], Loss: 12.5108\n", + "Epoch [2/50], Loss: 10.7305\n", + "Epoch [3/50], Loss: 10.2943\n", + "Epoch [4/50], Loss: 9.9350\n", + "Epoch [5/50], Loss: 9.7991\n", + "Epoch [6/50], Loss: 9.6443\n", + "Epoch [7/50], Loss: 9.4762\n", + "Epoch [8/50], Loss: 9.4637\n", + "Epoch [9/50], Loss: 9.3025\n", + "Epoch [10/50], Loss: 9.1719\n", + "Epoch [11/50], Loss: 9.0632\n", + "Epoch [12/50], Loss: 8.9741\n", + "Epoch [13/50], Loss: 8.8487\n", + "Epoch [14/50], Loss: 8.6565\n", + "Epoch [15/50], Loss: 8.5830\n", + "Epoch [16/50], Loss: 8.4196\n", + "Epoch [17/50], Loss: 8.2319\n", + "Epoch [18/50], Loss: 8.0655\n", + "Epoch [19/50], Loss: 7.7140\n", + "Epoch [20/50], Loss: 7.6921\n", + "Epoch [21/50], Loss: 7.3375\n", + "Epoch [22/50], Loss: 7.2297\n", + "Epoch [23/50], Loss: 6.8833\n", + "Epoch [24/50], Loss: 6.8534\n", + "Epoch [25/50], Loss: 6.4557\n", + "Epoch [26/50], Loss: 6.1365\n", + "Epoch [27/50], Loss: 5.8558\n", + "Epoch [28/50], Loss: 5.5030\n", + "Epoch [29/50], Loss: 5.1604\n", + "Epoch [30/50], Loss: 4.7742\n", + "Epoch [31/50], Loss: 4.5958\n", + "Epoch [32/50], Loss: 4.0713\n", + "Epoch [33/50], Loss: 3.8872\n", + "Epoch [34/50], Loss: 3.5240\n", + "Epoch [35/50], Loss: 3.3115\n", + "Epoch [36/50], Loss: 2.5667\n", + "Epoch [37/50], Loss: 2.6709\n", + "Epoch [38/50], Loss: 1.8075\n", + "Epoch [39/50], Loss: 1.6654\n", + "Epoch [40/50], Loss: 0.4622\n", + "Epoch [41/50], Loss: 0.4719\n", + "Epoch [42/50], Loss: -0.4037\n", + "Epoch [43/50], Loss: -0.9405\n", + "Epoch [44/50], Loss: -1.7204\n", + "Epoch [45/50], Loss: -2.4124\n", + "Epoch [46/50], Loss: -3.0032\n", + "Epoch [47/50], Loss: -2.7123\n", + "Epoch [48/50], Loss: -3.6953\n", + "Epoch [49/50], Loss: -3.7212\n", + "Epoch [50/50], Loss: -3.7558\n" + ] + } + ], + "source": [ + "import torch.optim as optim\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(), lr=8e-4)\n", + "\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "import tensorboard\n", + "writer = SummaryWriter()\n", + "\n", + "def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n", + " model.train()\n", + " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + " model.to(device)\n", + " \n", + " negative_iter = iter(negative_loader)\n", + " \n", + " for epoch in range(num_epochs):\n", + " total_loss = 0\n", + " for batch_idx, (labels, embeddings) in enumerate(train_loader):\n", + " embeddings = embeddings.to(device)\n", + " labels = labels.to(device)\n", + " \n", + " batch_size = embeddings.size(0)\n", + " \n", + " # ---------------------\n", + " # 1. Positive sample\n", + " # ---------------------\n", + " optimizer.zero_grad()\n", + " outputs = model(embeddings) # logits from the model\n", + " \n", + " class_loss = criterion(outputs, labels)\n", + " \n", + " # Energy of positive sample\n", + " known_energy = energy_score(outputs)\n", + " energy_loss_known = known_energy.mean()\n", + " \n", + " # ------------------------------------\n", + " # 2. Negative sample - Random Noise\n", + " # ------------------------------------\n", + " noise_embeddings = torch.randn_like(embeddings).to(device)\n", + " noise_outputs = model(noise_embeddings)\n", + " noise_energy = energy_score(noise_outputs)\n", + " energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n", + " \n", + " # ------------------------------------\n", + " # 3. Negative sample - custom corpus\n", + " # ------------------------------------\n", + " \n", + " try:\n", + " negative_samples = next(negative_iter)\n", + " except StopIteration:\n", + " negative_iter = iter(negative_loader)\n", + " negative_samples = next(negative_iter)\n", + " negative_samples = negative_samples.to(device)\n", + " negative_outputs = model(negative_samples)\n", + " negative_energy = energy_score(negative_outputs)\n", + " energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n", + " \n", + " # -----------------------------\n", + " # 4. Overall Loss calculation\n", + " # -----------------------------\n", + " total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n", + " total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n", + "\n", + " writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n", + " writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n", + " writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n", + " \n", + " total_loss_batch.backward()\n", + " optimizer.step()\n", + " \n", + " total_loss += total_loss_batch.item()\n", + " \n", + " avg_loss = total_loss / len(train_loader)\n", + " print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n", + "\n", + "train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n", + "writer.flush()\n" + ] + }, + { + "cell_type": "markdown", + "id": "e6d29558-f497-4033-8488-169bd25ce881", + "metadata": {}, + "source": [ + "## Evalutation" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "472702e6-db4a-4faa-9e92-7510e6eacbb1", + "metadata": {}, + "outputs": [], + "source": [ + "ENERGY_THRESHOLD = -3" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: 0.9315\n", + "Precision: 1.0000\n", + "Recall: 0.9254\n", + "F1 Score: 0.9612\n" + ] + } + ], + "source": [ + "from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n", + "\n", + "def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n", + " model.eval()\n", + " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + " \n", + " all_preds = []\n", + " all_labels = []\n", + " \n", + " # Evaluate positive sample\n", + " with torch.no_grad():\n", + " for labels, embeddings in known_loader:\n", + " embeddings = embeddings.to(device)\n", + " logits = model(embeddings)\n", + " energy = energy_score(logits)\n", + " \n", + " preds = (energy <= energy_threshold).long()\n", + " all_preds.extend(preds.cpu().numpy())\n", + " all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n", + " \n", + " # Evaluate negative sample\n", + " with torch.no_grad():\n", + " for embeddings in unknown_loader:\n", + " embeddings = embeddings.to(device)\n", + " logits = model(embeddings)\n", + " energy = energy_score(logits)\n", + " \n", + " preds = (energy <= energy_threshold).long()\n", + " all_preds.extend(preds.cpu().numpy())\n", + " all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n", + " \n", + " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n", + " accuracy = accuracy_score(all_labels, all_preds)\n", + "\n", + " print(f'Accuracy: {accuracy:.4f}')\n", + " print(f'Precision: {precision:.4f}')\n", + " print(f'Recall: {recall:.4f}')\n", + " print(f'F1 Score: {f1:.4f}')\n", + "\n", + "evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "ba614054-75e1-4f61-ace5-aeb11e29a222", + "metadata": {}, + "outputs": [], + "source": [ + "# Save the model\n", + "torch.save(model, \"model.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c", + "metadata": {}, + "source": [ + "## Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "03928d75-81c8-4298-ab8a-d7f8a758b561", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n" + ] + } + ], + "source": [ + "def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n", + " model.eval()\n", + " tokens = tokenizer.tokenize(sentence)\n", + " token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", + " embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n", + " embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n", + " \n", + " with torch.no_grad():\n", + " logits = model(embeddings)\n", + " probabilities = F.softmax(logits, dim=1)\n", + " max_prob, predicted = torch.max(probabilities, 1)\n", + " \n", + " # Calculate energy score\n", + " energy = energy_score(logits)\n", + "\n", + " # If energy > threshold, consider the input as unknown class\n", + " if energy.item() > energy_threshold:\n", + " return [\"Unknown\", max_prob.item(), energy.item()]\n", + " else:\n", + " return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n", + "\n", + "# Example usage:\n", + "sentence = \"weather today\"\n", + "energy_threshold = ENERGY_THRESHOLD\n", + "predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n", + "print(f'Predicted: {predicted}')\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}