sparkastML/intention-classify/validation/inference.py

76 lines
2.3 KiB
Python

from training.model import AttentionBasedModel
from training.config import model_name
import json
import torch
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
from training.config import DIMENSIONS
from training.model import AttentionBasedModel
def energy_score(logits):
# Energy score is minus logsumexp
return -torch.logsumexp(logits, dim=1)
def predict_with_energy(
model,
sentence,
embedding_map,
tokenizer,
idx_to_class,
energy_threshold,
max_length=64,
):
model.eval()
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
current_shape = embeddings.shape
if current_shape[1] < 2:
pad_size = 2 - current_shape[1]
embeddings = F.pad(
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
)
with torch.no_grad():
logits = model(embeddings)
print(logits)
probabilities = F.softmax(logits, dim=1)
max_prob, predicted = torch.max(probabilities, 1)
# Calculate energy score
energy = energy_score(logits)
# If energy > threshold, consider the input as unknown class
if energy.item() > energy_threshold:
return ["Unknown", max_prob.item(), energy.item()]
else:
return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]
with open("data.json", "r") as f:
data = json.load(f)
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
num_classes = len(class_to_idx)
input_dim = DIMENSIONS
model = AttentionBasedModel(input_dim, num_classes)
model.load_state_dict(torch.load("./model.pt"))
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Example usage:
ENERGY_THRESHOLD = 2
sentence = "what on earth is the cross entropy loss"
energy_threshold = ENERGY_THRESHOLD
predicted = predict_with_energy(
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
)
print(f"Predicted: {predicted}")