ref: the intention-classification model

This commit is contained in:
alikia2x (寒寒) 2024-09-22 03:58:56 +08:00
parent 66cf093177
commit 9f071ee0a0
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
10 changed files with 783 additions and 655 deletions

View File

@ -36,6 +36,7 @@
"室外的温度是多少",
"达拉斯今天热不热",
"苏州现在天气怎么样",
"明天悉尼会下雨吗?",
"how's the weather",
"What's going on with the weather?",
"Can you give me an update on the weather?",
@ -48,21 +49,21 @@
"What's the weather like right now?",
"Tell me the current weather conditions.",
"How about the weather today?",
"What's the weather looking like for the next few hours?",
"Is it going to stay this way all day?",
"Could you give me a brief overview of the weather?",
"What's the general weather situation in our area?",
"Is it cloudy or clear outside?",
"What's the forecast saying for today's weather?",
"Is it going to be a warm day?",
"Are we expecting any storms today?",
"What's the weather condition outside my window?",
"Is it a typical day for this season in terms of weather?",
"how's the weather now?",
"What's the temperature like right now?",
"Can you tell me the current temperature?",
"How hot is it outside?",
"What's the temperature supposed to be today?",
"What's the weather looking like for the next few hours",
"Is it going to stay this way all day",
"Could you give me a brief overview of the weather",
"What's the general weather situation in our area",
"Is it cloudy or clear outside",
"What's the forecast saying for today's weather",
"Is it going to be a warm day",
"Are we expecting any storms today",
"What's the weather condition outside my window",
"Is it a typical day for this season in terms of weather",
"how's the weather now",
"What's the temperature like right now",
"Can you tell me the current temperature",
"How hot is it outside",
"What's the temperature supposed to be today",
"What is the current temp outside?",
"Could you tell me the outdoor temperature?",
"Is it cold or warm outside?",
@ -81,8 +82,8 @@
"Can you tell me the temp in the nearby area?",
"Is it below freezing outside?",
"What's the average temperature for today?",
"Is the temperature dropping or rising?",
"What should I wear considering the temperature?"
"Is the temperature dropping or rising",
"What should I wear considering the temperature"
],
"base64": [
"请将数据使用base64编码",
@ -110,17 +111,16 @@
"解码 base64",
"Please encode this data with base64:",
"I need to encode the following data in base64",
"Could you encode this string using base64?",
"Could you encode this string using base64",
"Convert this data to b64 encoding",
"I want to encode this information with base64",
"Help me encode this in base32",
"Can you encode this data to base64 format?",
"Can you encode this data to base64 format",
"b64 encode",
"base64 encode",
"encode base64",
"base 64 encode online"
],
"url-encode": [
"编码 url",
"URL部分需要编码",
@ -145,7 +145,6 @@
"url decoder",
"URL encoder"
],
"html-encode": [
"请编码HTML实体",
"文本转为HTML实体",
@ -186,7 +185,6 @@
"html &nbsp conversion",
"html nbsp meaning"
],
"ai.command": [
"写一个TypeScript的HelloWorld代码",
"检查以下内容的语法和清晰度",
@ -237,11 +235,11 @@
"help me learn chinese",
"how to let the screen reader automatically focused to an newly poped up element in the web development",
"summarize following text:",
"Is there anything wrong with this code or can it be simplified?",
"Is there anything wrong with this code or can it be simplified",
"generate a Python script that prints 'Hello, World!'",
"Can you proofread this essay for grammar and punctuation errors?",
"Can you proofread this essay for grammar and punctuation errors",
"Create a list of ten example sentences for the word 'serendipity.'",
"Can you reformat this JSON to be more readable?",
"Can you reformat this JSON to be more readable",
"Suggest a creative title for my blog post about healthy eating.",
"Refactor this JavaScript function to make it more efficient.",
"Help me practice French: provide a sentence with a missing word that I can guess.",
@ -249,15 +247,15 @@
"Summarize this news article for me.",
"Can you review this code snippet for potential security vulnerabilities?",
"Generate a SQL query to find all users who signed up in the last 30 days.",
"Can you translate this paragraph into Spanish?",
"Can you translate this paragraph into Spanish",
"Create a flowchart based on the following process description.",
"Write a Python function to calculate the factorial of a number.",
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
"Can you optimize this image for faster loading on a website?",
"Can you optimize this image for faster loading on a website",
"Suggest some catchy taglines for a new mobile app focused on fitness.",
"Write a Bash script to back up my documents folder daily.",
"Help me draft an email to request a meeting with a potential client.",
"Can you convert this Markdown document into HTML?",
"Can you convert this Markdown document into HTML",
"Generate a Python script that scrapes data from a specified website.",
"Can you find the synonyms of the word 'meticulous'?",
"Write a SQL query to join two tables based on a common column.",
@ -267,31 +265,57 @@
"Can you assist me in learning Japanese?",
"How can I make an alert box appear when a user clicks a button on a webpage?",
"Summarize this research paper into bullet points.",
"Can you check if there are any logical errors in this algorithm?"
"Can you check if there are any logical errors in this algorithm?",
"请一步一步计算找到函数f(x)=U^2*x/(R+x)^2的顶点坐标。",
"如何理解transformer自注意力机制中的Q,K,V它们分别代表什么",
"帮我写一封求职信。先询问我的教育背景、技能和经验。",
"总结这篇论文",
"写一份10人晚宴的菜单",
"写一篇博客",
"写一段演讲稿"
],
"ai.question": [
"你认为哪个框架最适合性能敏感的项目?",
"knowledge": [
"什么是后量子密码学?",
"什么是密钥派生函数",
"什么是线性代数?",
"量子计算的特点是什么",
"哈希函数的作用?",
"什么是微积分?",
"什么是区块链技术",
"What is post-quantum cryptography",
"What is a key derivation function?",
"What is Linear Algebra?",
"What is the main use of linear algebra in computer science",
"What is quantum computing",
"What is a hash function",
"What is calculus",
"什么是站点隔离?",
"What is blockchain technology?",
"BLEU 是什么",
"黎巴嫩在哪",
"什么是转义字符",
"MixAlpha售价多少",
"什么是神经机器翻译",
"什么是月食",
"什么是人工智能",
"什么是F1-score"
],
"ai.question": [
"人工智能真的有智力吗",
"你认为哪个框架最适合性能敏感的项目?",
"线性代数在计算机科学中的主要用途是什么?",
"我应该使用哪个IDE来编写Go语言",
"Go vs Java vs Kotlin哪个适合后端",
"哪种编程语言最适合数据分析",
"什么是量子计算",
"什么是哈希函数?",
"什么是微积分?",
"机器学习在金融中的主要应用有哪些?",
"写Python代码最好的文本编辑器是哪个",
"Python vs R vs Julia哪个更适合数据科学",
"监督学习和无监督学习的关键区别是什么?",
"数据库在Web应用程序中的作用是什么",
"什么是区块链技术",
"使用Docker进行应用程序部署的优势是什么",
"哪个云服务提供商提供最好的AI工具",
"加密是如何工作的?",
"负载均衡器在网络架构中的目的是什么?",
"加密是如何工作的",
"负载均衡器在网络架构中的目的是什么",
"机器学习和深度学习有什么区别",
"软件工程中最常见的设计模式有哪些",
"神经网络是如何学习的",
@ -300,31 +324,22 @@
"Rust编程语言的关键特性是什么",
"HTTP和HTTPS有什么区别",
"使用像Git这样的版本控制系统有什么优势",
"什么是'边缘计算'的概念",
"哪种编程语言最适合构建移动应用?",
"关系数据库和NoSQL数据库有什么不同",
"算法在计算机科学中的重要性是什么",
"算法在计算机科学中的重要性是什么",
"API在软件开发中的作用是什么",
"保护Web应用程序的最佳实践是什么",
"虚拟现实和增强现实有什么区别?",
"机器翻译是如何工作的?",
"Which framework do you think is the most suitable for performance sensitive projects?",
"What is post-quantum cryptography",
"What is a key derivation function?",
"What is Linear Algebra?",
"What is the main use of linear algebra in computer science",
"which IDE should I use for Go",
"Go vs Java vs Koltin, which for a backend",
"Which programming language is best suited for data analysis?",
"What is quantum computing?",
"What is a hash function?",
"What is calculus?",
"What are the main applications of machine learning in finance?",
"Which text editor is best for writing Python code?",
"Python vs R vs Julia, which is better for data science?",
"What are the key differences between supervised and unsupervised learning?",
"What are the main applications of machine learning in finance",
"Which text editor is best for writing Python code",
"Python vs R vs Julia, which is better for data science",
"What are the key differences between supervised and unsupervised learning",
"What is the role of a database in a web application?",
"What is blockchain technology?",
"What are the advantages of using Docker for application deployment?",
"Which cloud service provider offers the best AI tools?",
"How does encryption work?",
@ -332,19 +347,20 @@
"What is the difference between machine learning and deep learning?",
"What are the most common design patterns in software engineering?",
"How does a neural network learn?",
"What is the main benefit of using a microservices architecture?",
"What is the difference between a compiler and an interpreter?",
"What are the key features of the Rust programming language?",
"What is the difference between HTTP and HTTPS?",
"What are the advantages of using a version control system like Git?",
"What is the concept of 'edge computing'?",
"Which programming language is best for building mobile apps?",
"How does a relational database differ from a NoSQL database?",
"What is the importance of algorithms in computer science?",
"What is the role of an API in software development?",
"What is the main benefit of using a microservices architecture",
"What is the difference between a compiler and an interpreter",
"What are the key features of the Rust programming language",
"What is the difference between HTTP and HTTPS",
"What are the advantages of using a version control system like Git",
"What is the concept of 'edge computing'",
"Which programming language is best for building mobile apps",
"How does a relational database differ from a NoSQL database",
"What is the importance of algorithms in computer science",
"What is the role of an API in software development",
"What are the best practices for securing a web application?",
"What is the difference between virtual reality and augmented reality?",
"How does machine translation work?"
"How does machine translation work?",
"MBTI有科学依据吗"
],
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
}

