{ "cells": [ { "cell_type": "markdown", "id": "a6a3195f-d099-4bf4-846f-51f403954818", "metadata": {}, "source": [ "# sparkastML: Training the Intention Classification Model\n", "\n", "This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n", "\n", "In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class." ] }, { "cell_type": "code", "execution_count": 1, "id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984", "metadata": {}, "outputs": [], "source": [ "import json\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "from torch.nn.utils.rnn import pad_sequence\n", "from sklearn.model_selection import train_test_split\n", "from transformers import AutoTokenizer, AutoModel\n", "import torch\n", "import numpy as np\n", "from scipy.spatial.distance import euclidean\n", "from scipy.stats import weibull_min\n", "from sklearn.preprocessing import normalize\n", "import torch.nn.functional as F\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] } ], "source": [ "model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n", "DIMENSIONS = 128\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "markdown", "id": "1ae14906-338d-4c99-87ed-bb1acd22b295", "metadata": {}, "source": [ "## Load Data\n", "\n", "We load the data from `data.json`, and also get the negative sample from the `noise.json`." ] }, { "cell_type": "code", "execution_count": 3, "id": "a206071c-ce4e-4de4-b936-bfc70d13708a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n", " embeddings = torch.tensor(embeddings)\n" ] } ], "source": [ "# Load data\n", "with open('data.json', 'r') as f:\n", " data = json.load(f)\n", "\n", "# Create map: class to index\n", "class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n", "idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n", "\n", "# Preprocess data, convert sentences to the format of (class idx, embedding)\n", "def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n", " dataset = []\n", " for label, sentences in data.items():\n", " for sentence in sentences:\n", " # Tokenize the sentence and convert tokens to embedding vectors\n", " tokens = tokenizer.tokenize(sentence)\n", " token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", " embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n", " embeddings = torch.tensor(embeddings)\n", " dataset.append((class_to_idx[label], embeddings))\n", " return dataset\n", "\n", "# Load embedding map\n", "embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n", "\n", "# Get preprocessed dataset\n", "dataset = preprocess_data(data, embedding_map, tokenizer)\n", "\n", "# Train-test split\n", "train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n", "\n", "class TextDataset(Dataset):\n", " def __init__(self, data):\n", " self.data = data\n", "\n", " def __len__(self):\n", " return len(self.data)\n", "\n", " def __getitem__(self, idx):\n", " return self.data[idx]\n", "\n", " def collate_fn(self, batch):\n", " labels, embeddings = zip(*batch)\n", " labels = torch.tensor(labels)\n", " embeddings = pad_sequence(embeddings, batch_first=True)\n", " return labels, embeddings\n", "\n", "train_dataset = TextDataset(train_data)\n", "val_dataset = TextDataset(val_data)\n", "\n", "train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n", "val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "class NegativeSampleDataset(Dataset):\n", " def __init__(self, negative_samples):\n", " \"\"\"\n", " negative_samples: List or array of negative sample embeddings or raw text\n", " \"\"\"\n", " self.samples = negative_samples\n", " \n", " def __len__(self):\n", " return len(self.samples)\n", " \n", " def __getitem__(self, idx):\n", " return self.samples[idx]\n", "\n", " def collate_fn(self, batch):\n", " embeddings = pad_sequence(batch, batch_first=True)\n", " return embeddings\n", "\n", "with open('noise.json', 'r') as f:\n", " negative_samples_list = json.load(f)\n", "\n", "negative_embedding_list = []\n", "\n", "for sentence in negative_samples_list:\n", " tokens = tokenizer.tokenize(sentence)\n", " token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", " embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n", " embeddings = torch.tensor(embeddings)\n", " negative_embedding_list.append(embeddings)\n", "\n", "negative_dataset = NegativeSampleDataset(negative_embedding_list)\n", "negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n" ] }, { "cell_type": "markdown", "id": "600febe4-2484-4aad-90a1-2bc821fdce1a", "metadata": {}, "source": [ "## Implementating the Model" ] }, { "cell_type": "code", "execution_count": 5, "id": "adf624ac-ad63-437b-95f6-b02b7253b91e", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class TextCNN(nn.Module):\n", " def __init__(self, input_dim, num_classes):\n", " super(TextCNN, self).__init__()\n", " self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n", " self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n", " self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n", " \n", " self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n", " self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n", " self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n", " \n", " self.dropout = nn.Dropout(0.5)\n", " self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n", "\n", " def forward(self, x):\n", " x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n", " \n", " x1 = F.relu(self.bn1(self.conv1(x)))\n", " x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n", " \n", " x2 = F.relu(self.bn2(self.conv2(x)))\n", " x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n", " \n", " x3 = F.relu(self.bn3(self.conv3(x)))\n", " x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n", " \n", " x = torch.cat((x1, x2, x3), dim=1)\n", " x = self.dropout(x)\n", " x = self.fc(x)\n", " return x\n", "\n", "# Initialize model\n", "input_dim = DIMENSIONS\n", "num_classes = len(class_to_idx)\n", "model = TextCNN(input_dim, num_classes)\n" ] }, { "cell_type": "markdown", "id": "2750e17d-8a60-40c7-851b-1a567d0ee82b", "metadata": {}, "source": [ "## Energy-based Models" ] }, { "cell_type": "code", "execution_count": 6, "id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb", "metadata": {}, "outputs": [], "source": [ "def energy_score(logits):\n", " # Energy score is minus logsumexp\n", " return -torch.logsumexp(logits, dim=1)\n", "\n", "def generate_noise(batch_size, seq_length ,input_dim, device):\n", " # Generate a Gaussian noise\n", " return torch.randn(batch_size, seq_length, input_dim).to(device)\n" ] }, { "cell_type": "markdown", "id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 7, "id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch [1/50], Loss: 12.5108\n", "Epoch [2/50], Loss: 10.7305\n", "Epoch [3/50], Loss: 10.2943\n", "Epoch [4/50], Loss: 9.9350\n", "Epoch [5/50], Loss: 9.7991\n", "Epoch [6/50], Loss: 9.6443\n", "Epoch [7/50], Loss: 9.4762\n", "Epoch [8/50], Loss: 9.4637\n", "Epoch [9/50], Loss: 9.3025\n", "Epoch [10/50], Loss: 9.1719\n", "Epoch [11/50], Loss: 9.0632\n", "Epoch [12/50], Loss: 8.9741\n", "Epoch [13/50], Loss: 8.8487\n", "Epoch [14/50], Loss: 8.6565\n", "Epoch [15/50], Loss: 8.5830\n", "Epoch [16/50], Loss: 8.4196\n", "Epoch [17/50], Loss: 8.2319\n", "Epoch [18/50], Loss: 8.0655\n", "Epoch [19/50], Loss: 7.7140\n", "Epoch [20/50], Loss: 7.6921\n", "Epoch [21/50], Loss: 7.3375\n", "Epoch [22/50], Loss: 7.2297\n", "Epoch [23/50], Loss: 6.8833\n", "Epoch [24/50], Loss: 6.8534\n", "Epoch [25/50], Loss: 6.4557\n", "Epoch [26/50], Loss: 6.1365\n", "Epoch [27/50], Loss: 5.8558\n", "Epoch [28/50], Loss: 5.5030\n", "Epoch [29/50], Loss: 5.1604\n", "Epoch [30/50], Loss: 4.7742\n", "Epoch [31/50], Loss: 4.5958\n", "Epoch [32/50], Loss: 4.0713\n", "Epoch [33/50], Loss: 3.8872\n", "Epoch [34/50], Loss: 3.5240\n", "Epoch [35/50], Loss: 3.3115\n", "Epoch [36/50], Loss: 2.5667\n", "Epoch [37/50], Loss: 2.6709\n", "Epoch [38/50], Loss: 1.8075\n", "Epoch [39/50], Loss: 1.6654\n", "Epoch [40/50], Loss: 0.4622\n", "Epoch [41/50], Loss: 0.4719\n", "Epoch [42/50], Loss: -0.4037\n", "Epoch [43/50], Loss: -0.9405\n", "Epoch [44/50], Loss: -1.7204\n", "Epoch [45/50], Loss: -2.4124\n", "Epoch [46/50], Loss: -3.0032\n", "Epoch [47/50], Loss: -2.7123\n", "Epoch [48/50], Loss: -3.6953\n", "Epoch [49/50], Loss: -3.7212\n", "Epoch [50/50], Loss: -3.7558\n" ] } ], "source": [ "import torch.optim as optim\n", "\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=8e-4)\n", "\n", "from torch.utils.tensorboard import SummaryWriter\n", "import tensorboard\n", "writer = SummaryWriter()\n", "\n", "def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n", " model.train()\n", " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", " model.to(device)\n", " \n", " negative_iter = iter(negative_loader)\n", " \n", " for epoch in range(num_epochs):\n", " total_loss = 0\n", " for batch_idx, (labels, embeddings) in enumerate(train_loader):\n", " embeddings = embeddings.to(device)\n", " labels = labels.to(device)\n", " \n", " batch_size = embeddings.size(0)\n", " \n", " # ---------------------\n", " # 1. Positive sample\n", " # ---------------------\n", " optimizer.zero_grad()\n", " outputs = model(embeddings) # logits from the model\n", " \n", " class_loss = criterion(outputs, labels)\n", " \n", " # Energy of positive sample\n", " known_energy = energy_score(outputs)\n", " energy_loss_known = known_energy.mean()\n", " \n", " # ------------------------------------\n", " # 2. Negative sample - Random Noise\n", " # ------------------------------------\n", " noise_embeddings = torch.randn_like(embeddings).to(device)\n", " noise_outputs = model(noise_embeddings)\n", " noise_energy = energy_score(noise_outputs)\n", " energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n", " \n", " # ------------------------------------\n", " # 3. Negative sample - custom corpus\n", " # ------------------------------------\n", " \n", " try:\n", " negative_samples = next(negative_iter)\n", " except StopIteration:\n", " negative_iter = iter(negative_loader)\n", " negative_samples = next(negative_iter)\n", " negative_samples = negative_samples.to(device)\n", " negative_outputs = model(negative_samples)\n", " negative_energy = energy_score(negative_outputs)\n", " energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n", " \n", " # -----------------------------\n", " # 4. Overall Loss calculation\n", " # -----------------------------\n", " total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n", " total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n", "\n", " writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n", " writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n", " writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n", " \n", " total_loss_batch.backward()\n", " optimizer.step()\n", " \n", " total_loss += total_loss_batch.item()\n", " \n", " avg_loss = total_loss / len(train_loader)\n", " print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n", "\n", "train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n", "writer.flush()\n" ] }, { "cell_type": "markdown", "id": "e6d29558-f497-4033-8488-169bd25ce881", "metadata": {}, "source": [ "## Evalutation" ] }, { "cell_type": "code", "execution_count": 8, "id": "472702e6-db4a-4faa-9e92-7510e6eacbb1", "metadata": {}, "outputs": [], "source": [ "ENERGY_THRESHOLD = -3" ] }, { "cell_type": "code", "execution_count": 9, "id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.9315\n", "Precision: 1.0000\n", "Recall: 0.9254\n", "F1 Score: 0.9612\n" ] } ], "source": [ "from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n", "\n", "def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n", " model.eval()\n", " device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", " \n", " all_preds = []\n", " all_labels = []\n", " \n", " # Evaluate positive sample\n", " with torch.no_grad():\n", " for labels, embeddings in known_loader:\n", " embeddings = embeddings.to(device)\n", " logits = model(embeddings)\n", " energy = energy_score(logits)\n", " \n", " preds = (energy <= energy_threshold).long()\n", " all_preds.extend(preds.cpu().numpy())\n", " all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n", " \n", " # Evaluate negative sample\n", " with torch.no_grad():\n", " for embeddings in unknown_loader:\n", " embeddings = embeddings.to(device)\n", " logits = model(embeddings)\n", " energy = energy_score(logits)\n", " \n", " preds = (energy <= energy_threshold).long()\n", " all_preds.extend(preds.cpu().numpy())\n", " all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n", " \n", " precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n", " accuracy = accuracy_score(all_labels, all_preds)\n", "\n", " print(f'Accuracy: {accuracy:.4f}')\n", " print(f'Precision: {precision:.4f}')\n", " print(f'Recall: {recall:.4f}')\n", " print(f'F1 Score: {f1:.4f}')\n", "\n", "evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)" ] }, { "cell_type": "code", "execution_count": 20, "id": "ba614054-75e1-4f61-ace5-aeb11e29a222", "metadata": {}, "outputs": [], "source": [ "# Save the model\n", "torch.save(model, \"model.pt\")" ] }, { "cell_type": "markdown", "id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": 19, "id": "03928d75-81c8-4298-ab8a-d7f8a758b561", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n" ] } ], "source": [ "def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n", " model.eval()\n", " tokens = tokenizer.tokenize(sentence)\n", " token_ids = tokenizer.convert_tokens_to_ids(tokens)\n", " embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n", " embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n", " \n", " with torch.no_grad():\n", " logits = model(embeddings)\n", " probabilities = F.softmax(logits, dim=1)\n", " max_prob, predicted = torch.max(probabilities, 1)\n", " \n", " # Calculate energy score\n", " energy = energy_score(logits)\n", "\n", " # If energy > threshold, consider the input as unknown class\n", " if energy.item() > energy_threshold:\n", " return [\"Unknown\", max_prob.item(), energy.item()]\n", " else:\n", " return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n", "\n", "# Example usage:\n", "sentence = \"weather today\"\n", "energy_threshold = ENERGY_THRESHOLD\n", "predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n", "print(f'Predicted: {predicted}')\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 5 }