76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
from training.model import AttentionBasedModel
|
|
from training.config import model_name
|
|
import json
|
|
import torch
|
|
from transformers import AutoTokenizer
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from training.config import DIMENSIONS
|
|
from training.model import AttentionBasedModel
|
|
|
|
|
|
def energy_score(logits):
|
|
# Energy score is minus logsumexp
|
|
return -torch.logsumexp(logits, dim=1)
|
|
|
|
|
|
def predict_with_energy(
|
|
model,
|
|
sentence,
|
|
embedding_map,
|
|
tokenizer,
|
|
idx_to_class,
|
|
energy_threshold,
|
|
max_length=64,
|
|
):
|
|
model.eval()
|
|
tokens = tokenizer.tokenize(sentence)
|
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
|
print(token_ids)
|
|
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
|
|
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
|
|
current_shape = embeddings.shape
|
|
|
|
if current_shape[1] < 2:
|
|
pad_size = 2 - current_shape[1]
|
|
embeddings = F.pad(
|
|
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
|
|
)
|
|
|
|
with torch.no_grad():
|
|
logits = model(embeddings)
|
|
print(logits)
|
|
probabilities = F.softmax(logits, dim=1)
|
|
max_prob, predicted = torch.max(probabilities, 1)
|
|
|
|
# Calculate energy score
|
|
energy = energy_score(logits)
|
|
|
|
# If energy > threshold, consider the input as unknown class
|
|
if energy.item() > energy_threshold:
|
|
return ["Unknown", max_prob.item(), energy.item()]
|
|
else:
|
|
return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]
|
|
|
|
|
|
with open("data.json", "r") as f:
|
|
data = json.load(f)
|
|
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
|
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
|
num_classes = len(class_to_idx)
|
|
|
|
input_dim = DIMENSIONS
|
|
model = AttentionBasedModel(input_dim, num_classes)
|
|
model.load_state_dict(torch.load("./model.pt"))
|
|
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
# Example usage:
|
|
ENERGY_THRESHOLD = 2
|
|
sentence = "what on earth is the cross entropy loss"
|
|
energy_threshold = ENERGY_THRESHOLD
|
|
predicted = predict_with_energy(
|
|
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
|
|
)
|
|
print(f"Predicted: {predicted}")
|