View File

@ -28,7 +28,7 @@
"metadata": {},
"outputs": [],
"source": [
"model_name=\"microsoft/Phi-3-mini-4k-instruct\""
"model_name=\"Qwen/Qwen2.5-3B\""
]
},
{
@ -37,17 +37,10 @@
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1aeb02c7c8084b1eb1b8e3178882fd60",
"model_id": "11caef0e1b674f6ab15880f3f25eca6a",
"version_major": 2,
"version_minor": 0
},
@ -95,7 +88,7 @@
"metadata": {},
"outputs": [],
"source": [
"DIMENSIONS = 128"
"DIMENSIONS = 96"
]
},
{
@ -168,11 +161,17 @@
"import struct\n",
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
" for token_id in range(len(vocab)):\n",
" # Write token id (2 bytes)\n",
" f.write(struct.pack('H', token_id))\n",
" # Write embedding vector (128 float numbers)\n",
" f.write(struct.pack('128f', *reduced_embeddings[token_id]))"
" # 将向量转换为半精度浮点数并保存\n",
" f.write(struct.pack('96e', *reduced_embeddings[token_id].astype(np.float16)))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "511a7cc4-1b8c-468c-b2a0-16dc6d74ab44",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -191,7 +190,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
"version": "3.10.14"
}
},
"nbformat": 4,

