ref: the intention-classification model
This commit is contained in:
parent
66cf093177
commit
9f071ee0a0
@ -36,6 +36,7 @@
|
|||||||
"室外的温度是多少",
|
"室外的温度是多少",
|
||||||
"达拉斯今天热不热",
|
"达拉斯今天热不热",
|
||||||
"苏州现在天气怎么样",
|
"苏州现在天气怎么样",
|
||||||
|
"明天悉尼会下雨吗?",
|
||||||
"how's the weather",
|
"how's the weather",
|
||||||
"What's going on with the weather?",
|
"What's going on with the weather?",
|
||||||
"Can you give me an update on the weather?",
|
"Can you give me an update on the weather?",
|
||||||
@ -48,21 +49,21 @@
|
|||||||
"What's the weather like right now?",
|
"What's the weather like right now?",
|
||||||
"Tell me the current weather conditions.",
|
"Tell me the current weather conditions.",
|
||||||
"How about the weather today?",
|
"How about the weather today?",
|
||||||
"What's the weather looking like for the next few hours?",
|
"What's the weather looking like for the next few hours",
|
||||||
"Is it going to stay this way all day?",
|
"Is it going to stay this way all day",
|
||||||
"Could you give me a brief overview of the weather?",
|
"Could you give me a brief overview of the weather",
|
||||||
"What's the general weather situation in our area?",
|
"What's the general weather situation in our area",
|
||||||
"Is it cloudy or clear outside?",
|
"Is it cloudy or clear outside",
|
||||||
"What's the forecast saying for today's weather?",
|
"What's the forecast saying for today's weather",
|
||||||
"Is it going to be a warm day?",
|
"Is it going to be a warm day",
|
||||||
"Are we expecting any storms today?",
|
"Are we expecting any storms today",
|
||||||
"What's the weather condition outside my window?",
|
"What's the weather condition outside my window",
|
||||||
"Is it a typical day for this season in terms of weather?",
|
"Is it a typical day for this season in terms of weather",
|
||||||
"how's the weather now?",
|
"how's the weather now",
|
||||||
"What's the temperature like right now?",
|
"What's the temperature like right now",
|
||||||
"Can you tell me the current temperature?",
|
"Can you tell me the current temperature",
|
||||||
"How hot is it outside?",
|
"How hot is it outside",
|
||||||
"What's the temperature supposed to be today?",
|
"What's the temperature supposed to be today",
|
||||||
"What is the current temp outside?",
|
"What is the current temp outside?",
|
||||||
"Could you tell me the outdoor temperature?",
|
"Could you tell me the outdoor temperature?",
|
||||||
"Is it cold or warm outside?",
|
"Is it cold or warm outside?",
|
||||||
@ -81,8 +82,8 @@
|
|||||||
"Can you tell me the temp in the nearby area?",
|
"Can you tell me the temp in the nearby area?",
|
||||||
"Is it below freezing outside?",
|
"Is it below freezing outside?",
|
||||||
"What's the average temperature for today?",
|
"What's the average temperature for today?",
|
||||||
"Is the temperature dropping or rising?",
|
"Is the temperature dropping or rising",
|
||||||
"What should I wear considering the temperature?"
|
"What should I wear considering the temperature"
|
||||||
],
|
],
|
||||||
"base64": [
|
"base64": [
|
||||||
"请将数据使用base64编码",
|
"请将数据使用base64编码",
|
||||||
@ -110,17 +111,16 @@
|
|||||||
"解码 base64",
|
"解码 base64",
|
||||||
"Please encode this data with base64:",
|
"Please encode this data with base64:",
|
||||||
"I need to encode the following data in base64",
|
"I need to encode the following data in base64",
|
||||||
"Could you encode this string using base64?",
|
"Could you encode this string using base64",
|
||||||
"Convert this data to b64 encoding",
|
"Convert this data to b64 encoding",
|
||||||
"I want to encode this information with base64",
|
"I want to encode this information with base64",
|
||||||
"Help me encode this in base32",
|
"Help me encode this in base32",
|
||||||
"Can you encode this data to base64 format?",
|
"Can you encode this data to base64 format",
|
||||||
"b64 encode",
|
"b64 encode",
|
||||||
"base64 encode",
|
"base64 encode",
|
||||||
"encode base64",
|
"encode base64",
|
||||||
"base 64 encode online"
|
"base 64 encode online"
|
||||||
],
|
],
|
||||||
|
|
||||||
"url-encode": [
|
"url-encode": [
|
||||||
"编码 url",
|
"编码 url",
|
||||||
"URL部分需要编码",
|
"URL部分需要编码",
|
||||||
@ -145,7 +145,6 @@
|
|||||||
"url decoder",
|
"url decoder",
|
||||||
"URL encoder"
|
"URL encoder"
|
||||||
],
|
],
|
||||||
|
|
||||||
"html-encode": [
|
"html-encode": [
|
||||||
"请编码HTML实体",
|
"请编码HTML实体",
|
||||||
"文本转为HTML实体",
|
"文本转为HTML实体",
|
||||||
@ -186,7 +185,6 @@
|
|||||||
"html   conversion",
|
"html   conversion",
|
||||||
"html nbsp meaning"
|
"html nbsp meaning"
|
||||||
],
|
],
|
||||||
|
|
||||||
"ai.command": [
|
"ai.command": [
|
||||||
"写一个TypeScript的HelloWorld代码",
|
"写一个TypeScript的HelloWorld代码",
|
||||||
"检查以下内容的语法和清晰度",
|
"检查以下内容的语法和清晰度",
|
||||||
@ -237,11 +235,11 @@
|
|||||||
"help me learn chinese",
|
"help me learn chinese",
|
||||||
"how to let the screen reader automatically focused to an newly poped up element in the web development",
|
"how to let the screen reader automatically focused to an newly poped up element in the web development",
|
||||||
"summarize following text:",
|
"summarize following text:",
|
||||||
"Is there anything wrong with this code or can it be simplified?",
|
"Is there anything wrong with this code or can it be simplified",
|
||||||
"generate a Python script that prints 'Hello, World!'",
|
"generate a Python script that prints 'Hello, World!'",
|
||||||
"Can you proofread this essay for grammar and punctuation errors?",
|
"Can you proofread this essay for grammar and punctuation errors",
|
||||||
"Create a list of ten example sentences for the word 'serendipity.'",
|
"Create a list of ten example sentences for the word 'serendipity.'",
|
||||||
"Can you reformat this JSON to be more readable?",
|
"Can you reformat this JSON to be more readable",
|
||||||
"Suggest a creative title for my blog post about healthy eating.",
|
"Suggest a creative title for my blog post about healthy eating.",
|
||||||
"Refactor this JavaScript function to make it more efficient.",
|
"Refactor this JavaScript function to make it more efficient.",
|
||||||
"Help me practice French: provide a sentence with a missing word that I can guess.",
|
"Help me practice French: provide a sentence with a missing word that I can guess.",
|
||||||
@ -249,15 +247,15 @@
|
|||||||
"Summarize this news article for me.",
|
"Summarize this news article for me.",
|
||||||
"Can you review this code snippet for potential security vulnerabilities?",
|
"Can you review this code snippet for potential security vulnerabilities?",
|
||||||
"Generate a SQL query to find all users who signed up in the last 30 days.",
|
"Generate a SQL query to find all users who signed up in the last 30 days.",
|
||||||
"Can you translate this paragraph into Spanish?",
|
"Can you translate this paragraph into Spanish",
|
||||||
"Create a flowchart based on the following process description.",
|
"Create a flowchart based on the following process description.",
|
||||||
"Write a Python function to calculate the factorial of a number.",
|
"Write a Python function to calculate the factorial of a number.",
|
||||||
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
|
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
|
||||||
"Can you optimize this image for faster loading on a website?",
|
"Can you optimize this image for faster loading on a website",
|
||||||
"Suggest some catchy taglines for a new mobile app focused on fitness.",
|
"Suggest some catchy taglines for a new mobile app focused on fitness.",
|
||||||
"Write a Bash script to back up my documents folder daily.",
|
"Write a Bash script to back up my documents folder daily.",
|
||||||
"Help me draft an email to request a meeting with a potential client.",
|
"Help me draft an email to request a meeting with a potential client.",
|
||||||
"Can you convert this Markdown document into HTML?",
|
"Can you convert this Markdown document into HTML",
|
||||||
"Generate a Python script that scrapes data from a specified website.",
|
"Generate a Python script that scrapes data from a specified website.",
|
||||||
"Can you find the synonyms of the word 'meticulous'?",
|
"Can you find the synonyms of the word 'meticulous'?",
|
||||||
"Write a SQL query to join two tables based on a common column.",
|
"Write a SQL query to join two tables based on a common column.",
|
||||||
@ -267,31 +265,57 @@
|
|||||||
"Can you assist me in learning Japanese?",
|
"Can you assist me in learning Japanese?",
|
||||||
"How can I make an alert box appear when a user clicks a button on a webpage?",
|
"How can I make an alert box appear when a user clicks a button on a webpage?",
|
||||||
"Summarize this research paper into bullet points.",
|
"Summarize this research paper into bullet points.",
|
||||||
"Can you check if there are any logical errors in this algorithm?"
|
"Can you check if there are any logical errors in this algorithm?",
|
||||||
|
"请一步一步计算找到函数f(x)=U^2*x/(R+x)^2的顶点坐标。",
|
||||||
|
"如何理解transformer自注意力机制中的Q,K,V?它们分别代表什么?",
|
||||||
|
"帮我写一封求职信。先询问我的教育背景、技能和经验。",
|
||||||
|
"总结这篇论文",
|
||||||
|
"写一份10人晚宴的菜单",
|
||||||
|
"写一篇博客",
|
||||||
|
"写一段演讲稿"
|
||||||
],
|
],
|
||||||
|
"knowledge": [
|
||||||
"ai.question": [
|
|
||||||
"你认为哪个框架最适合性能敏感的项目?",
|
|
||||||
"什么是后量子密码学?",
|
"什么是后量子密码学?",
|
||||||
"什么是密钥派生函数",
|
"什么是密钥派生函数",
|
||||||
"什么是线性代数?",
|
"什么是线性代数?",
|
||||||
|
"量子计算的特点是什么",
|
||||||
|
"哈希函数的作用?",
|
||||||
|
"什么是微积分?",
|
||||||
|
"什么是区块链技术",
|
||||||
|
"What is post-quantum cryptography",
|
||||||
|
"What is a key derivation function?",
|
||||||
|
"What is Linear Algebra?",
|
||||||
|
"What is the main use of linear algebra in computer science",
|
||||||
|
"What is quantum computing",
|
||||||
|
"What is a hash function",
|
||||||
|
"What is calculus",
|
||||||
|
"什么是站点隔离?",
|
||||||
|
"What is blockchain technology?",
|
||||||
|
"BLEU 是什么",
|
||||||
|
"黎巴嫩在哪",
|
||||||
|
"什么是转义字符",
|
||||||
|
"MixAlpha售价多少",
|
||||||
|
"什么是神经机器翻译",
|
||||||
|
"什么是月食",
|
||||||
|
"什么是人工智能",
|
||||||
|
"什么是F1-score"
|
||||||
|
],
|
||||||
|
"ai.question": [
|
||||||
|
"人工智能真的有智力吗",
|
||||||
|
"你认为哪个框架最适合性能敏感的项目?",
|
||||||
"线性代数在计算机科学中的主要用途是什么?",
|
"线性代数在计算机科学中的主要用途是什么?",
|
||||||
"我应该使用哪个IDE来编写Go语言?",
|
"我应该使用哪个IDE来编写Go语言?",
|
||||||
"Go vs Java vs Kotlin,哪个适合后端",
|
"Go vs Java vs Kotlin,哪个适合后端",
|
||||||
"哪种编程语言最适合数据分析",
|
"哪种编程语言最适合数据分析",
|
||||||
"什么是量子计算",
|
|
||||||
"什么是哈希函数?",
|
|
||||||
"什么是微积分?",
|
|
||||||
"机器学习在金融中的主要应用有哪些?",
|
"机器学习在金融中的主要应用有哪些?",
|
||||||
"写Python代码最好的文本编辑器是哪个?",
|
"写Python代码最好的文本编辑器是哪个?",
|
||||||
"Python vs R vs Julia,哪个更适合数据科学?",
|
"Python vs R vs Julia,哪个更适合数据科学?",
|
||||||
"监督学习和无监督学习的关键区别是什么?",
|
"监督学习和无监督学习的关键区别是什么?",
|
||||||
"数据库在Web应用程序中的作用是什么",
|
"数据库在Web应用程序中的作用是什么",
|
||||||
"什么是区块链技术",
|
|
||||||
"使用Docker进行应用程序部署的优势是什么?",
|
"使用Docker进行应用程序部署的优势是什么?",
|
||||||
"哪个云服务提供商提供最好的AI工具?",
|
"哪个云服务提供商提供最好的AI工具?",
|
||||||
"加密是如何工作的?",
|
"加密是如何工作的",
|
||||||
"负载均衡器在网络架构中的目的是什么?",
|
"负载均衡器在网络架构中的目的是什么",
|
||||||
"机器学习和深度学习有什么区别",
|
"机器学习和深度学习有什么区别",
|
||||||
"软件工程中最常见的设计模式有哪些",
|
"软件工程中最常见的设计模式有哪些",
|
||||||
"神经网络是如何学习的",
|
"神经网络是如何学习的",
|
||||||
@ -300,31 +324,22 @@
|
|||||||
"Rust编程语言的关键特性是什么?",
|
"Rust编程语言的关键特性是什么?",
|
||||||
"HTTP和HTTPS有什么区别",
|
"HTTP和HTTPS有什么区别",
|
||||||
"使用像Git这样的版本控制系统有什么优势?",
|
"使用像Git这样的版本控制系统有什么优势?",
|
||||||
"什么是'边缘计算'的概念",
|
|
||||||
"哪种编程语言最适合构建移动应用?",
|
"哪种编程语言最适合构建移动应用?",
|
||||||
"关系数据库和NoSQL数据库有什么不同?",
|
"关系数据库和NoSQL数据库有什么不同?",
|
||||||
"算法在计算机科学中的重要性是什么?",
|
"算法在计算机科学中的重要性是什么",
|
||||||
"API在软件开发中的作用是什么",
|
"API在软件开发中的作用是什么",
|
||||||
"保护Web应用程序的最佳实践是什么",
|
"保护Web应用程序的最佳实践是什么",
|
||||||
"虚拟现实和增强现实有什么区别?",
|
"虚拟现实和增强现实有什么区别?",
|
||||||
"机器翻译是如何工作的?",
|
"机器翻译是如何工作的?",
|
||||||
"Which framework do you think is the most suitable for performance sensitive projects?",
|
"Which framework do you think is the most suitable for performance sensitive projects?",
|
||||||
"What is post-quantum cryptography",
|
|
||||||
"What is a key derivation function?",
|
|
||||||
"What is Linear Algebra?",
|
|
||||||
"What is the main use of linear algebra in computer science",
|
|
||||||
"which IDE should I use for Go",
|
"which IDE should I use for Go",
|
||||||
"Go vs Java vs Koltin, which for a backend",
|
"Go vs Java vs Koltin, which for a backend",
|
||||||
"Which programming language is best suited for data analysis?",
|
"Which programming language is best suited for data analysis?",
|
||||||
"What is quantum computing?",
|
"What are the main applications of machine learning in finance",
|
||||||
"What is a hash function?",
|
"Which text editor is best for writing Python code",
|
||||||
"What is calculus?",
|
"Python vs R vs Julia, which is better for data science",
|
||||||
"What are the main applications of machine learning in finance?",
|
"What are the key differences between supervised and unsupervised learning",
|
||||||
"Which text editor is best for writing Python code?",
|
|
||||||
"Python vs R vs Julia, which is better for data science?",
|
|
||||||
"What are the key differences between supervised and unsupervised learning?",
|
|
||||||
"What is the role of a database in a web application?",
|
"What is the role of a database in a web application?",
|
||||||
"What is blockchain technology?",
|
|
||||||
"What are the advantages of using Docker for application deployment?",
|
"What are the advantages of using Docker for application deployment?",
|
||||||
"Which cloud service provider offers the best AI tools?",
|
"Which cloud service provider offers the best AI tools?",
|
||||||
"How does encryption work?",
|
"How does encryption work?",
|
||||||
@ -332,19 +347,20 @@
|
|||||||
"What is the difference between machine learning and deep learning?",
|
"What is the difference between machine learning and deep learning?",
|
||||||
"What are the most common design patterns in software engineering?",
|
"What are the most common design patterns in software engineering?",
|
||||||
"How does a neural network learn?",
|
"How does a neural network learn?",
|
||||||
"What is the main benefit of using a microservices architecture?",
|
"What is the main benefit of using a microservices architecture",
|
||||||
"What is the difference between a compiler and an interpreter?",
|
"What is the difference between a compiler and an interpreter",
|
||||||
"What are the key features of the Rust programming language?",
|
"What are the key features of the Rust programming language",
|
||||||
"What is the difference between HTTP and HTTPS?",
|
"What is the difference between HTTP and HTTPS",
|
||||||
"What are the advantages of using a version control system like Git?",
|
"What are the advantages of using a version control system like Git",
|
||||||
"What is the concept of 'edge computing'?",
|
"What is the concept of 'edge computing'",
|
||||||
"Which programming language is best for building mobile apps?",
|
"Which programming language is best for building mobile apps",
|
||||||
"How does a relational database differ from a NoSQL database?",
|
"How does a relational database differ from a NoSQL database",
|
||||||
"What is the importance of algorithms in computer science?",
|
"What is the importance of algorithms in computer science",
|
||||||
"What is the role of an API in software development?",
|
"What is the role of an API in software development",
|
||||||
"What are the best practices for securing a web application?",
|
"What are the best practices for securing a web application?",
|
||||||
"What is the difference between virtual reality and augmented reality?",
|
"What is the difference between virtual reality and augmented reality?",
|
||||||
"How does machine translation work?"
|
"How does machine translation work?",
|
||||||
|
"MBTI有科学依据吗?"
|
||||||
],
|
],
|
||||||
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
|
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model_name=\"microsoft/Phi-3-mini-4k-instruct\""
|
"model_name=\"Qwen/Qwen2.5-3B\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -37,17 +37,10 @@
|
|||||||
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
|
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "1aeb02c7c8084b1eb1b8e3178882fd60",
|
"model_id": "11caef0e1b674f6ab15880f3f25eca6a",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@ -95,7 +88,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"DIMENSIONS = 128"
|
"DIMENSIONS = 96"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -168,11 +161,17 @@
|
|||||||
"import struct\n",
|
"import struct\n",
|
||||||
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
|
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
|
||||||
" for token_id in range(len(vocab)):\n",
|
" for token_id in range(len(vocab)):\n",
|
||||||
" # Write token id (2 bytes)\n",
|
" # 将向量转换为半精度浮点数并保存\n",
|
||||||
" f.write(struct.pack('H', token_id))\n",
|
" f.write(struct.pack('96e', *reduced_embeddings[token_id].astype(np.float16)))\n"
|
||||||
" # Write embedding vector (128 float numbers)\n",
|
|
||||||
" f.write(struct.pack('128f', *reduced_embeddings[token_id]))"
|
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "511a7cc4-1b8c-468c-b2a0-16dc6d74ab44",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -191,7 +190,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.19"
|
"version": "3.10.14"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -4,5 +4,350 @@
|
|||||||
"我爱你",
|
"我爱你",
|
||||||
"嘿嘿嘿诶嘿",
|
"嘿嘿嘿诶嘿",
|
||||||
"为什么",
|
"为什么",
|
||||||
"拼多多"
|
"拼多多",
|
||||||
|
"machine translation",
|
||||||
|
"trustrank",
|
||||||
|
"中文词典",
|
||||||
|
"bin screen linux",
|
||||||
|
"\"TinyBERT",
|
||||||
|
"iconify",
|
||||||
|
"反义词 英文",
|
||||||
|
"referer",
|
||||||
|
"watchos uiscreen",
|
||||||
|
"张鑫旭",
|
||||||
|
"google first result",
|
||||||
|
"flutter text align center",
|
||||||
|
"ASR model",
|
||||||
|
"real time whisper",
|
||||||
|
"千樱凛",
|
||||||
|
"马嘉祺",
|
||||||
|
"flutter widget catalog",
|
||||||
|
"flutter BottomNavigationBar left",
|
||||||
|
"flutter tab indent vscode",
|
||||||
|
"react native 用 expo 吗",
|
||||||
|
"latest monorepo tool",
|
||||||
|
"\"vite\" \"abortController\" is not defined",
|
||||||
|
"vim comment lines",
|
||||||
|
"Error: unable to get issuer certificate",
|
||||||
|
"uuidv4",
|
||||||
|
"npm semver",
|
||||||
|
"react polyfill vite",
|
||||||
|
"vibrance",
|
||||||
|
"I can eat glass, it doesn't hurt me \"japanese\"",
|
||||||
|
"I can swallow glass without any harm to myself",
|
||||||
|
"copilot pricing",
|
||||||
|
"vim close window",
|
||||||
|
"sensors macos command",
|
||||||
|
"智乃",
|
||||||
|
"pypi wikipedia",
|
||||||
|
"tesseract macos m1",
|
||||||
|
"rag prompt template",
|
||||||
|
"英国 破产",
|
||||||
|
"bewlybewly",
|
||||||
|
"safari-web-extension-converter",
|
||||||
|
"starcoder",
|
||||||
|
"open source web search for ai",
|
||||||
|
"gpt4o mini tokenizer",
|
||||||
|
"gpt4o tokenizer",
|
||||||
|
"reverse dns lookup linux",
|
||||||
|
"online ping",
|
||||||
|
"termux",
|
||||||
|
"802.11 table",
|
||||||
|
"optimize",
|
||||||
|
"集群",
|
||||||
|
"chrome us",
|
||||||
|
"transflective",
|
||||||
|
"ielts toefl",
|
||||||
|
"react router",
|
||||||
|
"摇曳露营 萌娘百科",
|
||||||
|
"isrc",
|
||||||
|
"apple-system",
|
||||||
|
"-apple-system",
|
||||||
|
"css clip path animation",
|
||||||
|
"can i use relative path in og image",
|
||||||
|
"GitSora",
|
||||||
|
"matrix im",
|
||||||
|
"test your vocabulary",
|
||||||
|
"boarding pass",
|
||||||
|
"函数签名",
|
||||||
|
"类型谓词",
|
||||||
|
"barcode",
|
||||||
|
"智能",
|
||||||
|
"threejs 入门",
|
||||||
|
"南亚语系",
|
||||||
|
"linux user's computer be like",
|
||||||
|
"apple a16 显微图",
|
||||||
|
"dallas",
|
||||||
|
"恶魔 英文",
|
||||||
|
"Rime meaning",
|
||||||
|
"adobe media encoder macos download",
|
||||||
|
"mp4 transparency",
|
||||||
|
"webkit",
|
||||||
|
"chromium",
|
||||||
|
"献血",
|
||||||
|
"软件强制更新",
|
||||||
|
"If you don’t agree with its politics views, Notepad+ + will add random characters in your source code.",
|
||||||
|
"Unmerged paths",
|
||||||
|
"字数统计",
|
||||||
|
"Use build.rollupOptions.output.manualChunks to improve chunking: https://rollupjs.org/configuration-options/#output-manualchunks",
|
||||||
|
"世界人权宣言",
|
||||||
|
"latex percent",
|
||||||
|
"chord in keyboard",
|
||||||
|
"Google is trying to kill the Open Web.",
|
||||||
|
"silo'd",
|
||||||
|
"swiftui 数组倒数访问",
|
||||||
|
"swiftui link to another view",
|
||||||
|
"fizzbuzz",
|
||||||
|
"AppDelegate watchos",
|
||||||
|
"Cannot find type 'UIApplicationDelegate' in scope",
|
||||||
|
"swiftui web image",
|
||||||
|
"spammer",
|
||||||
|
"swiftui text",
|
||||||
|
"钢琴",
|
||||||
|
"disable webgl chrome",
|
||||||
|
"online uuid",
|
||||||
|
"cp show progress",
|
||||||
|
"易容术",
|
||||||
|
"fulilian",
|
||||||
|
"cargo",
|
||||||
|
"wordle",
|
||||||
|
"mismatch",
|
||||||
|
"btc",
|
||||||
|
"squelch",
|
||||||
|
"psql show table structure",
|
||||||
|
"let padding don't effect when empty",
|
||||||
|
"take over the world meaning",
|
||||||
|
"brain teasers",
|
||||||
|
"Google flight API",
|
||||||
|
"square symbol",
|
||||||
|
"sill",
|
||||||
|
"nextjs layout per page",
|
||||||
|
"UA 550 umol/L",
|
||||||
|
"react production promotion page",
|
||||||
|
"jupyter notebook",
|
||||||
|
"wth meaning",
|
||||||
|
"glove词向量",
|
||||||
|
"google suggestion relevance",
|
||||||
|
"YouTube advertising income",
|
||||||
|
"PKI",
|
||||||
|
"next client only component",
|
||||||
|
"nextjs use client",
|
||||||
|
"nextjs docker tailwind not working",
|
||||||
|
"k8s",
|
||||||
|
"Logistic Regression",
|
||||||
|
"氯化钾注射死刑",
|
||||||
|
"icloud photo loss",
|
||||||
|
"芙宁娜 水上行走",
|
||||||
|
"vector design tool",
|
||||||
|
"netizen",
|
||||||
|
"framework or next js documentation",
|
||||||
|
"csync",
|
||||||
|
"next js",
|
||||||
|
"后量子正向保密",
|
||||||
|
"nip05",
|
||||||
|
"Sora技术原理",
|
||||||
|
"wasm效率",
|
||||||
|
"switch code",
|
||||||
|
"online IPA pronunciation",
|
||||||
|
"pnpm global adir",
|
||||||
|
"如何搜索",
|
||||||
|
"1999 抽卡期望",
|
||||||
|
"swiftui background blur",
|
||||||
|
"chrome macos fullscreen hide",
|
||||||
|
"中英文空格自动",
|
||||||
|
"ios 旁白 屏幕识别",
|
||||||
|
"ios 旁白 转子",
|
||||||
|
"http 404",
|
||||||
|
"yaml缩进",
|
||||||
|
"counter generator github",
|
||||||
|
"git 服务器提供远端仓库",
|
||||||
|
"ipfs companion",
|
||||||
|
"supervisor config",
|
||||||
|
"SSO",
|
||||||
|
"slot embedding",
|
||||||
|
"sql show tables",
|
||||||
|
"The request signature we calculated does not match the signature you provided. Check your Secret Access Key and signing method.",
|
||||||
|
"icloud.com,cn",
|
||||||
|
"VuePress",
|
||||||
|
"parser",
|
||||||
|
"stackoverflow statistics",
|
||||||
|
"sd xl",
|
||||||
|
"Rollup failed to resolve import \"workbox-precaching\" from",
|
||||||
|
"dep",
|
||||||
|
"Cannot find module estree-walker.js docker",
|
||||||
|
"nuxt run",
|
||||||
|
"base58解码",
|
||||||
|
"cga",
|
||||||
|
"vscode",
|
||||||
|
"vscode",
|
||||||
|
"silicon",
|
||||||
|
"macos m1 linux",
|
||||||
|
"预处理 后处理",
|
||||||
|
"is vp9 opensource",
|
||||||
|
"Alice Blu",
|
||||||
|
"失控玩家",
|
||||||
|
"kv数据库",
|
||||||
|
"redis 持久化",
|
||||||
|
"firefox disable outline",
|
||||||
|
"cd -2",
|
||||||
|
"IM application",
|
||||||
|
"2021国产电影",
|
||||||
|
"youtube chat overlay obs",
|
||||||
|
"obs add clock",
|
||||||
|
"Z is not defined nuxt",
|
||||||
|
"safari ios debug",
|
||||||
|
"safari debug",
|
||||||
|
"chat",
|
||||||
|
"nuxt plugin inject",
|
||||||
|
"twitch",
|
||||||
|
"obs 绿幕",
|
||||||
|
"gnupg",
|
||||||
|
"kde plasma wallpaper engine",
|
||||||
|
"Plasma",
|
||||||
|
"dns over https",
|
||||||
|
"localforage缺点",
|
||||||
|
"watchOS 10",
|
||||||
|
"noun of repeat",
|
||||||
|
"微信输入法",
|
||||||
|
"行业报告",
|
||||||
|
"keepass",
|
||||||
|
"platform",
|
||||||
|
"steam",
|
||||||
|
"java proxy",
|
||||||
|
"0 design",
|
||||||
|
"cefr word level list",
|
||||||
|
"precipitation meaning",
|
||||||
|
"international school of lausanne",
|
||||||
|
"Vim Uganda",
|
||||||
|
"抖音 推荐算法",
|
||||||
|
"Meta NNLO",
|
||||||
|
"windbg dump分析",
|
||||||
|
"web image fft",
|
||||||
|
"GPT-4 Pricing",
|
||||||
|
"GPT-4",
|
||||||
|
"Scala",
|
||||||
|
"tauri教程",
|
||||||
|
"asyncio.create_task用法",
|
||||||
|
"H5 滚动到底部",
|
||||||
|
"microsoft copilot",
|
||||||
|
"枫丹文字",
|
||||||
|
"brew pip",
|
||||||
|
"TS7016: Could not find a declaration file for module react .",
|
||||||
|
"fastapi websocket",
|
||||||
|
"kazv",
|
||||||
|
"The Type 孔雀计划",
|
||||||
|
"第一个图形操作系统",
|
||||||
|
"娱乐 诞生",
|
||||||
|
"ffmpeg 音频封面",
|
||||||
|
"Jean-Loup Gailly",
|
||||||
|
"Linux用户软件位置",
|
||||||
|
"\"ubuntu\" 平滑滚动",
|
||||||
|
"python range函数",
|
||||||
|
"KMP",
|
||||||
|
"sd 8gen2 GPU GFLOPS",
|
||||||
|
"mac语音输入法",
|
||||||
|
"openai translate",
|
||||||
|
"蔚蓝档案 初始抽卡",
|
||||||
|
"free custom domain email",
|
||||||
|
"洛天依",
|
||||||
|
"b站 频道页Tab 跳转",
|
||||||
|
"URL 重定向预览",
|
||||||
|
"计算机",
|
||||||
|
"sololearn",
|
||||||
|
"PoS机制 通俗解释",
|
||||||
|
"google search cost",
|
||||||
|
"bos s3",
|
||||||
|
"react 打包",
|
||||||
|
"useeffect 用法",
|
||||||
|
"ts 字典类型",
|
||||||
|
"vscode 字典单词自动补全插件",
|
||||||
|
"componentwillupdate",
|
||||||
|
"iPad Mini 2",
|
||||||
|
"use-immer",
|
||||||
|
"reducer 和 context",
|
||||||
|
"mint",
|
||||||
|
"Elementary OS",
|
||||||
|
"google科技新闻",
|
||||||
|
"iCloud mail \"\"-9002\"\"",
|
||||||
|
"氢氧化铁胶体制备",
|
||||||
|
"react native 视频处理",
|
||||||
|
"四川 2023 高考 复旦大学 分数线",
|
||||||
|
"哑铃弯举",
|
||||||
|
"m2 ultra",
|
||||||
|
"电池循环计数 site:apple.com",
|
||||||
|
"相机发明时间",
|
||||||
|
"冯诺依曼结构",
|
||||||
|
"哈佛架构",
|
||||||
|
"nodejs 后端",
|
||||||
|
"34.5M€ to CN¥",
|
||||||
|
"NLP 实体关注",
|
||||||
|
"monkey",
|
||||||
|
"react 快捷键监听",
|
||||||
|
"mac 好看的电子书阅读器",
|
||||||
|
"新闻",
|
||||||
|
"在线字体编辑器",
|
||||||
|
"ars technica",
|
||||||
|
"genshin 4.1 release time",
|
||||||
|
"swift device activity report",
|
||||||
|
"swiftui tabview background",
|
||||||
|
"swiftui text space",
|
||||||
|
"apple inc. wikipedia",
|
||||||
|
"how long does it take Google to return the results",
|
||||||
|
"云原神 web",
|
||||||
|
"支持homekit的空调",
|
||||||
|
"内核隔离",
|
||||||
|
"海祇岛解密",
|
||||||
|
"swiftui Textfield",
|
||||||
|
"xcode",
|
||||||
|
"qq 链接",
|
||||||
|
"M1 推出时间",
|
||||||
|
"USB-IF",
|
||||||
|
"nvchat",
|
||||||
|
"P1% FPS",
|
||||||
|
"react i18next 当前语言",
|
||||||
|
"js 获取语言",
|
||||||
|
"MulType",
|
||||||
|
"b站平均使用时间",
|
||||||
|
"pip 阿里源",
|
||||||
|
"ip info",
|
||||||
|
"graphjet",
|
||||||
|
"金融思维",
|
||||||
|
"C#写入文件",
|
||||||
|
"Last Day Sinista M",
|
||||||
|
"在 系统 位置 xcode select 找 不 到 SDK",
|
||||||
|
"Error: Could not find a valid Xcode app bundle at '/Library/Developer/CommandLineTools'. Please update your Apple SDK location in Visual Studio's preferences (Projects > SDK Locations > Apple > Apple SDK). (UniBattery)",
|
||||||
|
".NET能做什么",
|
||||||
|
"could i give no tip ",
|
||||||
|
"miami university of ohio",
|
||||||
|
"方正颜宋",
|
||||||
|
"中文 标题字体",
|
||||||
|
"聚典平台",
|
||||||
|
"62 basic words for a language",
|
||||||
|
"procrastination meaning",
|
||||||
|
"Lingbe",
|
||||||
|
"娱乐至死",
|
||||||
|
"macOS 外接显示器渲染",
|
||||||
|
"白玉袖",
|
||||||
|
"SwiftUI入门",
|
||||||
|
"html插入其它网页",
|
||||||
|
"捆绑 小说",
|
||||||
|
"apple music 无损下载",
|
||||||
|
"一miumiu 赐予",
|
||||||
|
"macos markdown",
|
||||||
|
"safari 开发者工具",
|
||||||
|
"\"百合\" \"武侠\" \"国漫\"",
|
||||||
|
"epub 格式详解",
|
||||||
|
"chrome 隐藏滚动条",
|
||||||
|
"发宽空格",
|
||||||
|
"U+200A",
|
||||||
|
"无性人",
|
||||||
|
"Spotify",
|
||||||
|
"禾念",
|
||||||
|
"how to pronounce Lorem ipsum",
|
||||||
|
"言和为什么不是男孩子",
|
||||||
|
"浏览器主页",
|
||||||
|
"react",
|
||||||
|
"Tailwindcss react 扩展怎么用",
|
||||||
|
"Prettier 扩展怎么用",
|
||||||
|
"linter\""
|
||||||
]
|
]
|
@ -1,575 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "a6a3195f-d099-4bf4-846f-51f403954818",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# sparkastML: Training the Intention Classification Model\n",
|
|
||||||
"\n",
|
|
||||||
"This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n",
|
|
||||||
"\n",
|
|
||||||
"In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import json\n",
|
|
||||||
"import torch\n",
|
|
||||||
"from torch.utils.data import Dataset, DataLoader\n",
|
|
||||||
"from torch.nn.utils.rnn import pad_sequence\n",
|
|
||||||
"from sklearn.model_selection import train_test_split\n",
|
|
||||||
"from transformers import AutoTokenizer, AutoModel\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"from scipy.spatial.distance import euclidean\n",
|
|
||||||
"from scipy.stats import weibull_min\n",
|
|
||||||
"from sklearn.preprocessing import normalize\n",
|
|
||||||
"import torch.nn.functional as F\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n",
|
|
||||||
"DIMENSIONS = 128\n",
|
|
||||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "1ae14906-338d-4c99-87ed-bb1acd22b295",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Load Data\n",
|
|
||||||
"\n",
|
|
||||||
"We load the data from `data.json`, and also get the negative sample from the `noise.json`."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "a206071c-ce4e-4de4-b936-bfc70d13708a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
|
|
||||||
" embeddings = torch.tensor(embeddings)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Load data\n",
|
|
||||||
"with open('data.json', 'r') as f:\n",
|
|
||||||
" data = json.load(f)\n",
|
|
||||||
"\n",
|
|
||||||
"# Create map: class to index\n",
|
|
||||||
"class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n",
|
|
||||||
"idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n",
|
|
||||||
"\n",
|
|
||||||
"# Preprocess data, convert sentences to the format of (class idx, embedding)\n",
|
|
||||||
"def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n",
|
|
||||||
" dataset = []\n",
|
|
||||||
" for label, sentences in data.items():\n",
|
|
||||||
" for sentence in sentences:\n",
|
|
||||||
" # Tokenize the sentence and convert tokens to embedding vectors\n",
|
|
||||||
" tokens = tokenizer.tokenize(sentence)\n",
|
|
||||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
|
||||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
|
||||||
" embeddings = torch.tensor(embeddings)\n",
|
|
||||||
" dataset.append((class_to_idx[label], embeddings))\n",
|
|
||||||
" return dataset\n",
|
|
||||||
"\n",
|
|
||||||
"# Load embedding map\n",
|
|
||||||
"embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n",
|
|
||||||
"\n",
|
|
||||||
"# Get preprocessed dataset\n",
|
|
||||||
"dataset = preprocess_data(data, embedding_map, tokenizer)\n",
|
|
||||||
"\n",
|
|
||||||
"# Train-test split\n",
|
|
||||||
"train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n",
|
|
||||||
"\n",
|
|
||||||
"class TextDataset(Dataset):\n",
|
|
||||||
" def __init__(self, data):\n",
|
|
||||||
" self.data = data\n",
|
|
||||||
"\n",
|
|
||||||
" def __len__(self):\n",
|
|
||||||
" return len(self.data)\n",
|
|
||||||
"\n",
|
|
||||||
" def __getitem__(self, idx):\n",
|
|
||||||
" return self.data[idx]\n",
|
|
||||||
"\n",
|
|
||||||
" def collate_fn(self, batch):\n",
|
|
||||||
" labels, embeddings = zip(*batch)\n",
|
|
||||||
" labels = torch.tensor(labels)\n",
|
|
||||||
" embeddings = pad_sequence(embeddings, batch_first=True)\n",
|
|
||||||
" return labels, embeddings\n",
|
|
||||||
"\n",
|
|
||||||
"train_dataset = TextDataset(train_data)\n",
|
|
||||||
"val_dataset = TextDataset(val_data)\n",
|
|
||||||
"\n",
|
|
||||||
"train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n",
|
|
||||||
"val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"from torch.utils.data import Dataset, DataLoader\n",
|
|
||||||
"\n",
|
|
||||||
"class NegativeSampleDataset(Dataset):\n",
|
|
||||||
" def __init__(self, negative_samples):\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" negative_samples: List or array of negative sample embeddings or raw text\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" self.samples = negative_samples\n",
|
|
||||||
" \n",
|
|
||||||
" def __len__(self):\n",
|
|
||||||
" return len(self.samples)\n",
|
|
||||||
" \n",
|
|
||||||
" def __getitem__(self, idx):\n",
|
|
||||||
" return self.samples[idx]\n",
|
|
||||||
"\n",
|
|
||||||
" def collate_fn(self, batch):\n",
|
|
||||||
" embeddings = pad_sequence(batch, batch_first=True)\n",
|
|
||||||
" return embeddings\n",
|
|
||||||
"\n",
|
|
||||||
"with open('noise.json', 'r') as f:\n",
|
|
||||||
" negative_samples_list = json.load(f)\n",
|
|
||||||
"\n",
|
|
||||||
"negative_embedding_list = []\n",
|
|
||||||
"\n",
|
|
||||||
"for sentence in negative_samples_list:\n",
|
|
||||||
" tokens = tokenizer.tokenize(sentence)\n",
|
|
||||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
|
||||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n",
|
|
||||||
" embeddings = torch.tensor(embeddings)\n",
|
|
||||||
" negative_embedding_list.append(embeddings)\n",
|
|
||||||
"\n",
|
|
||||||
"negative_dataset = NegativeSampleDataset(negative_embedding_list)\n",
|
|
||||||
"negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "600febe4-2484-4aad-90a1-2bc821fdce1a",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Implementating the Model"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "adf624ac-ad63-437b-95f6-b02b7253b91e",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"import torch.nn.functional as F\n",
|
|
||||||
"\n",
|
|
||||||
"class TextCNN(nn.Module):\n",
|
|
||||||
" def __init__(self, input_dim, num_classes):\n",
|
|
||||||
" super(TextCNN, self).__init__()\n",
|
|
||||||
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
|
|
||||||
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
|
|
||||||
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
|
|
||||||
" \n",
|
|
||||||
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
|
|
||||||
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
|
|
||||||
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
|
|
||||||
" \n",
|
|
||||||
" self.dropout = nn.Dropout(0.5)\n",
|
|
||||||
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
|
|
||||||
"\n",
|
|
||||||
" def forward(self, x):\n",
|
|
||||||
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
|
|
||||||
" \n",
|
|
||||||
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
|
|
||||||
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
|
|
||||||
" \n",
|
|
||||||
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
|
|
||||||
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
|
|
||||||
" \n",
|
|
||||||
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
|
|
||||||
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
|
|
||||||
" \n",
|
|
||||||
" x = torch.cat((x1, x2, x3), dim=1)\n",
|
|
||||||
" x = self.dropout(x)\n",
|
|
||||||
" x = self.fc(x)\n",
|
|
||||||
" return x\n",
|
|
||||||
"\n",
|
|
||||||
"# Initialize model\n",
|
|
||||||
"input_dim = DIMENSIONS\n",
|
|
||||||
"num_classes = len(class_to_idx)\n",
|
|
||||||
"model = TextCNN(input_dim, num_classes)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "2750e17d-8a60-40c7-851b-1a567d0ee82b",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Energy-based Models"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def energy_score(logits):\n",
|
|
||||||
" # Energy score is minus logsumexp\n",
|
|
||||||
" return -torch.logsumexp(logits, dim=1)\n",
|
|
||||||
"\n",
|
|
||||||
"def generate_noise(batch_size, seq_length ,input_dim, device):\n",
|
|
||||||
" # Generate a Gaussian noise\n",
|
|
||||||
" return torch.randn(batch_size, seq_length, input_dim).to(device)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Training"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65",
|
|
||||||
"metadata": {
|
|
||||||
"scrolled": true
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch [1/50], Loss: 12.5108\n",
|
|
||||||
"Epoch [2/50], Loss: 10.7305\n",
|
|
||||||
"Epoch [3/50], Loss: 10.2943\n",
|
|
||||||
"Epoch [4/50], Loss: 9.9350\n",
|
|
||||||
"Epoch [5/50], Loss: 9.7991\n",
|
|
||||||
"Epoch [6/50], Loss: 9.6443\n",
|
|
||||||
"Epoch [7/50], Loss: 9.4762\n",
|
|
||||||
"Epoch [8/50], Loss: 9.4637\n",
|
|
||||||
"Epoch [9/50], Loss: 9.3025\n",
|
|
||||||
"Epoch [10/50], Loss: 9.1719\n",
|
|
||||||
"Epoch [11/50], Loss: 9.0632\n",
|
|
||||||
"Epoch [12/50], Loss: 8.9741\n",
|
|
||||||
"Epoch [13/50], Loss: 8.8487\n",
|
|
||||||
"Epoch [14/50], Loss: 8.6565\n",
|
|
||||||
"Epoch [15/50], Loss: 8.5830\n",
|
|
||||||
"Epoch [16/50], Loss: 8.4196\n",
|
|
||||||
"Epoch [17/50], Loss: 8.2319\n",
|
|
||||||
"Epoch [18/50], Loss: 8.0655\n",
|
|
||||||
"Epoch [19/50], Loss: 7.7140\n",
|
|
||||||
"Epoch [20/50], Loss: 7.6921\n",
|
|
||||||
"Epoch [21/50], Loss: 7.3375\n",
|
|
||||||
"Epoch [22/50], Loss: 7.2297\n",
|
|
||||||
"Epoch [23/50], Loss: 6.8833\n",
|
|
||||||
"Epoch [24/50], Loss: 6.8534\n",
|
|
||||||
"Epoch [25/50], Loss: 6.4557\n",
|
|
||||||
"Epoch [26/50], Loss: 6.1365\n",
|
|
||||||
"Epoch [27/50], Loss: 5.8558\n",
|
|
||||||
"Epoch [28/50], Loss: 5.5030\n",
|
|
||||||
"Epoch [29/50], Loss: 5.1604\n",
|
|
||||||
"Epoch [30/50], Loss: 4.7742\n",
|
|
||||||
"Epoch [31/50], Loss: 4.5958\n",
|
|
||||||
"Epoch [32/50], Loss: 4.0713\n",
|
|
||||||
"Epoch [33/50], Loss: 3.8872\n",
|
|
||||||
"Epoch [34/50], Loss: 3.5240\n",
|
|
||||||
"Epoch [35/50], Loss: 3.3115\n",
|
|
||||||
"Epoch [36/50], Loss: 2.5667\n",
|
|
||||||
"Epoch [37/50], Loss: 2.6709\n",
|
|
||||||
"Epoch [38/50], Loss: 1.8075\n",
|
|
||||||
"Epoch [39/50], Loss: 1.6654\n",
|
|
||||||
"Epoch [40/50], Loss: 0.4622\n",
|
|
||||||
"Epoch [41/50], Loss: 0.4719\n",
|
|
||||||
"Epoch [42/50], Loss: -0.4037\n",
|
|
||||||
"Epoch [43/50], Loss: -0.9405\n",
|
|
||||||
"Epoch [44/50], Loss: -1.7204\n",
|
|
||||||
"Epoch [45/50], Loss: -2.4124\n",
|
|
||||||
"Epoch [46/50], Loss: -3.0032\n",
|
|
||||||
"Epoch [47/50], Loss: -2.7123\n",
|
|
||||||
"Epoch [48/50], Loss: -3.6953\n",
|
|
||||||
"Epoch [49/50], Loss: -3.7212\n",
|
|
||||||
"Epoch [50/50], Loss: -3.7558\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import torch.optim as optim\n",
|
|
||||||
"\n",
|
|
||||||
"criterion = nn.CrossEntropyLoss()\n",
|
|
||||||
"optimizer = optim.Adam(model.parameters(), lr=8e-4)\n",
|
|
||||||
"\n",
|
|
||||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
|
||||||
"import tensorboard\n",
|
|
||||||
"writer = SummaryWriter()\n",
|
|
||||||
"\n",
|
|
||||||
"def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n",
|
|
||||||
" model.train()\n",
|
|
||||||
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
||||||
" model.to(device)\n",
|
|
||||||
" \n",
|
|
||||||
" negative_iter = iter(negative_loader)\n",
|
|
||||||
" \n",
|
|
||||||
" for epoch in range(num_epochs):\n",
|
|
||||||
" total_loss = 0\n",
|
|
||||||
" for batch_idx, (labels, embeddings) in enumerate(train_loader):\n",
|
|
||||||
" embeddings = embeddings.to(device)\n",
|
|
||||||
" labels = labels.to(device)\n",
|
|
||||||
" \n",
|
|
||||||
" batch_size = embeddings.size(0)\n",
|
|
||||||
" \n",
|
|
||||||
" # ---------------------\n",
|
|
||||||
" # 1. Positive sample\n",
|
|
||||||
" # ---------------------\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" outputs = model(embeddings) # logits from the model\n",
|
|
||||||
" \n",
|
|
||||||
" class_loss = criterion(outputs, labels)\n",
|
|
||||||
" \n",
|
|
||||||
" # Energy of positive sample\n",
|
|
||||||
" known_energy = energy_score(outputs)\n",
|
|
||||||
" energy_loss_known = known_energy.mean()\n",
|
|
||||||
" \n",
|
|
||||||
" # ------------------------------------\n",
|
|
||||||
" # 2. Negative sample - Random Noise\n",
|
|
||||||
" # ------------------------------------\n",
|
|
||||||
" noise_embeddings = torch.randn_like(embeddings).to(device)\n",
|
|
||||||
" noise_outputs = model(noise_embeddings)\n",
|
|
||||||
" noise_energy = energy_score(noise_outputs)\n",
|
|
||||||
" energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n",
|
|
||||||
" \n",
|
|
||||||
" # ------------------------------------\n",
|
|
||||||
" # 3. Negative sample - custom corpus\n",
|
|
||||||
" # ------------------------------------\n",
|
|
||||||
" \n",
|
|
||||||
" try:\n",
|
|
||||||
" negative_samples = next(negative_iter)\n",
|
|
||||||
" except StopIteration:\n",
|
|
||||||
" negative_iter = iter(negative_loader)\n",
|
|
||||||
" negative_samples = next(negative_iter)\n",
|
|
||||||
" negative_samples = negative_samples.to(device)\n",
|
|
||||||
" negative_outputs = model(negative_samples)\n",
|
|
||||||
" negative_energy = energy_score(negative_outputs)\n",
|
|
||||||
" energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n",
|
|
||||||
" \n",
|
|
||||||
" # -----------------------------\n",
|
|
||||||
" # 4. Overall Loss calculation\n",
|
|
||||||
" # -----------------------------\n",
|
|
||||||
" total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n",
|
|
||||||
" total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n",
|
|
||||||
"\n",
|
|
||||||
" writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n",
|
|
||||||
" writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n",
|
|
||||||
" writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n",
|
|
||||||
" \n",
|
|
||||||
" total_loss_batch.backward()\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
" \n",
|
|
||||||
" total_loss += total_loss_batch.item()\n",
|
|
||||||
" \n",
|
|
||||||
" avg_loss = total_loss / len(train_loader)\n",
|
|
||||||
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n",
|
|
||||||
"\n",
|
|
||||||
"train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n",
|
|
||||||
"writer.flush()\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "e6d29558-f497-4033-8488-169bd25ce881",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Evalutation"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "472702e6-db4a-4faa-9e92-7510e6eacbb1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"ENERGY_THRESHOLD = -3"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Accuracy: 0.9315\n",
|
|
||||||
"Precision: 1.0000\n",
|
|
||||||
"Recall: 0.9254\n",
|
|
||||||
"F1 Score: 0.9612\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n",
|
|
||||||
"\n",
|
|
||||||
"def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n",
|
|
||||||
" model.eval()\n",
|
|
||||||
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
||||||
" \n",
|
|
||||||
" all_preds = []\n",
|
|
||||||
" all_labels = []\n",
|
|
||||||
" \n",
|
|
||||||
" # Evaluate positive sample\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" for labels, embeddings in known_loader:\n",
|
|
||||||
" embeddings = embeddings.to(device)\n",
|
|
||||||
" logits = model(embeddings)\n",
|
|
||||||
" energy = energy_score(logits)\n",
|
|
||||||
" \n",
|
|
||||||
" preds = (energy <= energy_threshold).long()\n",
|
|
||||||
" all_preds.extend(preds.cpu().numpy())\n",
|
|
||||||
" all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n",
|
|
||||||
" \n",
|
|
||||||
" # Evaluate negative sample\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" for embeddings in unknown_loader:\n",
|
|
||||||
" embeddings = embeddings.to(device)\n",
|
|
||||||
" logits = model(embeddings)\n",
|
|
||||||
" energy = energy_score(logits)\n",
|
|
||||||
" \n",
|
|
||||||
" preds = (energy <= energy_threshold).long()\n",
|
|
||||||
" all_preds.extend(preds.cpu().numpy())\n",
|
|
||||||
" all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n",
|
|
||||||
" \n",
|
|
||||||
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
|
||||||
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
|
||||||
"\n",
|
|
||||||
" print(f'Accuracy: {accuracy:.4f}')\n",
|
|
||||||
" print(f'Precision: {precision:.4f}')\n",
|
|
||||||
" print(f'Recall: {recall:.4f}')\n",
|
|
||||||
" print(f'F1 Score: {f1:.4f}')\n",
|
|
||||||
"\n",
|
|
||||||
"evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 20,
|
|
||||||
"id": "ba614054-75e1-4f61-ace5-aeb11e29a222",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Save the model\n",
|
|
||||||
"torch.save(model, \"model.pt\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Inference"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 19,
|
|
||||||
"id": "03928d75-81c8-4298-ab8a-d7f8a758b561",
|
|
||||||
"metadata": {
|
|
||||||
"scrolled": true
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n",
|
|
||||||
" model.eval()\n",
|
|
||||||
" tokens = tokenizer.tokenize(sentence)\n",
|
|
||||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
|
||||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
|
||||||
" embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n",
|
|
||||||
" \n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" logits = model(embeddings)\n",
|
|
||||||
" probabilities = F.softmax(logits, dim=1)\n",
|
|
||||||
" max_prob, predicted = torch.max(probabilities, 1)\n",
|
|
||||||
" \n",
|
|
||||||
" # Calculate energy score\n",
|
|
||||||
" energy = energy_score(logits)\n",
|
|
||||||
"\n",
|
|
||||||
" # If energy > threshold, consider the input as unknown class\n",
|
|
||||||
" if energy.item() > energy_threshold:\n",
|
|
||||||
" return [\"Unknown\", max_prob.item(), energy.item()]\n",
|
|
||||||
" else:\n",
|
|
||||||
" return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n",
|
|
||||||
"\n",
|
|
||||||
"# Example usage:\n",
|
|
||||||
"sentence = \"weather today\"\n",
|
|
||||||
"energy_threshold = ENERGY_THRESHOLD\n",
|
|
||||||
"predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n",
|
|
||||||
"print(f'Predicted: {predicted}')\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.9.19"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
4
intention-classify/training/config.py
Normal file
4
intention-classify/training/config.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# config.py
|
||||||
|
|
||||||
|
model_name = "Qwen/Qwen2.5-3B"
|
||||||
|
DIMENSIONS = 96
|
64
intention-classify/training/data_utils.py
Normal file
64
intention-classify/training/data_utils.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# data_utils.py
|
||||||
|
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(file_path):
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def create_class_mappings(data):
|
||||||
|
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
||||||
|
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||||
|
return class_to_idx, idx_to_class
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64):
|
||||||
|
dataset = []
|
||||||
|
for label, sentences in data.items():
|
||||||
|
for sentence in sentences:
|
||||||
|
tokens = tokenizer.tokenize(sentence)
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
embeddings = [
|
||||||
|
embedding_map[token_id] for token_id in token_ids[:max_length]
|
||||||
|
]
|
||||||
|
embeddings = torch.tensor(embeddings)
|
||||||
|
dataset.append((class_to_idx[label], embeddings))
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
class TextDataset(Dataset):
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.data[idx]
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
labels, embeddings = zip(*batch)
|
||||||
|
labels = torch.tensor(labels)
|
||||||
|
embeddings = pad_sequence(embeddings, batch_first=True)
|
||||||
|
return labels, embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class NegativeSampleDataset(Dataset):
|
||||||
|
def __init__(self, negative_samples):
|
||||||
|
self.samples = negative_samples
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.samples[idx]
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
embeddings = pad_sequence(batch, batch_first=True)
|
||||||
|
return embeddings
|
50
intention-classify/training/model.py
Normal file
50
intention-classify/training/model.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# model.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from training.config import DIMENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, input_dim, heads):
|
||||||
|
super(SelfAttention, self).__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = (input_dim // heads) ** -0.5
|
||||||
|
self.qkv = nn.Linear(input_dim, input_dim * 3)
|
||||||
|
self.fc = nn.Linear(input_dim, input_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size, seq_length, embedding_dim = x.shape
|
||||||
|
qkv = self.qkv(x).view(
|
||||||
|
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
|
||||||
|
)
|
||||||
|
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
|
||||||
|
q = q.permute(0, 2, 1, 3)
|
||||||
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
attention_output = torch.matmul(attn_weights, v)
|
||||||
|
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
|
||||||
|
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
|
||||||
|
return self.fc(attention_output)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBasedModel(nn.Module):
|
||||||
|
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512):
|
||||||
|
super(AttentionBasedModel, self).__init__()
|
||||||
|
self.self_attention = SelfAttention(input_dim, heads)
|
||||||
|
self.fc1 = nn.Linear(input_dim, dim_feedforward)
|
||||||
|
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
||||||
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
self.norm = nn.LayerNorm(input_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
attn_output = self.self_attention(x)
|
||||||
|
attn_output = self.norm(attn_output + x)
|
||||||
|
pooled_output = torch.mean(attn_output, dim=1)
|
||||||
|
x = F.relu(self.fc1(pooled_output))
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
149
intention-classify/training/train.py
Normal file
149
intention-classify/training/train.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
# train.py
|
||||||
|
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from training.data_utils import (
|
||||||
|
load_data,
|
||||||
|
create_class_mappings,
|
||||||
|
preprocess_data,
|
||||||
|
TextDataset,
|
||||||
|
NegativeSampleDataset,
|
||||||
|
)
|
||||||
|
from training.model import AttentionBasedModel
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def energy_score(logits, temperature=1.0):
|
||||||
|
return -torch.logsumexp(logits / temperature, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_noise(batch_size, seq_length, input_dim, device):
|
||||||
|
return torch.randn(batch_size, seq_length, input_dim).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def train_energy_model(
|
||||||
|
model,
|
||||||
|
train_loader,
|
||||||
|
negative_loader,
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
num_epochs=10,
|
||||||
|
margin=1.0,
|
||||||
|
temperature=0.4,
|
||||||
|
):
|
||||||
|
model.train()
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model.to(device)
|
||||||
|
negative_iter = iter(negative_loader)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
total_loss = 0
|
||||||
|
for _, (labels, embeddings) in enumerate(train_loader):
|
||||||
|
embeddings = embeddings.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(embeddings)
|
||||||
|
class_loss = criterion(outputs, labels)
|
||||||
|
known_energy = energy_score(outputs, temperature)
|
||||||
|
positive_margin = 0.0
|
||||||
|
energy_loss_known = F.relu(known_energy - positive_margin).mean()
|
||||||
|
|
||||||
|
noise_embeddings = torch.randn_like(embeddings).to(device)
|
||||||
|
noise_outputs = model(noise_embeddings)
|
||||||
|
noise_energy = energy_score(noise_outputs, temperature)
|
||||||
|
energy_loss_noise = F.relu(margin - noise_energy).mean()
|
||||||
|
|
||||||
|
try:
|
||||||
|
negative_samples = next(negative_iter)
|
||||||
|
except StopIteration:
|
||||||
|
negative_iter = iter(negative_loader)
|
||||||
|
negative_samples = next(negative_iter)
|
||||||
|
negative_samples = negative_samples.to(device)
|
||||||
|
negative_outputs = model(negative_samples)
|
||||||
|
negative_energy = energy_score(negative_outputs, temperature)
|
||||||
|
energy_loss_negative = F.relu(margin - negative_energy).mean()
|
||||||
|
|
||||||
|
total_energy_loss = (
|
||||||
|
energy_loss_known + energy_loss_noise + energy_loss_negative
|
||||||
|
)
|
||||||
|
total_loss_batch = class_loss + total_energy_loss
|
||||||
|
|
||||||
|
total_loss_batch.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += total_loss_batch.item()
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(train_loader)
|
||||||
|
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
from config import model_name, DIMENSIONS
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
data = load_data("data.json")
|
||||||
|
class_to_idx, _ = create_class_mappings(data)
|
||||||
|
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||||
|
dataset = preprocess_data(data, embedding_map, tokenizer, class_to_idx)
|
||||||
|
train_data, _ = train_test_split(dataset, test_size=0.2)
|
||||||
|
|
||||||
|
train_dataset = TextDataset(train_data)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("noise.json", "r") as f:
|
||||||
|
negative_samples_list = json.load(f)
|
||||||
|
|
||||||
|
negative_embedding_list = []
|
||||||
|
for sentence in negative_samples_list:
|
||||||
|
tokens = tokenizer.tokenize(sentence)
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]
|
||||||
|
embeddings = torch.tensor(embeddings)
|
||||||
|
negative_embedding_list.append(embeddings)
|
||||||
|
|
||||||
|
negative_dataset = NegativeSampleDataset(negative_embedding_list)
|
||||||
|
negative_loader = DataLoader(
|
||||||
|
negative_dataset,
|
||||||
|
batch_size=24,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=negative_dataset.collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_dim = DIMENSIONS
|
||||||
|
num_classes = len(class_to_idx)
|
||||||
|
model = AttentionBasedModel(input_dim, num_classes)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=7e-4)
|
||||||
|
|
||||||
|
train_energy_model(
|
||||||
|
model, train_loader, negative_loader, criterion, optimizer, num_epochs=120
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), "model.pt")
|
||||||
|
|
||||||
|
dummy_input = torch.randn(1, 64, DIMENSIONS)
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
dummy_input,
|
||||||
|
"model.onnx",
|
||||||
|
input_names=["input"],
|
||||||
|
output_names=["output"],
|
||||||
|
dynamic_axes={
|
||||||
|
"input": {0: "batch_size", 1: "seq_length"},
|
||||||
|
"output": {0: "batch_size"},
|
||||||
|
},
|
||||||
|
opset_version=11,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
75
intention-classify/validation/inference.py
Normal file
75
intention-classify/validation/inference.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from training.model import AttentionBasedModel
|
||||||
|
from training.config import model_name
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from training.config import DIMENSIONS
|
||||||
|
from training.model import AttentionBasedModel
|
||||||
|
|
||||||
|
|
||||||
|
def energy_score(logits):
|
||||||
|
# Energy score is minus logsumexp
|
||||||
|
return -torch.logsumexp(logits, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def predict_with_energy(
|
||||||
|
model,
|
||||||
|
sentence,
|
||||||
|
embedding_map,
|
||||||
|
tokenizer,
|
||||||
|
idx_to_class,
|
||||||
|
energy_threshold,
|
||||||
|
max_length=64,
|
||||||
|
):
|
||||||
|
model.eval()
|
||||||
|
tokens = tokenizer.tokenize(sentence)
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
print(token_ids)
|
||||||
|
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
|
||||||
|
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
|
||||||
|
current_shape = embeddings.shape
|
||||||
|
|
||||||
|
if current_shape[1] < 2:
|
||||||
|
pad_size = 2 - current_shape[1]
|
||||||
|
embeddings = F.pad(
|
||||||
|
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(embeddings)
|
||||||
|
print(logits)
|
||||||
|
probabilities = F.softmax(logits, dim=1)
|
||||||
|
max_prob, predicted = torch.max(probabilities, 1)
|
||||||
|
|
||||||
|
# Calculate energy score
|
||||||
|
energy = energy_score(logits)
|
||||||
|
|
||||||
|
# If energy > threshold, consider the input as unknown class
|
||||||
|
if energy.item() > energy_threshold:
|
||||||
|
return ["Unknown", max_prob.item(), energy.item()]
|
||||||
|
else:
|
||||||
|
return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]
|
||||||
|
|
||||||
|
|
||||||
|
with open("data.json", "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
||||||
|
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||||
|
num_classes = len(class_to_idx)
|
||||||
|
|
||||||
|
input_dim = DIMENSIONS
|
||||||
|
model = AttentionBasedModel(input_dim, num_classes)
|
||||||
|
model.load_state_dict(torch.load("./model.pt"))
|
||||||
|
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
ENERGY_THRESHOLD = 0
|
||||||
|
sentence = "天气"
|
||||||
|
energy_threshold = ENERGY_THRESHOLD
|
||||||
|
predicted = predict_with_energy(
|
||||||
|
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
|
||||||
|
)
|
||||||
|
print(f"Predicted: {predicted}")
|
@ -24,6 +24,7 @@ cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
|
|||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
model_name = 'jinaai/jina-embeddings-v2-base-zh'
|
model_name = 'jinaai/jina-embeddings-v2-base-zh'
|
||||||
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model.to('cuda')
|
||||||
|
|
||||||
# Define file paths from command-line arguments
|
# Define file paths from command-line arguments
|
||||||
file_a_path = args.file_a
|
file_a_path = args.file_a
|
||||||
|
Loading…
Reference in New Issue
Block a user