ref: the intention-classification model
This commit is contained in:
parent
66cf093177
commit
9f071ee0a0
@ -36,6 +36,7 @@
|
||||
"室外的温度是多少",
|
||||
"达拉斯今天热不热",
|
||||
"苏州现在天气怎么样",
|
||||
"明天悉尼会下雨吗?",
|
||||
"how's the weather",
|
||||
"What's going on with the weather?",
|
||||
"Can you give me an update on the weather?",
|
||||
@ -48,21 +49,21 @@
|
||||
"What's the weather like right now?",
|
||||
"Tell me the current weather conditions.",
|
||||
"How about the weather today?",
|
||||
"What's the weather looking like for the next few hours?",
|
||||
"Is it going to stay this way all day?",
|
||||
"Could you give me a brief overview of the weather?",
|
||||
"What's the general weather situation in our area?",
|
||||
"Is it cloudy or clear outside?",
|
||||
"What's the forecast saying for today's weather?",
|
||||
"Is it going to be a warm day?",
|
||||
"Are we expecting any storms today?",
|
||||
"What's the weather condition outside my window?",
|
||||
"Is it a typical day for this season in terms of weather?",
|
||||
"how's the weather now?",
|
||||
"What's the temperature like right now?",
|
||||
"Can you tell me the current temperature?",
|
||||
"How hot is it outside?",
|
||||
"What's the temperature supposed to be today?",
|
||||
"What's the weather looking like for the next few hours",
|
||||
"Is it going to stay this way all day",
|
||||
"Could you give me a brief overview of the weather",
|
||||
"What's the general weather situation in our area",
|
||||
"Is it cloudy or clear outside",
|
||||
"What's the forecast saying for today's weather",
|
||||
"Is it going to be a warm day",
|
||||
"Are we expecting any storms today",
|
||||
"What's the weather condition outside my window",
|
||||
"Is it a typical day for this season in terms of weather",
|
||||
"how's the weather now",
|
||||
"What's the temperature like right now",
|
||||
"Can you tell me the current temperature",
|
||||
"How hot is it outside",
|
||||
"What's the temperature supposed to be today",
|
||||
"What is the current temp outside?",
|
||||
"Could you tell me the outdoor temperature?",
|
||||
"Is it cold or warm outside?",
|
||||
@ -81,8 +82,8 @@
|
||||
"Can you tell me the temp in the nearby area?",
|
||||
"Is it below freezing outside?",
|
||||
"What's the average temperature for today?",
|
||||
"Is the temperature dropping or rising?",
|
||||
"What should I wear considering the temperature?"
|
||||
"Is the temperature dropping or rising",
|
||||
"What should I wear considering the temperature"
|
||||
],
|
||||
"base64": [
|
||||
"请将数据使用base64编码",
|
||||
@ -110,17 +111,16 @@
|
||||
"解码 base64",
|
||||
"Please encode this data with base64:",
|
||||
"I need to encode the following data in base64",
|
||||
"Could you encode this string using base64?",
|
||||
"Could you encode this string using base64",
|
||||
"Convert this data to b64 encoding",
|
||||
"I want to encode this information with base64",
|
||||
"Help me encode this in base32",
|
||||
"Can you encode this data to base64 format?",
|
||||
"Can you encode this data to base64 format",
|
||||
"b64 encode",
|
||||
"base64 encode",
|
||||
"encode base64",
|
||||
"base 64 encode online"
|
||||
],
|
||||
|
||||
"url-encode": [
|
||||
"编码 url",
|
||||
"URL部分需要编码",
|
||||
@ -145,7 +145,6 @@
|
||||
"url decoder",
|
||||
"URL encoder"
|
||||
],
|
||||
|
||||
"html-encode": [
|
||||
"请编码HTML实体",
|
||||
"文本转为HTML实体",
|
||||
@ -186,7 +185,6 @@
|
||||
"html   conversion",
|
||||
"html nbsp meaning"
|
||||
],
|
||||
|
||||
"ai.command": [
|
||||
"写一个TypeScript的HelloWorld代码",
|
||||
"检查以下内容的语法和清晰度",
|
||||
@ -237,11 +235,11 @@
|
||||
"help me learn chinese",
|
||||
"how to let the screen reader automatically focused to an newly poped up element in the web development",
|
||||
"summarize following text:",
|
||||
"Is there anything wrong with this code or can it be simplified?",
|
||||
"Is there anything wrong with this code or can it be simplified",
|
||||
"generate a Python script that prints 'Hello, World!'",
|
||||
"Can you proofread this essay for grammar and punctuation errors?",
|
||||
"Can you proofread this essay for grammar and punctuation errors",
|
||||
"Create a list of ten example sentences for the word 'serendipity.'",
|
||||
"Can you reformat this JSON to be more readable?",
|
||||
"Can you reformat this JSON to be more readable",
|
||||
"Suggest a creative title for my blog post about healthy eating.",
|
||||
"Refactor this JavaScript function to make it more efficient.",
|
||||
"Help me practice French: provide a sentence with a missing word that I can guess.",
|
||||
@ -249,15 +247,15 @@
|
||||
"Summarize this news article for me.",
|
||||
"Can you review this code snippet for potential security vulnerabilities?",
|
||||
"Generate a SQL query to find all users who signed up in the last 30 days.",
|
||||
"Can you translate this paragraph into Spanish?",
|
||||
"Can you translate this paragraph into Spanish",
|
||||
"Create a flowchart based on the following process description.",
|
||||
"Write a Python function to calculate the factorial of a number.",
|
||||
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
|
||||
"Can you optimize this image for faster loading on a website?",
|
||||
"Can you optimize this image for faster loading on a website",
|
||||
"Suggest some catchy taglines for a new mobile app focused on fitness.",
|
||||
"Write a Bash script to back up my documents folder daily.",
|
||||
"Help me draft an email to request a meeting with a potential client.",
|
||||
"Can you convert this Markdown document into HTML?",
|
||||
"Can you convert this Markdown document into HTML",
|
||||
"Generate a Python script that scrapes data from a specified website.",
|
||||
"Can you find the synonyms of the word 'meticulous'?",
|
||||
"Write a SQL query to join two tables based on a common column.",
|
||||
@ -267,31 +265,57 @@
|
||||
"Can you assist me in learning Japanese?",
|
||||
"How can I make an alert box appear when a user clicks a button on a webpage?",
|
||||
"Summarize this research paper into bullet points.",
|
||||
"Can you check if there are any logical errors in this algorithm?"
|
||||
"Can you check if there are any logical errors in this algorithm?",
|
||||
"请一步一步计算找到函数f(x)=U^2*x/(R+x)^2的顶点坐标。",
|
||||
"如何理解transformer自注意力机制中的Q,K,V?它们分别代表什么?",
|
||||
"帮我写一封求职信。先询问我的教育背景、技能和经验。",
|
||||
"总结这篇论文",
|
||||
"写一份10人晚宴的菜单",
|
||||
"写一篇博客",
|
||||
"写一段演讲稿"
|
||||
],
|
||||
|
||||
"ai.question": [
|
||||
"你认为哪个框架最适合性能敏感的项目?",
|
||||
"knowledge": [
|
||||
"什么是后量子密码学?",
|
||||
"什么是密钥派生函数",
|
||||
"什么是线性代数?",
|
||||
"量子计算的特点是什么",
|
||||
"哈希函数的作用?",
|
||||
"什么是微积分?",
|
||||
"什么是区块链技术",
|
||||
"What is post-quantum cryptography",
|
||||
"What is a key derivation function?",
|
||||
"What is Linear Algebra?",
|
||||
"What is the main use of linear algebra in computer science",
|
||||
"What is quantum computing",
|
||||
"What is a hash function",
|
||||
"What is calculus",
|
||||
"什么是站点隔离?",
|
||||
"What is blockchain technology?",
|
||||
"BLEU 是什么",
|
||||
"黎巴嫩在哪",
|
||||
"什么是转义字符",
|
||||
"MixAlpha售价多少",
|
||||
"什么是神经机器翻译",
|
||||
"什么是月食",
|
||||
"什么是人工智能",
|
||||
"什么是F1-score"
|
||||
],
|
||||
"ai.question": [
|
||||
"人工智能真的有智力吗",
|
||||
"你认为哪个框架最适合性能敏感的项目?",
|
||||
"线性代数在计算机科学中的主要用途是什么?",
|
||||
"我应该使用哪个IDE来编写Go语言?",
|
||||
"Go vs Java vs Kotlin,哪个适合后端",
|
||||
"哪种编程语言最适合数据分析",
|
||||
"什么是量子计算",
|
||||
"什么是哈希函数?",
|
||||
"什么是微积分?",
|
||||
"机器学习在金融中的主要应用有哪些?",
|
||||
"写Python代码最好的文本编辑器是哪个?",
|
||||
"Python vs R vs Julia,哪个更适合数据科学?",
|
||||
"监督学习和无监督学习的关键区别是什么?",
|
||||
"数据库在Web应用程序中的作用是什么",
|
||||
"什么是区块链技术",
|
||||
"使用Docker进行应用程序部署的优势是什么?",
|
||||
"哪个云服务提供商提供最好的AI工具?",
|
||||
"加密是如何工作的?",
|
||||
"负载均衡器在网络架构中的目的是什么?",
|
||||
"加密是如何工作的",
|
||||
"负载均衡器在网络架构中的目的是什么",
|
||||
"机器学习和深度学习有什么区别",
|
||||
"软件工程中最常见的设计模式有哪些",
|
||||
"神经网络是如何学习的",
|
||||
@ -300,31 +324,22 @@
|
||||
"Rust编程语言的关键特性是什么?",
|
||||
"HTTP和HTTPS有什么区别",
|
||||
"使用像Git这样的版本控制系统有什么优势?",
|
||||
"什么是'边缘计算'的概念",
|
||||
"哪种编程语言最适合构建移动应用?",
|
||||
"关系数据库和NoSQL数据库有什么不同?",
|
||||
"算法在计算机科学中的重要性是什么?",
|
||||
"算法在计算机科学中的重要性是什么",
|
||||
"API在软件开发中的作用是什么",
|
||||
"保护Web应用程序的最佳实践是什么",
|
||||
"虚拟现实和增强现实有什么区别?",
|
||||
"机器翻译是如何工作的?",
|
||||
"Which framework do you think is the most suitable for performance sensitive projects?",
|
||||
"What is post-quantum cryptography",
|
||||
"What is a key derivation function?",
|
||||
"What is Linear Algebra?",
|
||||
"What is the main use of linear algebra in computer science",
|
||||
"which IDE should I use for Go",
|
||||
"Go vs Java vs Koltin, which for a backend",
|
||||
"Which programming language is best suited for data analysis?",
|
||||
"What is quantum computing?",
|
||||
"What is a hash function?",
|
||||
"What is calculus?",
|
||||
"What are the main applications of machine learning in finance?",
|
||||
"Which text editor is best for writing Python code?",
|
||||
"Python vs R vs Julia, which is better for data science?",
|
||||
"What are the key differences between supervised and unsupervised learning?",
|
||||
"What are the main applications of machine learning in finance",
|
||||
"Which text editor is best for writing Python code",
|
||||
"Python vs R vs Julia, which is better for data science",
|
||||
"What are the key differences between supervised and unsupervised learning",
|
||||
"What is the role of a database in a web application?",
|
||||
"What is blockchain technology?",
|
||||
"What are the advantages of using Docker for application deployment?",
|
||||
"Which cloud service provider offers the best AI tools?",
|
||||
"How does encryption work?",
|
||||
@ -332,19 +347,20 @@
|
||||
"What is the difference between machine learning and deep learning?",
|
||||
"What are the most common design patterns in software engineering?",
|
||||
"How does a neural network learn?",
|
||||
"What is the main benefit of using a microservices architecture?",
|
||||
"What is the difference between a compiler and an interpreter?",
|
||||
"What are the key features of the Rust programming language?",
|
||||
"What is the difference between HTTP and HTTPS?",
|
||||
"What are the advantages of using a version control system like Git?",
|
||||
"What is the concept of 'edge computing'?",
|
||||
"Which programming language is best for building mobile apps?",
|
||||
"How does a relational database differ from a NoSQL database?",
|
||||
"What is the importance of algorithms in computer science?",
|
||||
"What is the role of an API in software development?",
|
||||
"What is the main benefit of using a microservices architecture",
|
||||
"What is the difference between a compiler and an interpreter",
|
||||
"What are the key features of the Rust programming language",
|
||||
"What is the difference between HTTP and HTTPS",
|
||||
"What are the advantages of using a version control system like Git",
|
||||
"What is the concept of 'edge computing'",
|
||||
"Which programming language is best for building mobile apps",
|
||||
"How does a relational database differ from a NoSQL database",
|
||||
"What is the importance of algorithms in computer science",
|
||||
"What is the role of an API in software development",
|
||||
"What are the best practices for securing a web application?",
|
||||
"What is the difference between virtual reality and augmented reality?",
|
||||
"How does machine translation work?"
|
||||
"How does machine translation work?",
|
||||
"MBTI有科学依据吗?"
|
||||
],
|
||||
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
|
||||
}
|
||||
|
@ -28,7 +28,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name=\"microsoft/Phi-3-mini-4k-instruct\""
|
||||
"model_name=\"Qwen/Qwen2.5-3B\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -37,17 +37,10 @@
|
||||
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1aeb02c7c8084b1eb1b8e3178882fd60",
|
||||
"model_id": "11caef0e1b674f6ab15880f3f25eca6a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@ -95,7 +88,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DIMENSIONS = 128"
|
||||
"DIMENSIONS = 96"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -168,11 +161,17 @@
|
||||
"import struct\n",
|
||||
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
|
||||
" for token_id in range(len(vocab)):\n",
|
||||
" # Write token id (2 bytes)\n",
|
||||
" f.write(struct.pack('H', token_id))\n",
|
||||
" # Write embedding vector (128 float numbers)\n",
|
||||
" f.write(struct.pack('128f', *reduced_embeddings[token_id]))"
|
||||
" # 将向量转换为半精度浮点数并保存\n",
|
||||
" f.write(struct.pack('96e', *reduced_embeddings[token_id].astype(np.float16)))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "511a7cc4-1b8c-468c-b2a0-16dc6d74ab44",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -191,7 +190,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.19"
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -4,5 +4,350 @@
|
||||
"我爱你",
|
||||
"嘿嘿嘿诶嘿",
|
||||
"为什么",
|
||||
"拼多多"
|
||||
]
|
||||
"拼多多",
|
||||
"machine translation",
|
||||
"trustrank",
|
||||
"中文词典",
|
||||
"bin screen linux",
|
||||
"\"TinyBERT",
|
||||
"iconify",
|
||||
"反义词 英文",
|
||||
"referer",
|
||||
"watchos uiscreen",
|
||||
"张鑫旭",
|
||||
"google first result",
|
||||
"flutter text align center",
|
||||
"ASR model",
|
||||
"real time whisper",
|
||||
"千樱凛",
|
||||
"马嘉祺",
|
||||
"flutter widget catalog",
|
||||
"flutter BottomNavigationBar left",
|
||||
"flutter tab indent vscode",
|
||||
"react native 用 expo 吗",
|
||||
"latest monorepo tool",
|
||||
"\"vite\" \"abortController\" is not defined",
|
||||
"vim comment lines",
|
||||
"Error: unable to get issuer certificate",
|
||||
"uuidv4",
|
||||
"npm semver",
|
||||
"react polyfill vite",
|
||||
"vibrance",
|
||||
"I can eat glass, it doesn't hurt me \"japanese\"",
|
||||
"I can swallow glass without any harm to myself",
|
||||
"copilot pricing",
|
||||
"vim close window",
|
||||
"sensors macos command",
|
||||
"智乃",
|
||||
"pypi wikipedia",
|
||||
"tesseract macos m1",
|
||||
"rag prompt template",
|
||||
"英国 破产",
|
||||
"bewlybewly",
|
||||
"safari-web-extension-converter",
|
||||
"starcoder",
|
||||
"open source web search for ai",
|
||||
"gpt4o mini tokenizer",
|
||||
"gpt4o tokenizer",
|
||||
"reverse dns lookup linux",
|
||||
"online ping",
|
||||
"termux",
|
||||
"802.11 table",
|
||||
"optimize",
|
||||
"集群",
|
||||
"chrome us",
|
||||
"transflective",
|
||||
"ielts toefl",
|
||||
"react router",
|
||||
"摇曳露营 萌娘百科",
|
||||
"isrc",
|
||||
"apple-system",
|
||||
"-apple-system",
|
||||
"css clip path animation",
|
||||
"can i use relative path in og image",
|
||||
"GitSora",
|
||||
"matrix im",
|
||||
"test your vocabulary",
|
||||
"boarding pass",
|
||||
"函数签名",
|
||||
"类型谓词",
|
||||
"barcode",
|
||||
"智能",
|
||||
"threejs 入门",
|
||||
"南亚语系",
|
||||
"linux user's computer be like",
|
||||
"apple a16 显微图",
|
||||
"dallas",
|
||||
"恶魔 英文",
|
||||
"Rime meaning",
|
||||
"adobe media encoder macos download",
|
||||
"mp4 transparency",
|
||||
"webkit",
|
||||
"chromium",
|
||||
"献血",
|
||||
"软件强制更新",
|
||||
"If you don’t agree with its politics views, Notepad+ + will add random characters in your source code.",
|
||||
"Unmerged paths",
|
||||
"字数统计",
|
||||
"Use build.rollupOptions.output.manualChunks to improve chunking: https://rollupjs.org/configuration-options/#output-manualchunks",
|
||||
"世界人权宣言",
|
||||
"latex percent",
|
||||
"chord in keyboard",
|
||||
"Google is trying to kill the Open Web.",
|
||||
"silo'd",
|
||||
"swiftui 数组倒数访问",
|
||||
"swiftui link to another view",
|
||||
"fizzbuzz",
|
||||
"AppDelegate watchos",
|
||||
"Cannot find type 'UIApplicationDelegate' in scope",
|
||||
"swiftui web image",
|
||||
"spammer",
|
||||
"swiftui text",
|
||||
"钢琴",
|
||||
"disable webgl chrome",
|
||||
"online uuid",
|
||||
"cp show progress",
|
||||
"易容术",
|
||||
"fulilian",
|
||||
"cargo",
|
||||
"wordle",
|
||||
"mismatch",
|
||||
"btc",
|
||||
"squelch",
|
||||
"psql show table structure",
|
||||
"let padding don't effect when empty",
|
||||
"take over the world meaning",
|
||||
"brain teasers",
|
||||
"Google flight API",
|
||||
"square symbol",
|
||||
"sill",
|
||||
"nextjs layout per page",
|
||||
"UA 550 umol/L",
|
||||
"react production promotion page",
|
||||
"jupyter notebook",
|
||||
"wth meaning",
|
||||
"glove词向量",
|
||||
"google suggestion relevance",
|
||||
"YouTube advertising income",
|
||||
"PKI",
|
||||
"next client only component",
|
||||
"nextjs use client",
|
||||
"nextjs docker tailwind not working",
|
||||
"k8s",
|
||||
"Logistic Regression",
|
||||
"氯化钾注射死刑",
|
||||
"icloud photo loss",
|
||||
"芙宁娜 水上行走",
|
||||
"vector design tool",
|
||||
"netizen",
|
||||
"framework or next js documentation",
|
||||
"csync",
|
||||
"next js",
|
||||
"后量子正向保密",
|
||||
"nip05",
|
||||
"Sora技术原理",
|
||||
"wasm效率",
|
||||
"switch code",
|
||||
"online IPA pronunciation",
|
||||
"pnpm global adir",
|
||||
"如何搜索",
|
||||
"1999 抽卡期望",
|
||||
"swiftui background blur",
|
||||
"chrome macos fullscreen hide",
|
||||
"中英文空格自动",
|
||||
"ios 旁白 屏幕识别",
|
||||
"ios 旁白 转子",
|
||||
"http 404",
|
||||
"yaml缩进",
|
||||
"counter generator github",
|
||||
"git 服务器提供远端仓库",
|
||||
"ipfs companion",
|
||||
"supervisor config",
|
||||
"SSO",
|
||||
"slot embedding",
|
||||
"sql show tables",
|
||||
"The request signature we calculated does not match the signature you provided. Check your Secret Access Key and signing method.",
|
||||
"icloud.com,cn",
|
||||
"VuePress",
|
||||
"parser",
|
||||
"stackoverflow statistics",
|
||||
"sd xl",
|
||||
"Rollup failed to resolve import \"workbox-precaching\" from",
|
||||
"dep",
|
||||
"Cannot find module estree-walker.js docker",
|
||||
"nuxt run",
|
||||
"base58解码",
|
||||
"cga",
|
||||
"vscode",
|
||||
"vscode",
|
||||
"silicon",
|
||||
"macos m1 linux",
|
||||
"预处理 后处理",
|
||||
"is vp9 opensource",
|
||||
"Alice Blu",
|
||||
"失控玩家",
|
||||
"kv数据库",
|
||||
"redis 持久化",
|
||||
"firefox disable outline",
|
||||
"cd -2",
|
||||
"IM application",
|
||||
"2021国产电影",
|
||||
"youtube chat overlay obs",
|
||||
"obs add clock",
|
||||
"Z is not defined nuxt",
|
||||
"safari ios debug",
|
||||
"safari debug",
|
||||
"chat",
|
||||
"nuxt plugin inject",
|
||||
"twitch",
|
||||
"obs 绿幕",
|
||||
"gnupg",
|
||||
"kde plasma wallpaper engine",
|
||||
"Plasma",
|
||||
"dns over https",
|
||||
"localforage缺点",
|
||||
"watchOS 10",
|
||||
"noun of repeat",
|
||||
"微信输入法",
|
||||
"行业报告",
|
||||
"keepass",
|
||||
"platform",
|
||||
"steam",
|
||||
"java proxy",
|
||||
"0 design",
|
||||
"cefr word level list",
|
||||
"precipitation meaning",
|
||||
"international school of lausanne",
|
||||
"Vim Uganda",
|
||||
"抖音 推荐算法",
|
||||
"Meta NNLO",
|
||||
"windbg dump分析",
|
||||
"web image fft",
|
||||
"GPT-4 Pricing",
|
||||
"GPT-4",
|
||||
"Scala",
|
||||
"tauri教程",
|
||||
"asyncio.create_task用法",
|
||||
"H5 滚动到底部",
|
||||
"microsoft copilot",
|
||||
"枫丹文字",
|
||||
"brew pip",
|
||||
"TS7016: Could not find a declaration file for module react .",
|
||||
"fastapi websocket",
|
||||
"kazv",
|
||||
"The Type 孔雀计划",
|
||||
"第一个图形操作系统",
|
||||
"娱乐 诞生",
|
||||
"ffmpeg 音频封面",
|
||||
"Jean-Loup Gailly",
|
||||
"Linux用户软件位置",
|
||||
"\"ubuntu\" 平滑滚动",
|
||||
"python range函数",
|
||||
"KMP",
|
||||
"sd 8gen2 GPU GFLOPS",
|
||||
"mac语音输入法",
|
||||
"openai translate",
|
||||
"蔚蓝档案 初始抽卡",
|
||||
"free custom domain email",
|
||||
"洛天依",
|
||||
"b站 频道页Tab 跳转",
|
||||
"URL 重定向预览",
|
||||
"计算机",
|
||||
"sololearn",
|
||||
"PoS机制 通俗解释",
|
||||
"google search cost",
|
||||
"bos s3",
|
||||
"react 打包",
|
||||
"useeffect 用法",
|
||||
"ts 字典类型",
|
||||
"vscode 字典单词自动补全插件",
|
||||
"componentwillupdate",
|
||||
"iPad Mini 2",
|
||||
"use-immer",
|
||||
"reducer 和 context",
|
||||
"mint",
|
||||
"Elementary OS",
|
||||
"google科技新闻",
|
||||
"iCloud mail \"\"-9002\"\"",
|
||||
"氢氧化铁胶体制备",
|
||||
"react native 视频处理",
|
||||
"四川 2023 高考 复旦大学 分数线",
|
||||
"哑铃弯举",
|
||||
"m2 ultra",
|
||||
"电池循环计数 site:apple.com",
|
||||
"相机发明时间",
|
||||
"冯诺依曼结构",
|
||||
"哈佛架构",
|
||||
"nodejs 后端",
|
||||
"34.5M€ to CN¥",
|
||||
"NLP 实体关注",
|
||||
"monkey",
|
||||
"react 快捷键监听",
|
||||
"mac 好看的电子书阅读器",
|
||||
"新闻",
|
||||
"在线字体编辑器",
|
||||
"ars technica",
|
||||
"genshin 4.1 release time",
|
||||
"swift device activity report",
|
||||
"swiftui tabview background",
|
||||
"swiftui text space",
|
||||
"apple inc. wikipedia",
|
||||
"how long does it take Google to return the results",
|
||||
"云原神 web",
|
||||
"支持homekit的空调",
|
||||
"内核隔离",
|
||||
"海祇岛解密",
|
||||
"swiftui Textfield",
|
||||
"xcode",
|
||||
"qq 链接",
|
||||
"M1 推出时间",
|
||||
"USB-IF",
|
||||
"nvchat",
|
||||
"P1% FPS",
|
||||
"react i18next 当前语言",
|
||||
"js 获取语言",
|
||||
"MulType",
|
||||
"b站平均使用时间",
|
||||
"pip 阿里源",
|
||||
"ip info",
|
||||
"graphjet",
|
||||
"金融思维",
|
||||
"C#写入文件",
|
||||
"Last Day Sinista M",
|
||||
"在 系统 位置 xcode select 找 不 到 SDK",
|
||||
"Error: Could not find a valid Xcode app bundle at '/Library/Developer/CommandLineTools'. Please update your Apple SDK location in Visual Studio's preferences (Projects > SDK Locations > Apple > Apple SDK). (UniBattery)",
|
||||
".NET能做什么",
|
||||
"could i give no tip ",
|
||||
"miami university of ohio",
|
||||
"方正颜宋",
|
||||
"中文 标题字体",
|
||||
"聚典平台",
|
||||
"62 basic words for a language",
|
||||
"procrastination meaning",
|
||||
"Lingbe",
|
||||
"娱乐至死",
|
||||
"macOS 外接显示器渲染",
|
||||
"白玉袖",
|
||||
"SwiftUI入门",
|
||||
"html插入其它网页",
|
||||
"捆绑 小说",
|
||||
"apple music 无损下载",
|
||||
"一miumiu 赐予",
|
||||
"macos markdown",
|
||||
"safari 开发者工具",
|
||||
"\"百合\" \"武侠\" \"国漫\"",
|
||||
"epub 格式详解",
|
||||
"chrome 隐藏滚动条",
|
||||
"发宽空格",
|
||||
"U+200A",
|
||||
"无性人",
|
||||
"Spotify",
|
||||
"禾念",
|
||||
"how to pronounce Lorem ipsum",
|
||||
"言和为什么不是男孩子",
|
||||
"浏览器主页",
|
||||
"react",
|
||||
"Tailwindcss react 扩展怎么用",
|
||||
"Prettier 扩展怎么用",
|
||||
"linter\""
|
||||
]
|
||||
|
@ -1,575 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a6a3195f-d099-4bf4-846f-51f403954818",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# sparkastML: Training the Intention Classification Model\n",
|
||||
"\n",
|
||||
"This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n",
|
||||
"\n",
|
||||
"In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"from torch.nn.utils.rnn import pad_sequence\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from transformers import AutoTokenizer, AutoModel\n",
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"from scipy.spatial.distance import euclidean\n",
|
||||
"from scipy.stats import weibull_min\n",
|
||||
"from sklearn.preprocessing import normalize\n",
|
||||
"import torch.nn.functional as F\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n",
|
||||
"DIMENSIONS = 128\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1ae14906-338d-4c99-87ed-bb1acd22b295",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load Data\n",
|
||||
"\n",
|
||||
"We load the data from `data.json`, and also get the negative sample from the `noise.json`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "a206071c-ce4e-4de4-b936-bfc70d13708a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
|
||||
" embeddings = torch.tensor(embeddings)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Load data\n",
|
||||
"with open('data.json', 'r') as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
"\n",
|
||||
"# Create map: class to index\n",
|
||||
"class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n",
|
||||
"idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n",
|
||||
"\n",
|
||||
"# Preprocess data, convert sentences to the format of (class idx, embedding)\n",
|
||||
"def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n",
|
||||
" dataset = []\n",
|
||||
" for label, sentences in data.items():\n",
|
||||
" for sentence in sentences:\n",
|
||||
" # Tokenize the sentence and convert tokens to embedding vectors\n",
|
||||
" tokens = tokenizer.tokenize(sentence)\n",
|
||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
||||
" embeddings = torch.tensor(embeddings)\n",
|
||||
" dataset.append((class_to_idx[label], embeddings))\n",
|
||||
" return dataset\n",
|
||||
"\n",
|
||||
"# Load embedding map\n",
|
||||
"embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n",
|
||||
"\n",
|
||||
"# Get preprocessed dataset\n",
|
||||
"dataset = preprocess_data(data, embedding_map, tokenizer)\n",
|
||||
"\n",
|
||||
"# Train-test split\n",
|
||||
"train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n",
|
||||
"\n",
|
||||
"class TextDataset(Dataset):\n",
|
||||
" def __init__(self, data):\n",
|
||||
" self.data = data\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.data)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" return self.data[idx]\n",
|
||||
"\n",
|
||||
" def collate_fn(self, batch):\n",
|
||||
" labels, embeddings = zip(*batch)\n",
|
||||
" labels = torch.tensor(labels)\n",
|
||||
" embeddings = pad_sequence(embeddings, batch_first=True)\n",
|
||||
" return labels, embeddings\n",
|
||||
"\n",
|
||||
"train_dataset = TextDataset(train_data)\n",
|
||||
"val_dataset = TextDataset(val_data)\n",
|
||||
"\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n",
|
||||
"val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"\n",
|
||||
"class NegativeSampleDataset(Dataset):\n",
|
||||
" def __init__(self, negative_samples):\n",
|
||||
" \"\"\"\n",
|
||||
" negative_samples: List or array of negative sample embeddings or raw text\n",
|
||||
" \"\"\"\n",
|
||||
" self.samples = negative_samples\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.samples)\n",
|
||||
" \n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" return self.samples[idx]\n",
|
||||
"\n",
|
||||
" def collate_fn(self, batch):\n",
|
||||
" embeddings = pad_sequence(batch, batch_first=True)\n",
|
||||
" return embeddings\n",
|
||||
"\n",
|
||||
"with open('noise.json', 'r') as f:\n",
|
||||
" negative_samples_list = json.load(f)\n",
|
||||
"\n",
|
||||
"negative_embedding_list = []\n",
|
||||
"\n",
|
||||
"for sentence in negative_samples_list:\n",
|
||||
" tokens = tokenizer.tokenize(sentence)\n",
|
||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n",
|
||||
" embeddings = torch.tensor(embeddings)\n",
|
||||
" negative_embedding_list.append(embeddings)\n",
|
||||
"\n",
|
||||
"negative_dataset = NegativeSampleDataset(negative_embedding_list)\n",
|
||||
"negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "600febe4-2484-4aad-90a1-2bc821fdce1a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Implementating the Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "adf624ac-ad63-437b-95f6-b02b7253b91e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"\n",
|
||||
"class TextCNN(nn.Module):\n",
|
||||
" def __init__(self, input_dim, num_classes):\n",
|
||||
" super(TextCNN, self).__init__()\n",
|
||||
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
|
||||
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
|
||||
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
|
||||
" \n",
|
||||
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||
" \n",
|
||||
" self.dropout = nn.Dropout(0.5)\n",
|
||||
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
|
||||
" \n",
|
||||
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
|
||||
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
|
||||
" \n",
|
||||
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
|
||||
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
|
||||
" \n",
|
||||
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
|
||||
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
|
||||
" \n",
|
||||
" x = torch.cat((x1, x2, x3), dim=1)\n",
|
||||
" x = self.dropout(x)\n",
|
||||
" x = self.fc(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"# Initialize model\n",
|
||||
"input_dim = DIMENSIONS\n",
|
||||
"num_classes = len(class_to_idx)\n",
|
||||
"model = TextCNN(input_dim, num_classes)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2750e17d-8a60-40c7-851b-1a567d0ee82b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Energy-based Models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def energy_score(logits):\n",
|
||||
" # Energy score is minus logsumexp\n",
|
||||
" return -torch.logsumexp(logits, dim=1)\n",
|
||||
"\n",
|
||||
"def generate_noise(batch_size, seq_length ,input_dim, device):\n",
|
||||
" # Generate a Gaussian noise\n",
|
||||
" return torch.randn(batch_size, seq_length, input_dim).to(device)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch [1/50], Loss: 12.5108\n",
|
||||
"Epoch [2/50], Loss: 10.7305\n",
|
||||
"Epoch [3/50], Loss: 10.2943\n",
|
||||
"Epoch [4/50], Loss: 9.9350\n",
|
||||
"Epoch [5/50], Loss: 9.7991\n",
|
||||
"Epoch [6/50], Loss: 9.6443\n",
|
||||
"Epoch [7/50], Loss: 9.4762\n",
|
||||
"Epoch [8/50], Loss: 9.4637\n",
|
||||
"Epoch [9/50], Loss: 9.3025\n",
|
||||
"Epoch [10/50], Loss: 9.1719\n",
|
||||
"Epoch [11/50], Loss: 9.0632\n",
|
||||
"Epoch [12/50], Loss: 8.9741\n",
|
||||
"Epoch [13/50], Loss: 8.8487\n",
|
||||
"Epoch [14/50], Loss: 8.6565\n",
|
||||
"Epoch [15/50], Loss: 8.5830\n",
|
||||
"Epoch [16/50], Loss: 8.4196\n",
|
||||
"Epoch [17/50], Loss: 8.2319\n",
|
||||
"Epoch [18/50], Loss: 8.0655\n",
|
||||
"Epoch [19/50], Loss: 7.7140\n",
|
||||
"Epoch [20/50], Loss: 7.6921\n",
|
||||
"Epoch [21/50], Loss: 7.3375\n",
|
||||
"Epoch [22/50], Loss: 7.2297\n",
|
||||
"Epoch [23/50], Loss: 6.8833\n",
|
||||
"Epoch [24/50], Loss: 6.8534\n",
|
||||
"Epoch [25/50], Loss: 6.4557\n",
|
||||
"Epoch [26/50], Loss: 6.1365\n",
|
||||
"Epoch [27/50], Loss: 5.8558\n",
|
||||
"Epoch [28/50], Loss: 5.5030\n",
|
||||
"Epoch [29/50], Loss: 5.1604\n",
|
||||
"Epoch [30/50], Loss: 4.7742\n",
|
||||
"Epoch [31/50], Loss: 4.5958\n",
|
||||
"Epoch [32/50], Loss: 4.0713\n",
|
||||
"Epoch [33/50], Loss: 3.8872\n",
|
||||
"Epoch [34/50], Loss: 3.5240\n",
|
||||
"Epoch [35/50], Loss: 3.3115\n",
|
||||
"Epoch [36/50], Loss: 2.5667\n",
|
||||
"Epoch [37/50], Loss: 2.6709\n",
|
||||
"Epoch [38/50], Loss: 1.8075\n",
|
||||
"Epoch [39/50], Loss: 1.6654\n",
|
||||
"Epoch [40/50], Loss: 0.4622\n",
|
||||
"Epoch [41/50], Loss: 0.4719\n",
|
||||
"Epoch [42/50], Loss: -0.4037\n",
|
||||
"Epoch [43/50], Loss: -0.9405\n",
|
||||
"Epoch [44/50], Loss: -1.7204\n",
|
||||
"Epoch [45/50], Loss: -2.4124\n",
|
||||
"Epoch [46/50], Loss: -3.0032\n",
|
||||
"Epoch [47/50], Loss: -2.7123\n",
|
||||
"Epoch [48/50], Loss: -3.6953\n",
|
||||
"Epoch [49/50], Loss: -3.7212\n",
|
||||
"Epoch [50/50], Loss: -3.7558\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch.optim as optim\n",
|
||||
"\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=8e-4)\n",
|
||||
"\n",
|
||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||
"import tensorboard\n",
|
||||
"writer = SummaryWriter()\n",
|
||||
"\n",
|
||||
"def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n",
|
||||
" model.train()\n",
|
||||
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
" model.to(device)\n",
|
||||
" \n",
|
||||
" negative_iter = iter(negative_loader)\n",
|
||||
" \n",
|
||||
" for epoch in range(num_epochs):\n",
|
||||
" total_loss = 0\n",
|
||||
" for batch_idx, (labels, embeddings) in enumerate(train_loader):\n",
|
||||
" embeddings = embeddings.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
" \n",
|
||||
" batch_size = embeddings.size(0)\n",
|
||||
" \n",
|
||||
" # ---------------------\n",
|
||||
" # 1. Positive sample\n",
|
||||
" # ---------------------\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" outputs = model(embeddings) # logits from the model\n",
|
||||
" \n",
|
||||
" class_loss = criterion(outputs, labels)\n",
|
||||
" \n",
|
||||
" # Energy of positive sample\n",
|
||||
" known_energy = energy_score(outputs)\n",
|
||||
" energy_loss_known = known_energy.mean()\n",
|
||||
" \n",
|
||||
" # ------------------------------------\n",
|
||||
" # 2. Negative sample - Random Noise\n",
|
||||
" # ------------------------------------\n",
|
||||
" noise_embeddings = torch.randn_like(embeddings).to(device)\n",
|
||||
" noise_outputs = model(noise_embeddings)\n",
|
||||
" noise_energy = energy_score(noise_outputs)\n",
|
||||
" energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n",
|
||||
" \n",
|
||||
" # ------------------------------------\n",
|
||||
" # 3. Negative sample - custom corpus\n",
|
||||
" # ------------------------------------\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" negative_samples = next(negative_iter)\n",
|
||||
" except StopIteration:\n",
|
||||
" negative_iter = iter(negative_loader)\n",
|
||||
" negative_samples = next(negative_iter)\n",
|
||||
" negative_samples = negative_samples.to(device)\n",
|
||||
" negative_outputs = model(negative_samples)\n",
|
||||
" negative_energy = energy_score(negative_outputs)\n",
|
||||
" energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n",
|
||||
" \n",
|
||||
" # -----------------------------\n",
|
||||
" # 4. Overall Loss calculation\n",
|
||||
" # -----------------------------\n",
|
||||
" total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n",
|
||||
" total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n",
|
||||
"\n",
|
||||
" writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n",
|
||||
" writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n",
|
||||
" writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n",
|
||||
" \n",
|
||||
" total_loss_batch.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" \n",
|
||||
" total_loss += total_loss_batch.item()\n",
|
||||
" \n",
|
||||
" avg_loss = total_loss / len(train_loader)\n",
|
||||
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n",
|
||||
"\n",
|
||||
"train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n",
|
||||
"writer.flush()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e6d29558-f497-4033-8488-169bd25ce881",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evalutation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "472702e6-db4a-4faa-9e92-7510e6eacbb1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ENERGY_THRESHOLD = -3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy: 0.9315\n",
|
||||
"Precision: 1.0000\n",
|
||||
"Recall: 0.9254\n",
|
||||
"F1 Score: 0.9612\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n",
|
||||
"\n",
|
||||
"def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n",
|
||||
" model.eval()\n",
|
||||
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
" \n",
|
||||
" all_preds = []\n",
|
||||
" all_labels = []\n",
|
||||
" \n",
|
||||
" # Evaluate positive sample\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for labels, embeddings in known_loader:\n",
|
||||
" embeddings = embeddings.to(device)\n",
|
||||
" logits = model(embeddings)\n",
|
||||
" energy = energy_score(logits)\n",
|
||||
" \n",
|
||||
" preds = (energy <= energy_threshold).long()\n",
|
||||
" all_preds.extend(preds.cpu().numpy())\n",
|
||||
" all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n",
|
||||
" \n",
|
||||
" # Evaluate negative sample\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for embeddings in unknown_loader:\n",
|
||||
" embeddings = embeddings.to(device)\n",
|
||||
" logits = model(embeddings)\n",
|
||||
" energy = energy_score(logits)\n",
|
||||
" \n",
|
||||
" preds = (energy <= energy_threshold).long()\n",
|
||||
" all_preds.extend(preds.cpu().numpy())\n",
|
||||
" all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n",
|
||||
" \n",
|
||||
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
||||
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
||||
"\n",
|
||||
" print(f'Accuracy: {accuracy:.4f}')\n",
|
||||
" print(f'Precision: {precision:.4f}')\n",
|
||||
" print(f'Recall: {recall:.4f}')\n",
|
||||
" print(f'F1 Score: {f1:.4f}')\n",
|
||||
"\n",
|
||||
"evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "ba614054-75e1-4f61-ace5-aeb11e29a222",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save the model\n",
|
||||
"torch.save(model, \"model.pt\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Inference"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "03928d75-81c8-4298-ab8a-d7f8a758b561",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n",
|
||||
" model.eval()\n",
|
||||
" tokens = tokenizer.tokenize(sentence)\n",
|
||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
||||
" embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" logits = model(embeddings)\n",
|
||||
" probabilities = F.softmax(logits, dim=1)\n",
|
||||
" max_prob, predicted = torch.max(probabilities, 1)\n",
|
||||
" \n",
|
||||
" # Calculate energy score\n",
|
||||
" energy = energy_score(logits)\n",
|
||||
"\n",
|
||||
" # If energy > threshold, consider the input as unknown class\n",
|
||||
" if energy.item() > energy_threshold:\n",
|
||||
" return [\"Unknown\", max_prob.item(), energy.item()]\n",
|
||||
" else:\n",
|
||||
" return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n",
|
||||
"\n",
|
||||
"# Example usage:\n",
|
||||
"sentence = \"weather today\"\n",
|
||||
"energy_threshold = ENERGY_THRESHOLD\n",
|
||||
"predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n",
|
||||
"print(f'Predicted: {predicted}')\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
4
intention-classify/training/config.py
Normal file
4
intention-classify/training/config.py
Normal file
@ -0,0 +1,4 @@
|
||||
# config.py
|
||||
|
||||
model_name = "Qwen/Qwen2.5-3B"
|
||||
DIMENSIONS = 96
|
64
intention-classify/training/data_utils.py
Normal file
64
intention-classify/training/data_utils.py
Normal file
@ -0,0 +1,64 @@
|
||||
# data_utils.py
|
||||
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def load_data(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def create_class_mappings(data):
|
||||
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
||||
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||
return class_to_idx, idx_to_class
|
||||
|
||||
|
||||
def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64):
|
||||
dataset = []
|
||||
for label, sentences in data.items():
|
||||
for sentence in sentences:
|
||||
tokens = tokenizer.tokenize(sentence)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
embeddings = [
|
||||
embedding_map[token_id] for token_id in token_ids[:max_length]
|
||||
]
|
||||
embeddings = torch.tensor(embeddings)
|
||||
dataset.append((class_to_idx[label], embeddings))
|
||||
return dataset
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
def collate_fn(self, batch):
|
||||
labels, embeddings = zip(*batch)
|
||||
labels = torch.tensor(labels)
|
||||
embeddings = pad_sequence(embeddings, batch_first=True)
|
||||
return labels, embeddings
|
||||
|
||||
|
||||
class NegativeSampleDataset(Dataset):
|
||||
def __init__(self, negative_samples):
|
||||
self.samples = negative_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.samples[idx]
|
||||
|
||||
def collate_fn(self, batch):
|
||||
embeddings = pad_sequence(batch, batch_first=True)
|
||||
return embeddings
|
50
intention-classify/training/model.py
Normal file
50
intention-classify/training/model.py
Normal file
@ -0,0 +1,50 @@
|
||||
# model.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from training.config import DIMENSIONS
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, input_dim, heads):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.heads = heads
|
||||
self.scale = (input_dim // heads) ** -0.5
|
||||
self.qkv = nn.Linear(input_dim, input_dim * 3)
|
||||
self.fc = nn.Linear(input_dim, input_dim)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_length, embedding_dim = x.shape
|
||||
qkv = self.qkv(x).view(
|
||||
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
|
||||
)
|
||||
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
attention_output = torch.matmul(attn_weights, v)
|
||||
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
|
||||
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
|
||||
return self.fc(attention_output)
|
||||
|
||||
|
||||
class AttentionBasedModel(nn.Module):
|
||||
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512):
|
||||
super(AttentionBasedModel, self).__init__()
|
||||
self.self_attention = SelfAttention(input_dim, heads)
|
||||
self.fc1 = nn.Linear(input_dim, dim_feedforward)
|
||||
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.norm = nn.LayerNorm(input_dim)
|
||||
|
||||
def forward(self, x):
|
||||
attn_output = self.self_attention(x)
|
||||
attn_output = self.norm(attn_output + x)
|
||||
pooled_output = torch.mean(attn_output, dim=1)
|
||||
x = F.relu(self.fc1(pooled_output))
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
149
intention-classify/training/train.py
Normal file
149
intention-classify/training/train.py
Normal file
@ -0,0 +1,149 @@
|
||||
# train.py
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
import torch
|
||||
import json
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
from training.data_utils import (
|
||||
load_data,
|
||||
create_class_mappings,
|
||||
preprocess_data,
|
||||
TextDataset,
|
||||
NegativeSampleDataset,
|
||||
)
|
||||
from training.model import AttentionBasedModel
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def energy_score(logits, temperature=1.0):
|
||||
return -torch.logsumexp(logits / temperature, dim=1)
|
||||
|
||||
|
||||
def generate_noise(batch_size, seq_length, input_dim, device):
|
||||
return torch.randn(batch_size, seq_length, input_dim).to(device)
|
||||
|
||||
|
||||
def train_energy_model(
|
||||
model,
|
||||
train_loader,
|
||||
negative_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
num_epochs=10,
|
||||
margin=1.0,
|
||||
temperature=0.4,
|
||||
):
|
||||
model.train()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
negative_iter = iter(negative_loader)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
for _, (labels, embeddings) in enumerate(train_loader):
|
||||
embeddings = embeddings.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(embeddings)
|
||||
class_loss = criterion(outputs, labels)
|
||||
known_energy = energy_score(outputs, temperature)
|
||||
positive_margin = 0.0
|
||||
energy_loss_known = F.relu(known_energy - positive_margin).mean()
|
||||
|
||||
noise_embeddings = torch.randn_like(embeddings).to(device)
|
||||
noise_outputs = model(noise_embeddings)
|
||||
noise_energy = energy_score(noise_outputs, temperature)
|
||||
energy_loss_noise = F.relu(margin - noise_energy).mean()
|
||||
|
||||
try:
|
||||
negative_samples = next(negative_iter)
|
||||
except StopIteration:
|
||||
negative_iter = iter(negative_loader)
|
||||
negative_samples = next(negative_iter)
|
||||
negative_samples = negative_samples.to(device)
|
||||
negative_outputs = model(negative_samples)
|
||||
negative_energy = energy_score(negative_outputs, temperature)
|
||||
energy_loss_negative = F.relu(margin - negative_energy).mean()
|
||||
|
||||
total_energy_loss = (
|
||||
energy_loss_known + energy_loss_noise + energy_loss_negative
|
||||
)
|
||||
total_loss_batch = class_loss + total_energy_loss
|
||||
|
||||
total_loss_batch.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += total_loss_batch.item()
|
||||
|
||||
avg_loss = total_loss / len(train_loader)
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
from config import model_name, DIMENSIONS
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
data = load_data("data.json")
|
||||
class_to_idx, _ = create_class_mappings(data)
|
||||
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||
dataset = preprocess_data(data, embedding_map, tokenizer, class_to_idx)
|
||||
train_data, _ = train_test_split(dataset, test_size=0.2)
|
||||
|
||||
train_dataset = TextDataset(train_data)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn
|
||||
)
|
||||
|
||||
with open("noise.json", "r") as f:
|
||||
negative_samples_list = json.load(f)
|
||||
|
||||
negative_embedding_list = []
|
||||
for sentence in negative_samples_list:
|
||||
tokens = tokenizer.tokenize(sentence)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]
|
||||
embeddings = torch.tensor(embeddings)
|
||||
negative_embedding_list.append(embeddings)
|
||||
|
||||
negative_dataset = NegativeSampleDataset(negative_embedding_list)
|
||||
negative_loader = DataLoader(
|
||||
negative_dataset,
|
||||
batch_size=24,
|
||||
shuffle=True,
|
||||
collate_fn=negative_dataset.collate_fn,
|
||||
)
|
||||
|
||||
input_dim = DIMENSIONS
|
||||
num_classes = len(class_to_idx)
|
||||
model = AttentionBasedModel(input_dim, num_classes)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=7e-4)
|
||||
|
||||
train_energy_model(
|
||||
model, train_loader, negative_loader, criterion, optimizer, num_epochs=120
|
||||
)
|
||||
|
||||
torch.save(model.state_dict(), "model.pt")
|
||||
|
||||
dummy_input = torch.randn(1, 64, DIMENSIONS)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
"model.onnx",
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={
|
||||
"input": {0: "batch_size", 1: "seq_length"},
|
||||
"output": {0: "batch_size"},
|
||||
},
|
||||
opset_version=11,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
75
intention-classify/validation/inference.py
Normal file
75
intention-classify/validation/inference.py
Normal file
@ -0,0 +1,75 @@
|
||||
from training.model import AttentionBasedModel
|
||||
from training.config import model_name
|
||||
import json
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from training.config import DIMENSIONS
|
||||
from training.model import AttentionBasedModel
|
||||
|
||||
|
||||
def energy_score(logits):
|
||||
# Energy score is minus logsumexp
|
||||
return -torch.logsumexp(logits, dim=1)
|
||||
|
||||
|
||||
def predict_with_energy(
|
||||
model,
|
||||
sentence,
|
||||
embedding_map,
|
||||
tokenizer,
|
||||
idx_to_class,
|
||||
energy_threshold,
|
||||
max_length=64,
|
||||
):
|
||||
model.eval()
|
||||
tokens = tokenizer.tokenize(sentence)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
print(token_ids)
|
||||
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
|
||||
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
|
||||
current_shape = embeddings.shape
|
||||
|
||||
if current_shape[1] < 2:
|
||||
pad_size = 2 - current_shape[1]
|
||||
embeddings = F.pad(
|
||||
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(embeddings)
|
||||
print(logits)
|
||||
probabilities = F.softmax(logits, dim=1)
|
||||
max_prob, predicted = torch.max(probabilities, 1)
|
||||
|
||||
# Calculate energy score
|
||||
energy = energy_score(logits)
|
||||
|
||||
# If energy > threshold, consider the input as unknown class
|
||||
if energy.item() > energy_threshold:
|
||||
return ["Unknown", max_prob.item(), energy.item()]
|
||||
else:
|
||||
return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]
|
||||
|
||||
|
||||
with open("data.json", "r") as f:
|
||||
data = json.load(f)
|
||||
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
||||
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||
num_classes = len(class_to_idx)
|
||||
|
||||
input_dim = DIMENSIONS
|
||||
model = AttentionBasedModel(input_dim, num_classes)
|
||||
model.load_state_dict(torch.load("./model.pt"))
|
||||
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Example usage:
|
||||
ENERGY_THRESHOLD = 0
|
||||
sentence = "天气"
|
||||
energy_threshold = ENERGY_THRESHOLD
|
||||
predicted = predict_with_energy(
|
||||
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
|
||||
)
|
||||
print(f"Predicted: {predicted}")
|
@ -24,6 +24,7 @@ cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
|
||||
# Load the model and tokenizer
|
||||
model_name = 'jinaai/jina-embeddings-v2-base-zh'
|
||||
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||
model.to('cuda')
|
||||
|
||||
# Define file paths from command-line arguments
|
||||
file_a_path = args.file_a
|
||||
|
Loading…
Reference in New Issue
Block a user