View File

@ -4,5 +4,350 @@
"我爱你",
"嘿嘿嘿诶嘿",
"为什么",
"拼多多"
]
"拼多多",
"machine translation",
"trustrank",
"中文词典",
"bin screen linux",
"\"TinyBERT",
"iconify",
"反义词 英文",
"referer",
"watchos uiscreen",
"张鑫旭",
"google first result",
"flutter text align center",
"ASR model",
"real time whisper",
"千樱凛",
"马嘉祺",
"flutter widget catalog",
"flutter BottomNavigationBar left",
"flutter tab indent vscode",
"react native 用 expo 吗",
"latest monorepo tool",
"\"vite\" \"abortController\" is not defined",
"vim comment lines",
"Error: unable to get issuer certificate",
"uuidv4",
"npm semver",
"react polyfill vite",
"vibrance",
"I can eat glass, it doesn't hurt me \"japanese\"",
"I can swallow glass without any harm to myself",
"copilot pricing",
"vim close window",
"sensors macos command",
"智乃",
"pypi wikipedia",
"tesseract macos m1",
"rag prompt template",
"英国 破产",
"bewlybewly",
"safari-web-extension-converter",
"starcoder",
"open source web search for ai",
"gpt4o mini tokenizer",
"gpt4o tokenizer",
"reverse dns lookup linux",
"online ping",
"termux",
"802.11 table",
"optimize",
"集群",
"chrome us",
"transflective",
"ielts toefl",
"react router",
"摇曳露营 萌娘百科",
"isrc",
"apple-system",
"-apple-system",
"css clip path animation",
"can i use relative path in og image",
"GitSora",
"matrix im",
"test your vocabulary",
"boarding pass",
"函数签名",
"类型谓词",
"barcode",
"智能",
"threejs 入门",
"南亚语系",
"linux user's computer be like",
"apple a16 显微图",
"dallas",
"恶魔 英文",
"Rime meaning",
"adobe media encoder macos download",
"mp4 transparency",
"webkit",
"chromium",
"献血",
"软件强制更新",
"If you dont agree with its politics views, Notepad+ + will add random characters in your source code.",
"Unmerged paths",
"字数统计",
"Use build.rollupOptions.output.manualChunks to improve chunking: https://rollupjs.org/configuration-options/#output-manualchunks",
"世界人权宣言",
"latex percent",
"chord in keyboard",
"Google is trying to kill the Open Web.",
"silo'd",
"swiftui 数组倒数访问",
"swiftui link to another view",
"fizzbuzz",
"AppDelegate watchos",
"Cannot find type 'UIApplicationDelegate' in scope",
"swiftui web image",
"spammer",
"swiftui text",
"钢琴",
"disable webgl chrome",
"online uuid",
"cp show progress",
"易容术",
"fulilian",
"cargo",
"wordle",
"mismatch",
"btc",
"squelch",
"psql show table structure",
"let padding don't effect when empty",
"take over the world meaning",
"brain teasers",
"Google flight API",
"square symbol",
"sill",
"nextjs layout per page",
"UA 550 umol/L",
"react production promotion page",
"jupyter notebook",
"wth meaning",
"glove词向量",
"google suggestion relevance",
"YouTube advertising income",
"PKI",
"next client only component",
"nextjs use client",
"nextjs docker tailwind not working",
"k8s",
"Logistic Regression",
"氯化钾注射死刑",
"icloud photo loss",
"芙宁娜 水上行走",
"vector design tool",
"netizen",
"framework or next js documentation",
"csync",
"next js",
"后量子正向保密",
"nip05",
"Sora技术原理",
"wasm效率",
"switch code",
"online IPA pronunciation",
"pnpm global adir",
"如何搜索",
"1999 抽卡期望",
"swiftui background blur",
"chrome macos fullscreen hide",
"中英文空格自动",
"ios 旁白 屏幕识别",
"ios 旁白 转子",
"http 404",
"yaml缩进",
"counter generator github",
"git 服务器提供远端仓库",
"ipfs companion",
"supervisor config",
"SSO",
"slot embedding",
"sql show tables",
"The request signature we calculated does not match the signature you provided. Check your Secret Access Key and signing method.",
"icloud.com,cn",
"VuePress",
"parser",
"stackoverflow statistics",
"sd xl",
"Rollup failed to resolve import \"workbox-precaching\" from",
"dep",
"Cannot find module estree-walker.js docker",
"nuxt run",
"base58解码",
"cga",
"vscode",
"vscode",
"silicon",
"macos m1 linux",
"预处理 后处理",
"is vp9 opensource",
"Alice Blu",
"失控玩家",
"kv数据库",
"redis 持久化",
"firefox disable outline",
"cd -2",
"IM application",
"2021国产电影",
"youtube chat overlay obs",
"obs add clock",
"Z is not defined nuxt",
"safari ios debug",
"safari debug",
"chat",
"nuxt plugin inject",
"twitch",
"obs 绿幕",
"gnupg",
"kde plasma wallpaper engine",
"Plasma",
"dns over https",
"localforage缺点",
"watchOS 10",
"noun of repeat",
"微信输入法",
"行业报告",
"keepass",
"platform",
"steam",
"java proxy",
"0 design",
"cefr word level list",
"precipitation meaning",
"international school of lausanne",
"Vim Uganda",
"抖音 推荐算法",
"Meta NNLO",
"windbg dump分析",
"web image fft",
"GPT-4 Pricing",
"GPT-4",
"Scala",
"tauri教程",
"asyncio.create_task用法",
"H5 滚动到底部",
"microsoft copilot",
"枫丹文字",
"brew pip",
"TS7016: Could not find a declaration file for module react .",
"fastapi websocket",
"kazv",
"The Type 孔雀计划",
"第一个图形操作系统",
"娱乐 诞生",
"ffmpeg 音频封面",
"Jean-Loup Gailly",
"Linux用户软件位置",
"\"ubuntu\" 平滑滚动",
"python range函数",
"KMP",
"sd 8gen2 GPU GFLOPS",
"mac语音输入法",
"openai translate",
"蔚蓝档案 初始抽卡",
"free custom domain email",
"洛天依",
"b站 频道页Tab 跳转",
"URL 重定向预览",
"计算机",
"sololearn",
"PoS机制 通俗解释",
"google search cost",
"bos s3",
"react 打包",
"useeffect 用法",
"ts 字典类型",
"vscode 字典单词自动补全插件",
"componentwillupdate",
"iPad Mini 2",
"use-immer",
"reducer 和 context",
"mint",
"Elementary OS",
"google科技新闻",
"iCloud mail \"\"-9002\"\"",
"氢氧化铁胶体制备",
"react native 视频处理",
"四川 2023 高考 复旦大学 分数线",
"哑铃弯举",
"m2 ultra",
"电池循环计数 site:apple.com",
"相机发明时间",
"冯诺依曼结构",
"哈佛架构",
"nodejs 后端",
"34.5M€ to CN¥",
"NLP 实体关注",
"monkey",
"react 快捷键监听",
"mac 好看的电子书阅读器",
"新闻",
"在线字体编辑器",
"ars technica",
"genshin 4.1 release time",
"swift device activity report",
"swiftui tabview background",
"swiftui text space",
"apple inc. wikipedia",
"how long does it take Google to return the results",
"云原神 web",
"支持homekit的空调",
"内核隔离",
"海祇岛解密",
"swiftui Textfield",
"xcode",
"qq 链接",
"M1 推出时间",
"USB-IF",
"nvchat",
"P1% FPS",
"react i18next 当前语言",
"js 获取语言",
"MulType",
"b站平均使用时间",
"pip 阿里源",
"ip info",
"graphjet",
"金融思维",
"C#写入文件",
"Last Day Sinista M",
"在 系统 位置 xcode select 找 不 到 SDK",
"Error: Could not find a valid Xcode app bundle at '/Library/Developer/CommandLineTools'. Please update your Apple SDK location in Visual Studio's preferences (Projects > SDK Locations > Apple > Apple SDK). (UniBattery)",
".NET能做什么",
"could i give no tip ",
"miami university of ohio",
"方正颜宋",
"中文 标题字体",
"聚典平台",
"62 basic words for a language",
"procrastination meaning",
"Lingbe",
"娱乐至死",
"macOS 外接显示器渲染",
"白玉袖",
"SwiftUI入门",
"html插入其它网页",
"捆绑 小说",
"apple music 无损下载",
"一miumiu 赐予",
"macos markdown",
"safari 开发者工具",
"\"百合\" \"武侠\" \"国漫\"",
"epub 格式详解",
"chrome 隐藏滚动条",
"发宽空格",
"U+200A",
"无性人",
"Spotify",
"禾念",
"how to pronounce Lorem ipsum",
"言和为什么不是男孩子",
"浏览器主页",
"react",
"Tailwindcss react 扩展怎么用",
"Prettier 扩展怎么用",
"linter\""
]

