127 lines
4.0 KiB
Plaintext
127 lines
4.0 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "94ff7007",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Convert to ONNX\n",
|
||
"\n",
|
||
"This notebook converts our model to [ONNX](https://onnx.ai/) format, which is the open standard for machine learning interoperability. In this way, we can run our model in JS (browser)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "311162fd-f957-4746-b524-25bb3e09efbc",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"from torch import nn\n",
|
||
"import torch.utils.model_zoo as model_zoo\n",
|
||
"import torch.onnx\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "fd182c6d-1e77-4bbb-bb53-8321d40ae002",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class TextCNN(nn.Module):\n",
|
||
" def __init__(self, input_dim, num_classes):\n",
|
||
" super(TextCNN, self).__init__()\n",
|
||
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
|
||
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
|
||
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
|
||
" \n",
|
||
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||
" \n",
|
||
" self.dropout = nn.Dropout(0.5)\n",
|
||
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
|
||
" \n",
|
||
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
|
||
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
|
||
" \n",
|
||
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
|
||
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
|
||
" \n",
|
||
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
|
||
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
|
||
" \n",
|
||
" x = torch.cat((x1, x2, x3), dim=1)\n",
|
||
" x = self.dropout(x)\n",
|
||
" x = self.fc(x)\n",
|
||
" return x"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "bdb597cb-d896-485c-8c9c-897b1d35e8d2",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"model = torch.load(\"model.pt\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "9f7a6e64-75f2-4fa9-8d1e-b83099765d02",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\n",
|
||
"# Example input: use random embedding vector to simulate real input\n",
|
||
"dummy_input = torch.randn(1, 64, 128) # (batch_size, seq_length, embedding_dim)\n",
|
||
"\n",
|
||
"# Export model\n",
|
||
"torch.onnx.export(\n",
|
||
" model, # The model to export\n",
|
||
" dummy_input, # Example input\n",
|
||
" \"model.onnx\", # File name\n",
|
||
" input_names=['input'], # Input name (Could customize)\n",
|
||
" output_names=['output'], # Output name (Could customize)\n",
|
||
" dynamic_axes={\n",
|
||
" 'input': {0: 'batch_size', 1: 'seq_length'}, # Dynamic batch and sequence length\n",
|
||
" 'output': {0: 'batch_size'}\n",
|
||
" },\n",
|
||
" opset_version=11 # ONNX version,ensure the ONNX runtime supports it\n",
|
||
")\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.19"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|