576 lines
20 KiB
Plaintext
576 lines
20 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a6a3195f-d099-4bf4-846f-51f403954818",
|
|
"metadata": {},
|
|
"source": [
|
|
"# sparkastML: Training the Intention Classification Model\n",
|
|
"\n",
|
|
"This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n",
|
|
"\n",
|
|
"In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"import torch\n",
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
"from torch.nn.utils.rnn import pad_sequence\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"from transformers import AutoTokenizer, AutoModel\n",
|
|
"import torch\n",
|
|
"import numpy as np\n",
|
|
"from scipy.spatial.distance import euclidean\n",
|
|
"from scipy.stats import weibull_min\n",
|
|
"from sklearn.preprocessing import normalize\n",
|
|
"import torch.nn.functional as F\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n",
|
|
"DIMENSIONS = 128\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1ae14906-338d-4c99-87ed-bb1acd22b295",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load Data\n",
|
|
"\n",
|
|
"We load the data from `data.json`, and also get the negative sample from the `noise.json`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "a206071c-ce4e-4de4-b936-bfc70d13708a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
|
|
" embeddings = torch.tensor(embeddings)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Load data\n",
|
|
"with open('data.json', 'r') as f:\n",
|
|
" data = json.load(f)\n",
|
|
"\n",
|
|
"# Create map: class to index\n",
|
|
"class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n",
|
|
"idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n",
|
|
"\n",
|
|
"# Preprocess data, convert sentences to the format of (class idx, embedding)\n",
|
|
"def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n",
|
|
" dataset = []\n",
|
|
" for label, sentences in data.items():\n",
|
|
" for sentence in sentences:\n",
|
|
" # Tokenize the sentence and convert tokens to embedding vectors\n",
|
|
" tokens = tokenizer.tokenize(sentence)\n",
|
|
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
|
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
|
" embeddings = torch.tensor(embeddings)\n",
|
|
" dataset.append((class_to_idx[label], embeddings))\n",
|
|
" return dataset\n",
|
|
"\n",
|
|
"# Load embedding map\n",
|
|
"embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n",
|
|
"\n",
|
|
"# Get preprocessed dataset\n",
|
|
"dataset = preprocess_data(data, embedding_map, tokenizer)\n",
|
|
"\n",
|
|
"# Train-test split\n",
|
|
"train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n",
|
|
"\n",
|
|
"class TextDataset(Dataset):\n",
|
|
" def __init__(self, data):\n",
|
|
" self.data = data\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return len(self.data)\n",
|
|
"\n",
|
|
" def __getitem__(self, idx):\n",
|
|
" return self.data[idx]\n",
|
|
"\n",
|
|
" def collate_fn(self, batch):\n",
|
|
" labels, embeddings = zip(*batch)\n",
|
|
" labels = torch.tensor(labels)\n",
|
|
" embeddings = pad_sequence(embeddings, batch_first=True)\n",
|
|
" return labels, embeddings\n",
|
|
"\n",
|
|
"train_dataset = TextDataset(train_data)\n",
|
|
"val_dataset = TextDataset(val_data)\n",
|
|
"\n",
|
|
"train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n",
|
|
"val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch.utils.data import Dataset, DataLoader\n",
|
|
"\n",
|
|
"class NegativeSampleDataset(Dataset):\n",
|
|
" def __init__(self, negative_samples):\n",
|
|
" \"\"\"\n",
|
|
" negative_samples: List or array of negative sample embeddings or raw text\n",
|
|
" \"\"\"\n",
|
|
" self.samples = negative_samples\n",
|
|
" \n",
|
|
" def __len__(self):\n",
|
|
" return len(self.samples)\n",
|
|
" \n",
|
|
" def __getitem__(self, idx):\n",
|
|
" return self.samples[idx]\n",
|
|
"\n",
|
|
" def collate_fn(self, batch):\n",
|
|
" embeddings = pad_sequence(batch, batch_first=True)\n",
|
|
" return embeddings\n",
|
|
"\n",
|
|
"with open('noise.json', 'r') as f:\n",
|
|
" negative_samples_list = json.load(f)\n",
|
|
"\n",
|
|
"negative_embedding_list = []\n",
|
|
"\n",
|
|
"for sentence in negative_samples_list:\n",
|
|
" tokens = tokenizer.tokenize(sentence)\n",
|
|
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
|
" embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n",
|
|
" embeddings = torch.tensor(embeddings)\n",
|
|
" negative_embedding_list.append(embeddings)\n",
|
|
"\n",
|
|
"negative_dataset = NegativeSampleDataset(negative_embedding_list)\n",
|
|
"negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "600febe4-2484-4aad-90a1-2bc821fdce1a",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Implementating the Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "adf624ac-ad63-437b-95f6-b02b7253b91e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"\n",
|
|
"class TextCNN(nn.Module):\n",
|
|
" def __init__(self, input_dim, num_classes):\n",
|
|
" super(TextCNN, self).__init__()\n",
|
|
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
|
|
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
|
|
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
|
|
" \n",
|
|
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
|
|
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
|
|
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
|
|
" \n",
|
|
" self.dropout = nn.Dropout(0.5)\n",
|
|
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
|
|
" \n",
|
|
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
|
|
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
|
|
" \n",
|
|
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
|
|
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
|
|
" \n",
|
|
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
|
|
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
|
|
" \n",
|
|
" x = torch.cat((x1, x2, x3), dim=1)\n",
|
|
" x = self.dropout(x)\n",
|
|
" x = self.fc(x)\n",
|
|
" return x\n",
|
|
"\n",
|
|
"# Initialize model\n",
|
|
"input_dim = DIMENSIONS\n",
|
|
"num_classes = len(class_to_idx)\n",
|
|
"model = TextCNN(input_dim, num_classes)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2750e17d-8a60-40c7-851b-1a567d0ee82b",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Energy-based Models"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def energy_score(logits):\n",
|
|
" # Energy score is minus logsumexp\n",
|
|
" return -torch.logsumexp(logits, dim=1)\n",
|
|
"\n",
|
|
"def generate_noise(batch_size, seq_length ,input_dim, device):\n",
|
|
" # Generate a Gaussian noise\n",
|
|
" return torch.randn(batch_size, seq_length, input_dim).to(device)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch [1/50], Loss: 12.5108\n",
|
|
"Epoch [2/50], Loss: 10.7305\n",
|
|
"Epoch [3/50], Loss: 10.2943\n",
|
|
"Epoch [4/50], Loss: 9.9350\n",
|
|
"Epoch [5/50], Loss: 9.7991\n",
|
|
"Epoch [6/50], Loss: 9.6443\n",
|
|
"Epoch [7/50], Loss: 9.4762\n",
|
|
"Epoch [8/50], Loss: 9.4637\n",
|
|
"Epoch [9/50], Loss: 9.3025\n",
|
|
"Epoch [10/50], Loss: 9.1719\n",
|
|
"Epoch [11/50], Loss: 9.0632\n",
|
|
"Epoch [12/50], Loss: 8.9741\n",
|
|
"Epoch [13/50], Loss: 8.8487\n",
|
|
"Epoch [14/50], Loss: 8.6565\n",
|
|
"Epoch [15/50], Loss: 8.5830\n",
|
|
"Epoch [16/50], Loss: 8.4196\n",
|
|
"Epoch [17/50], Loss: 8.2319\n",
|
|
"Epoch [18/50], Loss: 8.0655\n",
|
|
"Epoch [19/50], Loss: 7.7140\n",
|
|
"Epoch [20/50], Loss: 7.6921\n",
|
|
"Epoch [21/50], Loss: 7.3375\n",
|
|
"Epoch [22/50], Loss: 7.2297\n",
|
|
"Epoch [23/50], Loss: 6.8833\n",
|
|
"Epoch [24/50], Loss: 6.8534\n",
|
|
"Epoch [25/50], Loss: 6.4557\n",
|
|
"Epoch [26/50], Loss: 6.1365\n",
|
|
"Epoch [27/50], Loss: 5.8558\n",
|
|
"Epoch [28/50], Loss: 5.5030\n",
|
|
"Epoch [29/50], Loss: 5.1604\n",
|
|
"Epoch [30/50], Loss: 4.7742\n",
|
|
"Epoch [31/50], Loss: 4.5958\n",
|
|
"Epoch [32/50], Loss: 4.0713\n",
|
|
"Epoch [33/50], Loss: 3.8872\n",
|
|
"Epoch [34/50], Loss: 3.5240\n",
|
|
"Epoch [35/50], Loss: 3.3115\n",
|
|
"Epoch [36/50], Loss: 2.5667\n",
|
|
"Epoch [37/50], Loss: 2.6709\n",
|
|
"Epoch [38/50], Loss: 1.8075\n",
|
|
"Epoch [39/50], Loss: 1.6654\n",
|
|
"Epoch [40/50], Loss: 0.4622\n",
|
|
"Epoch [41/50], Loss: 0.4719\n",
|
|
"Epoch [42/50], Loss: -0.4037\n",
|
|
"Epoch [43/50], Loss: -0.9405\n",
|
|
"Epoch [44/50], Loss: -1.7204\n",
|
|
"Epoch [45/50], Loss: -2.4124\n",
|
|
"Epoch [46/50], Loss: -3.0032\n",
|
|
"Epoch [47/50], Loss: -2.7123\n",
|
|
"Epoch [48/50], Loss: -3.6953\n",
|
|
"Epoch [49/50], Loss: -3.7212\n",
|
|
"Epoch [50/50], Loss: -3.7558\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch.optim as optim\n",
|
|
"\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"optimizer = optim.Adam(model.parameters(), lr=8e-4)\n",
|
|
"\n",
|
|
"from torch.utils.tensorboard import SummaryWriter\n",
|
|
"import tensorboard\n",
|
|
"writer = SummaryWriter()\n",
|
|
"\n",
|
|
"def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n",
|
|
" model.train()\n",
|
|
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
" model.to(device)\n",
|
|
" \n",
|
|
" negative_iter = iter(negative_loader)\n",
|
|
" \n",
|
|
" for epoch in range(num_epochs):\n",
|
|
" total_loss = 0\n",
|
|
" for batch_idx, (labels, embeddings) in enumerate(train_loader):\n",
|
|
" embeddings = embeddings.to(device)\n",
|
|
" labels = labels.to(device)\n",
|
|
" \n",
|
|
" batch_size = embeddings.size(0)\n",
|
|
" \n",
|
|
" # ---------------------\n",
|
|
" # 1. Positive sample\n",
|
|
" # ---------------------\n",
|
|
" optimizer.zero_grad()\n",
|
|
" outputs = model(embeddings) # logits from the model\n",
|
|
" \n",
|
|
" class_loss = criterion(outputs, labels)\n",
|
|
" \n",
|
|
" # Energy of positive sample\n",
|
|
" known_energy = energy_score(outputs)\n",
|
|
" energy_loss_known = known_energy.mean()\n",
|
|
" \n",
|
|
" # ------------------------------------\n",
|
|
" # 2. Negative sample - Random Noise\n",
|
|
" # ------------------------------------\n",
|
|
" noise_embeddings = torch.randn_like(embeddings).to(device)\n",
|
|
" noise_outputs = model(noise_embeddings)\n",
|
|
" noise_energy = energy_score(noise_outputs)\n",
|
|
" energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n",
|
|
" \n",
|
|
" # ------------------------------------\n",
|
|
" # 3. Negative sample - custom corpus\n",
|
|
" # ------------------------------------\n",
|
|
" \n",
|
|
" try:\n",
|
|
" negative_samples = next(negative_iter)\n",
|
|
" except StopIteration:\n",
|
|
" negative_iter = iter(negative_loader)\n",
|
|
" negative_samples = next(negative_iter)\n",
|
|
" negative_samples = negative_samples.to(device)\n",
|
|
" negative_outputs = model(negative_samples)\n",
|
|
" negative_energy = energy_score(negative_outputs)\n",
|
|
" energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n",
|
|
" \n",
|
|
" # -----------------------------\n",
|
|
" # 4. Overall Loss calculation\n",
|
|
" # -----------------------------\n",
|
|
" total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n",
|
|
" total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n",
|
|
"\n",
|
|
" writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n",
|
|
" writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n",
|
|
" writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n",
|
|
" \n",
|
|
" total_loss_batch.backward()\n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" total_loss += total_loss_batch.item()\n",
|
|
" \n",
|
|
" avg_loss = total_loss / len(train_loader)\n",
|
|
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n",
|
|
"\n",
|
|
"train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n",
|
|
"writer.flush()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e6d29558-f497-4033-8488-169bd25ce881",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Evalutation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "472702e6-db4a-4faa-9e92-7510e6eacbb1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ENERGY_THRESHOLD = -3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Accuracy: 0.9315\n",
|
|
"Precision: 1.0000\n",
|
|
"Recall: 0.9254\n",
|
|
"F1 Score: 0.9612\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n",
|
|
"\n",
|
|
"def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n",
|
|
" model.eval()\n",
|
|
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
" \n",
|
|
" all_preds = []\n",
|
|
" all_labels = []\n",
|
|
" \n",
|
|
" # Evaluate positive sample\n",
|
|
" with torch.no_grad():\n",
|
|
" for labels, embeddings in known_loader:\n",
|
|
" embeddings = embeddings.to(device)\n",
|
|
" logits = model(embeddings)\n",
|
|
" energy = energy_score(logits)\n",
|
|
" \n",
|
|
" preds = (energy <= energy_threshold).long()\n",
|
|
" all_preds.extend(preds.cpu().numpy())\n",
|
|
" all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n",
|
|
" \n",
|
|
" # Evaluate negative sample\n",
|
|
" with torch.no_grad():\n",
|
|
" for embeddings in unknown_loader:\n",
|
|
" embeddings = embeddings.to(device)\n",
|
|
" logits = model(embeddings)\n",
|
|
" energy = energy_score(logits)\n",
|
|
" \n",
|
|
" preds = (energy <= energy_threshold).long()\n",
|
|
" all_preds.extend(preds.cpu().numpy())\n",
|
|
" all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n",
|
|
" \n",
|
|
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
|
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
|
"\n",
|
|
" print(f'Accuracy: {accuracy:.4f}')\n",
|
|
" print(f'Precision: {precision:.4f}')\n",
|
|
" print(f'Recall: {recall:.4f}')\n",
|
|
" print(f'F1 Score: {f1:.4f}')\n",
|
|
"\n",
|
|
"evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "ba614054-75e1-4f61-ace5-aeb11e29a222",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save the model\n",
|
|
"torch.save(model, \"model.pt\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Inference"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "03928d75-81c8-4298-ab8a-d7f8a758b561",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n",
|
|
" model.eval()\n",
|
|
" tokens = tokenizer.tokenize(sentence)\n",
|
|
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
|
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
|
" embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" logits = model(embeddings)\n",
|
|
" probabilities = F.softmax(logits, dim=1)\n",
|
|
" max_prob, predicted = torch.max(probabilities, 1)\n",
|
|
" \n",
|
|
" # Calculate energy score\n",
|
|
" energy = energy_score(logits)\n",
|
|
"\n",
|
|
" # If energy > threshold, consider the input as unknown class\n",
|
|
" if energy.item() > energy_threshold:\n",
|
|
" return [\"Unknown\", max_prob.item(), energy.item()]\n",
|
|
" else:\n",
|
|
" return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n",
|
|
"\n",
|
|
"# Example usage:\n",
|
|
"sentence = \"weather today\"\n",
|
|
"energy_threshold = ENERGY_THRESHOLD\n",
|
|
"predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n",
|
|
"print(f'Predicted: {predicted}')\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.19"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|