sparkastML/intention-classify/validation/openset_validation.py

80 lines
2.4 KiB
Python

from training.model import AttentionBasedModel
from training.config import model_name
from training.config import DIMENSIONS
from training.data_utils import get_sentences
import json
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support
def energy_score(logits):
# Energy score is minus logsumexp
return -torch.logsumexp(logits, dim=1)
def get_energy(
model,
sentence,
embedding_map,
tokenizer,
max_length=64,
):
model.eval()
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
current_shape = embeddings.shape
if current_shape[1] < 2:
pad_size = 2 - current_shape[1]
embeddings = F.pad(
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
)
with torch.no_grad():
logits = model(embeddings)
# Calculate energy score
energy = energy_score(logits)
return energy
with open("data.json", "r") as f:
positive_data = json.load(f)
class_to_idx = {cls: idx for idx, cls in enumerate(positive_data.keys())}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
num_classes = len(class_to_idx)
with open("noise.json", "r") as f:
negative_data = json.load(f)
input_dim = DIMENSIONS
model = AttentionBasedModel(input_dim, num_classes)
model.load_state_dict(torch.load("./model.pt"))
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
tokenizer = AutoTokenizer.from_pretrained(model_name)
all_preds = []
all_labels = []
ENERGY_THRESHOLD = 2
for item in tqdm(get_sentences(positive_data)):
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
all_preds.append(result)
all_labels.append(1)
for item in tqdm(negative_data):
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
all_preds.append(result)
all_labels.append(0)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
accuracy = accuracy_score(all_labels, all_preds)
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')