View File

@ -1,575 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a6a3195f-d099-4bf4-846f-51f403954818",
"metadata": {},
"source": [
"# sparkastML: Training the Intention Classification Model\n",
"\n",
"This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n",
"\n",
"In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from sklearn.model_selection import train_test_split\n",
"from transformers import AutoTokenizer, AutoModel\n",
"import torch\n",
"import numpy as np\n",
"from scipy.spatial.distance import euclidean\n",
"from scipy.stats import weibull_min\n",
"from sklearn.preprocessing import normalize\n",
"import torch.nn.functional as F\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n",
"DIMENSIONS = 128\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
]
},
{
"cell_type": "markdown",
"id": "1ae14906-338d-4c99-87ed-bb1acd22b295",
"metadata": {},
"source": [
"## Load Data\n",
"\n",
"We load the data from `data.json`, and also get the negative sample from the `noise.json`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a206071c-ce4e-4de4-b936-bfc70d13708a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
" embeddings = torch.tensor(embeddings)\n"
]
}
],
"source": [
"# Load data\n",
"with open('data.json', 'r') as f:\n",
" data = json.load(f)\n",
"\n",
"# Create map: class to index\n",
"class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n",
"idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n",
"\n",
"# Preprocess data, convert sentences to the format of (class idx, embedding)\n",
"def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n",
" dataset = []\n",
" for label, sentences in data.items():\n",
" for sentence in sentences:\n",
" # Tokenize the sentence and convert tokens to embedding vectors\n",
" tokens = tokenizer.tokenize(sentence)\n",
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
" embeddings = torch.tensor(embeddings)\n",
" dataset.append((class_to_idx[label], embeddings))\n",
" return dataset\n",
"\n",
"# Load embedding map\n",
"embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n",
"\n",
"# Get preprocessed dataset\n",
"dataset = preprocess_data(data, embedding_map, tokenizer)\n",
"\n",
"# Train-test split\n",
"train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n",
"\n",
"class TextDataset(Dataset):\n",
" def __init__(self, data):\n",
" self.data = data\n",
"\n",
" def __len__(self):\n",
" return len(self.data)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n",
"\n",
" def collate_fn(self, batch):\n",
" labels, embeddings = zip(*batch)\n",
" labels = torch.tensor(labels)\n",
" embeddings = pad_sequence(embeddings, batch_first=True)\n",
" return labels, embeddings\n",
"\n",
"train_dataset = TextDataset(train_data)\n",
"val_dataset = TextDataset(val_data)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n",
"val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"class NegativeSampleDataset(Dataset):\n",
" def __init__(self, negative_samples):\n",
" \"\"\"\n",
" negative_samples: List or array of negative sample embeddings or raw text\n",
" \"\"\"\n",
" self.samples = negative_samples\n",
" \n",
" def __len__(self):\n",
" return len(self.samples)\n",
" \n",
" def __getitem__(self, idx):\n",
" return self.samples[idx]\n",
"\n",
" def collate_fn(self, batch):\n",
" embeddings = pad_sequence(batch, batch_first=True)\n",
" return embeddings\n",
"\n",
"with open('noise.json', 'r') as f:\n",
" negative_samples_list = json.load(f)\n",
"\n",
"negative_embedding_list = []\n",
"\n",
"for sentence in negative_samples_list:\n",
" tokens = tokenizer.tokenize(sentence)\n",
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
" embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n",
" embeddings = torch.tensor(embeddings)\n",
" negative_embedding_list.append(embeddings)\n",
"\n",
"negative_dataset = NegativeSampleDataset(negative_embedding_list)\n",
"negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n"
]
},
{
"cell_type": "markdown",
"id": "600febe4-2484-4aad-90a1-2bc821fdce1a",
"metadata": {},
"source": [
"## Implementating the Model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "adf624ac-ad63-437b-95f6-b02b7253b91e",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class TextCNN(nn.Module):\n",
" def __init__(self, input_dim, num_classes):\n",
" super(TextCNN, self).__init__()\n",
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
" \n",
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
" \n",
" self.dropout = nn.Dropout(0.5)\n",
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
" \n",
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
" \n",
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
" \n",
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
" \n",
" x = torch.cat((x1, x2, x3), dim=1)\n",
" x = self.dropout(x)\n",
" x = self.fc(x)\n",
" return x\n",
"\n",
"# Initialize model\n",
"input_dim = DIMENSIONS\n",
"num_classes = len(class_to_idx)\n",
"model = TextCNN(input_dim, num_classes)\n"
]
},
{
"cell_type": "markdown",
"id": "2750e17d-8a60-40c7-851b-1a567d0ee82b",
"metadata": {},
"source": [
"## Energy-based Models"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb",
"metadata": {},
"outputs": [],
"source": [
"def energy_score(logits):\n",
" # Energy score is minus logsumexp\n",
" return -torch.logsumexp(logits, dim=1)\n",
"\n",
"def generate_noise(batch_size, seq_length ,input_dim, device):\n",
" # Generate a Gaussian noise\n",
" return torch.randn(batch_size, seq_length, input_dim).to(device)\n"
]
},
{
"cell_type": "markdown",
"id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/50], Loss: 12.5108\n",
"Epoch [2/50], Loss: 10.7305\n",
"Epoch [3/50], Loss: 10.2943\n",
"Epoch [4/50], Loss: 9.9350\n",
"Epoch [5/50], Loss: 9.7991\n",
"Epoch [6/50], Loss: 9.6443\n",
"Epoch [7/50], Loss: 9.4762\n",
"Epoch [8/50], Loss: 9.4637\n",
"Epoch [9/50], Loss: 9.3025\n",
"Epoch [10/50], Loss: 9.1719\n",
"Epoch [11/50], Loss: 9.0632\n",
"Epoch [12/50], Loss: 8.9741\n",
"Epoch [13/50], Loss: 8.8487\n",
"Epoch [14/50], Loss: 8.6565\n",
"Epoch [15/50], Loss: 8.5830\n",
"Epoch [16/50], Loss: 8.4196\n",
"Epoch [17/50], Loss: 8.2319\n",
"Epoch [18/50], Loss: 8.0655\n",
"Epoch [19/50], Loss: 7.7140\n",
"Epoch [20/50], Loss: 7.6921\n",
"Epoch [21/50], Loss: 7.3375\n",
"Epoch [22/50], Loss: 7.2297\n",
"Epoch [23/50], Loss: 6.8833\n",
"Epoch [24/50], Loss: 6.8534\n",
"Epoch [25/50], Loss: 6.4557\n",
"Epoch [26/50], Loss: 6.1365\n",
"Epoch [27/50], Loss: 5.8558\n",
"Epoch [28/50], Loss: 5.5030\n",
"Epoch [29/50], Loss: 5.1604\n",
"Epoch [30/50], Loss: 4.7742\n",
"Epoch [31/50], Loss: 4.5958\n",
"Epoch [32/50], Loss: 4.0713\n",
"Epoch [33/50], Loss: 3.8872\n",
"Epoch [34/50], Loss: 3.5240\n",
"Epoch [35/50], Loss: 3.3115\n",
"Epoch [36/50], Loss: 2.5667\n",
"Epoch [37/50], Loss: 2.6709\n",
"Epoch [38/50], Loss: 1.8075\n",
"Epoch [39/50], Loss: 1.6654\n",
"Epoch [40/50], Loss: 0.4622\n",
"Epoch [41/50], Loss: 0.4719\n",
"Epoch [42/50], Loss: -0.4037\n",
"Epoch [43/50], Loss: -0.9405\n",
"Epoch [44/50], Loss: -1.7204\n",
"Epoch [45/50], Loss: -2.4124\n",
"Epoch [46/50], Loss: -3.0032\n",
"Epoch [47/50], Loss: -2.7123\n",
"Epoch [48/50], Loss: -3.6953\n",
"Epoch [49/50], Loss: -3.7212\n",
"Epoch [50/50], Loss: -3.7558\n"
]
}
],
"source": [
"import torch.optim as optim\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=8e-4)\n",
"\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import tensorboard\n",
"writer = SummaryWriter()\n",
"\n",
"def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n",
" model.train()\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" model.to(device)\n",
" \n",
" negative_iter = iter(negative_loader)\n",
" \n",
" for epoch in range(num_epochs):\n",
" total_loss = 0\n",
" for batch_idx, (labels, embeddings) in enumerate(train_loader):\n",
" embeddings = embeddings.to(device)\n",
" labels = labels.to(device)\n",
" \n",
" batch_size = embeddings.size(0)\n",
" \n",
" # ---------------------\n",
" # 1. Positive sample\n",
" # ---------------------\n",
" optimizer.zero_grad()\n",
" outputs = model(embeddings) # logits from the model\n",
" \n",
" class_loss = criterion(outputs, labels)\n",
" \n",
" # Energy of positive sample\n",
" known_energy = energy_score(outputs)\n",
" energy_loss_known = known_energy.mean()\n",
" \n",
" # ------------------------------------\n",
" # 2. Negative sample - Random Noise\n",
" # ------------------------------------\n",
" noise_embeddings = torch.randn_like(embeddings).to(device)\n",
" noise_outputs = model(noise_embeddings)\n",
" noise_energy = energy_score(noise_outputs)\n",
" energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n",
" \n",
" # ------------------------------------\n",
" # 3. Negative sample - custom corpus\n",
" # ------------------------------------\n",
" \n",
" try:\n",
" negative_samples = next(negative_iter)\n",
" except StopIteration:\n",
" negative_iter = iter(negative_loader)\n",
" negative_samples = next(negative_iter)\n",
" negative_samples = negative_samples.to(device)\n",
" negative_outputs = model(negative_samples)\n",
" negative_energy = energy_score(negative_outputs)\n",
" energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n",
" \n",
" # -----------------------------\n",
" # 4. Overall Loss calculation\n",
" # -----------------------------\n",
" total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n",
" total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n",
"\n",
" writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n",
" writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n",
" writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n",
" \n",
" total_loss_batch.backward()\n",
" optimizer.step()\n",
" \n",
" total_loss += total_loss_batch.item()\n",
" \n",
" avg_loss = total_loss / len(train_loader)\n",
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n",
"\n",
"train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n",
"writer.flush()\n"
]
},
{
"cell_type": "markdown",
"id": "e6d29558-f497-4033-8488-169bd25ce881",
"metadata": {},
"source": [
"## Evalutation"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "472702e6-db4a-4faa-9e92-7510e6eacbb1",
"metadata": {},
"outputs": [],
"source": [
"ENERGY_THRESHOLD = -3"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9315\n",
"Precision: 1.0000\n",
"Recall: 0.9254\n",
"F1 Score: 0.9612\n"
]
}
],
"source": [
"from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n",
"\n",
"def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n",
" model.eval()\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" \n",
" all_preds = []\n",
" all_labels = []\n",
" \n",
" # Evaluate positive sample\n",
" with torch.no_grad():\n",
" for labels, embeddings in known_loader:\n",
" embeddings = embeddings.to(device)\n",
" logits = model(embeddings)\n",
" energy = energy_score(logits)\n",
" \n",
" preds = (energy <= energy_threshold).long()\n",
" all_preds.extend(preds.cpu().numpy())\n",
" all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n",
" \n",
" # Evaluate negative sample\n",
" with torch.no_grad():\n",
" for embeddings in unknown_loader:\n",
" embeddings = embeddings.to(device)\n",
" logits = model(embeddings)\n",
" energy = energy_score(logits)\n",
" \n",
" preds = (energy <= energy_threshold).long()\n",
" all_preds.extend(preds.cpu().numpy())\n",
" all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n",
" \n",
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
" accuracy = accuracy_score(all_labels, all_preds)\n",
"\n",
" print(f'Accuracy: {accuracy:.4f}')\n",
" print(f'Precision: {precision:.4f}')\n",
" print(f'Recall: {recall:.4f}')\n",
" print(f'F1 Score: {f1:.4f}')\n",
"\n",
"evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "ba614054-75e1-4f61-ace5-aeb11e29a222",
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"torch.save(model, \"model.pt\")"
]
},
{
"cell_type": "markdown",
"id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c",
"metadata": {},
"source": [
"## Inference"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "03928d75-81c8-4298-ab8a-d7f8a758b561",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n"
]
}
],
"source": [
"def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n",
" model.eval()\n",
" tokens = tokenizer.tokenize(sentence)\n",
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
" embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n",
" \n",
" with torch.no_grad():\n",
" logits = model(embeddings)\n",
" probabilities = F.softmax(logits, dim=1)\n",
" max_prob, predicted = torch.max(probabilities, 1)\n",
" \n",
" # Calculate energy score\n",
" energy = energy_score(logits)\n",
"\n",
" # If energy > threshold, consider the input as unknown class\n",
" if energy.item() > energy_threshold:\n",
" return [\"Unknown\", max_prob.item(), energy.item()]\n",
" else:\n",
" return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n",
"\n",
"# Example usage:\n",
"sentence = \"weather today\"\n",
"energy_threshold = ENERGY_THRESHOLD\n",
"predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n",
"print(f'Predicted: {predicted}')\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,4 @@
# config.py
model_name = "Qwen/Qwen2.5-3B"
DIMENSIONS = 96

View File

@ -0,0 +1,64 @@
# data_utils.py
import json
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
def load_data(file_path):
with open(file_path, "r") as f:
data = json.load(f)
return data
def create_class_mappings(data):
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
return class_to_idx, idx_to_class
def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64):
dataset = []
for label, sentences in data.items():
for sentence in sentences:
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
embeddings = [
embedding_map[token_id] for token_id in token_ids[:max_length]
]
embeddings = torch.tensor(embeddings)
dataset.append((class_to_idx[label], embeddings))
return dataset
class TextDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(self, batch):
labels, embeddings = zip(*batch)
labels = torch.tensor(labels)
embeddings = pad_sequence(embeddings, batch_first=True)
return labels, embeddings
class NegativeSampleDataset(Dataset):
def __init__(self, negative_samples):
self.samples = negative_samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def collate_fn(self, batch):
embeddings = pad_sequence(batch, batch_first=True)
return embeddings

