{ "cells": [ { "cell_type": "markdown", "id": "94ff7007", "metadata": {}, "source": [ "# Convert to ONNX\n", "\n", "This notebook converts our model to [ONNX](https://onnx.ai/) format, which is the open standard for machine learning interoperability. In this way, we can run our model in JS (browser)" ] }, { "cell_type": "code", "execution_count": 1, "id": "311162fd-f957-4746-b524-25bb3e09efbc", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "from torch import nn\n", "import torch.utils.model_zoo as model_zoo\n", "import torch.onnx\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 2, "id": "fd182c6d-1e77-4bbb-bb53-8321d40ae002", "metadata": {}, "outputs": [], "source": [ "class TextCNN(nn.Module):\n", " def __init__(self, input_dim, num_classes):\n", " super(TextCNN, self).__init__()\n", " self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n", " self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n", " \n", " self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n", " self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n", " self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n", " \n", " self.dropout = nn.Dropout(0.5)\n", " self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n", "\n", " def forward(self, x):\n", " x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n", " \n", " x1 = F.relu(self.bn1(self.conv1(x)))\n", " x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n", " \n", " x2 = F.relu(self.bn2(self.conv2(x)))\n", " x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n", " \n", " x3 = F.relu(self.bn3(self.conv3(x)))\n", " x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n", " \n", " x = torch.cat((x1, x2, x3), dim=1)\n", " x = self.dropout(x)\n", " x = self.fc(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 3, "id": "bdb597cb-d896-485c-8c9c-897b1d35e8d2", "metadata": {}, "outputs": [], "source": [ "model = torch.load(\"model.pt\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "9f7a6e64-75f2-4fa9-8d1e-b83099765d02", "metadata": {}, "outputs": [], "source": [ "\n", "# Example input: use random embedding vector to simulate real input\n", "dummy_input = torch.randn(1, 64, 128) # (batch_size, seq_length, embedding_dim)\n", "\n", "# Export model\n", "torch.onnx.export(\n", " model, # The model to export\n", " dummy_input, # Example input\n", " \"model.onnx\", # File name\n", " input_names=['input'], # Input name (Could customize)\n", " output_names=['output'], # Output name (Could customize)\n", " dynamic_axes={\n", " 'input': {0: 'batch_size', 1: 'seq_length'}, # Dynamic batch and sequence length\n", " 'output': {0: 'batch_size'}\n", " },\n", " opset_version=11 # ONNX version,ensure the ONNX runtime supports it\n", ")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 }