add: open set validation
This commit is contained in:
parent
bf2c9a393a
commit
ae6f10a6f0
1
intention-classify/NLU_meta.json
Normal file
1
intention-classify/NLU_meta.json
Normal file
@ -0,0 +1 @@
|
||||
{"idx_to_class": {"0": "weather", "1": "base64", "2": "url-encode", "3": "html-encode", "4": "ai.command", "5": "knowledge", "6": "ai.question", "7": "datetime"}, "threshold": 1.7}
|
@ -40,7 +40,7 @@
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "11caef0e1b674f6ab15880f3f25eca6a",
|
||||
"model_id": "38137fc55ad24a9785ecbe1978bbc605",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@ -69,6 +69,122 @@
|
||||
"vocab = tokenizer.get_vocab()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "21214ff4-018d-4230-81b9-331ebb42773b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def bytes_to_unicode():\n",
|
||||
" \"\"\"\n",
|
||||
" Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control\n",
|
||||
" characters the bpe code barfs on.\n",
|
||||
"\n",
|
||||
" The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab\n",
|
||||
" if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for\n",
|
||||
" decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup\n",
|
||||
" tables between utf-8 bytes and unicode strings.\n",
|
||||
" \"\"\"\n",
|
||||
" bs = (\n",
|
||||
" list(range(ord(\"!\"), ord(\"~\") + 1)) + list(range(ord(\"¡\"), ord(\"¬\") + 1)) + list(range(ord(\"®\"), ord(\"ÿ\") + 1))\n",
|
||||
" )\n",
|
||||
" cs = bs[:]\n",
|
||||
" n = 0\n",
|
||||
" for b in range(2**8):\n",
|
||||
" if b not in bs:\n",
|
||||
" bs.append(b)\n",
|
||||
" cs.append(2**8 + n)\n",
|
||||
" n += 1\n",
|
||||
" cs = [chr(n) for n in cs]\n",
|
||||
" return dict(zip(bs, cs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "cbc23d2d-985b-443a-83ee-c2286046ad5e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"btu=bytes_to_unicode()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "4a99fa07-4922-4d8d-9c28-2275bf9cb8df",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"utb = reversed_dict = {value: key for key, value in btu.items()}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "cb218ea7-50c7-4bb8-aa7f-0ee85da76147",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = tokenizer.convert_ids_to_tokens([104307])[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "2dcb332a-cba9-4a14-9486-4e1ff6bd3dba",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"å\n",
|
||||
"229\n",
|
||||
"¤\n",
|
||||
"164\n",
|
||||
"©\n",
|
||||
"169\n",
|
||||
"æ\n",
|
||||
"230\n",
|
||||
"°\n",
|
||||
"176\n",
|
||||
"Ķ\n",
|
||||
"148\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"decoded=b\"\"\n",
|
||||
"for chr in result:\n",
|
||||
" print(chr)\n",
|
||||
" if chr in utb:\n",
|
||||
" print(utb[chr])\n",
|
||||
" decoded+=bytes([utb[chr]])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"id": "b1bf1289-2cab-4a97-ad21-b2d24de6d688",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'天气'"
|
||||
]
|
||||
},
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"decoded.decode(\"utf-8\", errors='replace')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
|
@ -31,6 +31,13 @@ def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64)
|
||||
dataset.append((class_to_idx[label], embeddings))
|
||||
return dataset
|
||||
|
||||
def get_sentences(data):
|
||||
result = []
|
||||
for _, sentences in data.items():
|
||||
for sentence in sentences:
|
||||
result.append(sentence)
|
||||
return result
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
|
@ -32,18 +32,21 @@ class SelfAttention(nn.Module):
|
||||
|
||||
|
||||
class AttentionBasedModel(nn.Module):
|
||||
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512):
|
||||
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512, num_layers=3):
|
||||
super(AttentionBasedModel, self).__init__()
|
||||
self.self_attention = SelfAttention(input_dim, heads)
|
||||
self.self_attention_layers = nn.ModuleList([
|
||||
SelfAttention(input_dim, heads) for _ in range(num_layers)
|
||||
])
|
||||
self.fc1 = nn.Linear(input_dim, dim_feedforward)
|
||||
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.norm = nn.LayerNorm(input_dim)
|
||||
|
||||
def forward(self, x):
|
||||
attn_output = self.self_attention(x)
|
||||
attn_output = self.norm(attn_output + x)
|
||||
pooled_output = torch.mean(attn_output, dim=1)
|
||||
for attn_layer in self.self_attention_layers:
|
||||
attn_output = attn_layer(x)
|
||||
x = self.norm(attn_output + x)
|
||||
pooled_output = torch.mean(x, dim=1)
|
||||
x = F.relu(self.fc1(pooled_output))
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
|
@ -66,8 +66,8 @@ embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Example usage:
|
||||
ENERGY_THRESHOLD = 0
|
||||
sentence = "天气"
|
||||
ENERGY_THRESHOLD = 2
|
||||
sentence = "what on earth is the cross entropy loss"
|
||||
energy_threshold = ENERGY_THRESHOLD
|
||||
predicted = predict_with_energy(
|
||||
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
|
||||
|
80
intention-classify/validation/openset_validation.py
Normal file
80
intention-classify/validation/openset_validation.py
Normal file
@ -0,0 +1,80 @@
|
||||
from training.model import AttentionBasedModel
|
||||
from training.config import model_name
|
||||
from training.config import DIMENSIONS
|
||||
from training.data_utils import get_sentences
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support
|
||||
|
||||
def energy_score(logits):
|
||||
# Energy score is minus logsumexp
|
||||
return -torch.logsumexp(logits, dim=1)
|
||||
|
||||
|
||||
def get_energy(
|
||||
model,
|
||||
sentence,
|
||||
embedding_map,
|
||||
tokenizer,
|
||||
max_length=64,
|
||||
):
|
||||
model.eval()
|
||||
tokens = tokenizer.tokenize(sentence)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
|
||||
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
|
||||
current_shape = embeddings.shape
|
||||
|
||||
if current_shape[1] < 2:
|
||||
pad_size = 2 - current_shape[1]
|
||||
embeddings = F.pad(
|
||||
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(embeddings)
|
||||
# Calculate energy score
|
||||
energy = energy_score(logits)
|
||||
|
||||
return energy
|
||||
|
||||
|
||||
with open("data.json", "r") as f:
|
||||
positive_data = json.load(f)
|
||||
class_to_idx = {cls: idx for idx, cls in enumerate(positive_data.keys())}
|
||||
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||
num_classes = len(class_to_idx)
|
||||
|
||||
with open("noise.json", "r") as f:
|
||||
negative_data = json.load(f)
|
||||
|
||||
input_dim = DIMENSIONS
|
||||
model = AttentionBasedModel(input_dim, num_classes)
|
||||
model.load_state_dict(torch.load("./model.pt"))
|
||||
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
ENERGY_THRESHOLD = 2
|
||||
for item in tqdm(get_sentences(positive_data)):
|
||||
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
|
||||
all_preds.append(result)
|
||||
all_labels.append(1)
|
||||
|
||||
for item in tqdm(negative_data):
|
||||
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
|
||||
all_preds.append(result)
|
||||
all_labels.append(0)
|
||||
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
|
||||
print(f'Accuracy: {accuracy:.4f}')
|
||||
print(f'Precision: {precision:.4f}')
|
||||
print(f'Recall: {recall:.4f}')
|
||||
print(f'F1 Score: {f1:.4f}')
|
Loading…
Reference in New Issue
Block a user