# Convert to ONNX

This notebook converts our model to [ONNX](https://onnx.ai/) format, which is the open standard for machine learning interoperability. In this way, we can run our model in JS (browser)

In [1]:
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class TextCNN(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(TextCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)
        self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)
        
        self.bn1 = nn.BatchNorm1d(DIMENSIONS)
        self.bn2 = nn.BatchNorm1d(DIMENSIONS)
        self.bn3 = nn.BatchNorm1d(DIMENSIONS)
        
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(DIMENSIONS * 3, num_classes)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Change the input shape to (batch_size, embedding_dim, seq_length)
        
        x1 = F.relu(self.bn1(self.conv1(x)))
        x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)
        
        x2 = F.relu(self.bn2(self.conv2(x)))
        x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)
        
        x3 = F.relu(self.bn3(self.conv3(x)))
        x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)
        
        x = torch.cat((x1, x2, x3), dim=1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

In [3]:
model = torch.load("model.pt")

In [4]:

# Example input: use random embedding vector to simulate real input
dummy_input = torch.randn(1, 64, 128)  # (batch_size, seq_length, embedding_dim)

# Export model
torch.onnx.export(
    model,                       # The model to export
    dummy_input,                 # Example input
    "model.onnx",                # File name
    input_names=['input'],       # Input name (Could customize)
    output_names=['output'],     # Output name (Could customize)
    dynamic_axes={
        'input': {0: 'batch_size', 1: 'seq_length'},  # Dynamic batch and sequence length
        'output': {0: 'batch_size'}
    },
    opset_version=11             # ONNX versionï¼Œensure the ONNX runtime supports it
)