View File

@ -0,0 +1,50 @@
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from training.config import DIMENSIONS
class SelfAttention(nn.Module):
def __init__(self, input_dim, heads):
super(SelfAttention, self).__init__()
self.heads = heads
self.scale = (input_dim // heads) ** -0.5
self.qkv = nn.Linear(input_dim, input_dim * 3)
self.fc = nn.Linear(input_dim, input_dim)
def forward(self, x):
batch_size, seq_length, embedding_dim = x.shape
qkv = self.qkv(x).view(
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
)
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
attention_output = torch.matmul(attn_weights, v)
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
return self.fc(attention_output)
class AttentionBasedModel(nn.Module):
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512):
super(AttentionBasedModel, self).__init__()
self.self_attention = SelfAttention(input_dim, heads)
self.fc1 = nn.Linear(input_dim, dim_feedforward)
self.fc2 = nn.Linear(dim_feedforward, num_classes)
self.dropout = nn.Dropout(0.5)
self.norm = nn.LayerNorm(input_dim)
def forward(self, x):
attn_output = self.self_attention(x)
attn_output = self.norm(attn_output + x)
pooled_output = torch.mean(attn_output, dim=1)
x = F.relu(self.fc1(pooled_output))
x = self.dropout(x)
x = self.fc2(x)
return x

View File

@ -0,0 +1,149 @@
# train.py
from sklearn.model_selection import train_test_split
import torch
import json
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from training.data_utils import (
load_data,
create_class_mappings,
preprocess_data,
TextDataset,
NegativeSampleDataset,
)
from training.model import AttentionBasedModel
import torch.nn as nn
def energy_score(logits, temperature=1.0):
return -torch.logsumexp(logits / temperature, dim=1)
def generate_noise(batch_size, seq_length, input_dim, device):
return torch.randn(batch_size, seq_length, input_dim).to(device)
def train_energy_model(
model,
train_loader,
negative_loader,
criterion,
optimizer,
num_epochs=10,
margin=1.0,
temperature=0.4,
):
model.train()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
negative_iter = iter(negative_loader)
for epoch in range(num_epochs):
total_loss = 0
for _, (labels, embeddings) in enumerate(train_loader):
embeddings = embeddings.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(embeddings)
class_loss = criterion(outputs, labels)
known_energy = energy_score(outputs, temperature)
positive_margin = 0.0
energy_loss_known = F.relu(known_energy - positive_margin).mean()
noise_embeddings = torch.randn_like(embeddings).to(device)
noise_outputs = model(noise_embeddings)
noise_energy = energy_score(noise_outputs, temperature)
energy_loss_noise = F.relu(margin - noise_energy).mean()
try:
negative_samples = next(negative_iter)
except StopIteration:
negative_iter = iter(negative_loader)
negative_samples = next(negative_iter)
negative_samples = negative_samples.to(device)
negative_outputs = model(negative_samples)
negative_energy = energy_score(negative_outputs, temperature)
energy_loss_negative = F.relu(margin - negative_energy).mean()
total_energy_loss = (
energy_loss_known + energy_loss_noise + energy_loss_negative
)
total_loss_batch = class_loss + total_energy_loss
total_loss_batch.backward()
optimizer.step()
total_loss += total_loss_batch.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
def main():
from config import model_name, DIMENSIONS
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
data = load_data("data.json")
class_to_idx, _ = create_class_mappings(data)
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
dataset = preprocess_data(data, embedding_map, tokenizer, class_to_idx)
train_data, _ = train_test_split(dataset, test_size=0.2)
train_dataset = TextDataset(train_data)
train_loader = DataLoader(
train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn
)
with open("noise.json", "r") as f:
negative_samples_list = json.load(f)
negative_embedding_list = []
for sentence in negative_samples_list:
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]
embeddings = torch.tensor(embeddings)
negative_embedding_list.append(embeddings)
negative_dataset = NegativeSampleDataset(negative_embedding_list)
negative_loader = DataLoader(
negative_dataset,
batch_size=24,
shuffle=True,
collate_fn=negative_dataset.collate_fn,
)
input_dim = DIMENSIONS
num_classes = len(class_to_idx)
model = AttentionBasedModel(input_dim, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=7e-4)
train_energy_model(
model, train_loader, negative_loader, criterion, optimizer, num_epochs=120
)
torch.save(model.state_dict(), "model.pt")
dummy_input = torch.randn(1, 64, DIMENSIONS)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "seq_length"},
"output": {0: "batch_size"},
},
opset_version=11,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,75 @@
from training.model import AttentionBasedModel
from training.config import model_name
import json
import torch
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
from training.config import DIMENSIONS
from training.model import AttentionBasedModel
def energy_score(logits):
# Energy score is minus logsumexp
return -torch.logsumexp(logits, dim=1)
def predict_with_energy(
model,
sentence,
embedding_map,
tokenizer,
idx_to_class,
energy_threshold,
max_length=64,
):
model.eval()
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
current_shape = embeddings.shape
if current_shape[1] < 2:
pad_size = 2 - current_shape[1]
embeddings = F.pad(
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
)
with torch.no_grad():
logits = model(embeddings)
print(logits)
probabilities = F.softmax(logits, dim=1)
max_prob, predicted = torch.max(probabilities, 1)
# Calculate energy score
energy = energy_score(logits)
# If energy > threshold, consider the input as unknown class
if energy.item() > energy_threshold:
return ["Unknown", max_prob.item(), energy.item()]
else:
return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]
with open("data.json", "r") as f:
data = json.load(f)
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
num_classes = len(class_to_idx)
input_dim = DIMENSIONS
model = AttentionBasedModel(input_dim, num_classes)
model.load_state_dict(torch.load("./model.pt"))
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Example usage:
ENERGY_THRESHOLD = 0
sentence = "天气"
energy_threshold = ENERGY_THRESHOLD
predicted = predict_with_energy(
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
)
print(f"Predicted: {predicted}")

View File

@ -24,6 +24,7 @@ cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
# Load the model and tokenizer
model_name = 'jinaai/jina-embeddings-v2-base-zh'
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
model.to('cuda')
# Define file paths from command-line arguments
file_a_path = args.file_a