add: open set validation
This commit is contained in:
parent
bf2c9a393a
commit
ae6f10a6f0
1
intention-classify/NLU_meta.json
Normal file
1
intention-classify/NLU_meta.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"idx_to_class": {"0": "weather", "1": "base64", "2": "url-encode", "3": "html-encode", "4": "ai.command", "5": "knowledge", "6": "ai.question", "7": "datetime"}, "threshold": 1.7}
|
@ -40,7 +40,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "11caef0e1b674f6ab15880f3f25eca6a",
|
"model_id": "38137fc55ad24a9785ecbe1978bbc605",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@ -69,6 +69,122 @@
|
|||||||
"vocab = tokenizer.get_vocab()"
|
"vocab = tokenizer.get_vocab()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "21214ff4-018d-4230-81b9-331ebb42773b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def bytes_to_unicode():\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control\n",
|
||||||
|
" characters the bpe code barfs on.\n",
|
||||||
|
"\n",
|
||||||
|
" The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab\n",
|
||||||
|
" if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for\n",
|
||||||
|
" decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup\n",
|
||||||
|
" tables between utf-8 bytes and unicode strings.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" bs = (\n",
|
||||||
|
" list(range(ord(\"!\"), ord(\"~\") + 1)) + list(range(ord(\"¡\"), ord(\"¬\") + 1)) + list(range(ord(\"®\"), ord(\"ÿ\") + 1))\n",
|
||||||
|
" )\n",
|
||||||
|
" cs = bs[:]\n",
|
||||||
|
" n = 0\n",
|
||||||
|
" for b in range(2**8):\n",
|
||||||
|
" if b not in bs:\n",
|
||||||
|
" bs.append(b)\n",
|
||||||
|
" cs.append(2**8 + n)\n",
|
||||||
|
" n += 1\n",
|
||||||
|
" cs = [chr(n) for n in cs]\n",
|
||||||
|
" return dict(zip(bs, cs))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "cbc23d2d-985b-443a-83ee-c2286046ad5e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"btu=bytes_to_unicode()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "4a99fa07-4922-4d8d-9c28-2275bf9cb8df",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"utb = reversed_dict = {value: key for key, value in btu.items()}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"id": "cb218ea7-50c7-4bb8-aa7f-0ee85da76147",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"result = tokenizer.convert_ids_to_tokens([104307])[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"id": "2dcb332a-cba9-4a14-9486-4e1ff6bd3dba",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"å\n",
|
||||||
|
"229\n",
|
||||||
|
"¤\n",
|
||||||
|
"164\n",
|
||||||
|
"©\n",
|
||||||
|
"169\n",
|
||||||
|
"æ\n",
|
||||||
|
"230\n",
|
||||||
|
"°\n",
|
||||||
|
"176\n",
|
||||||
|
"Ķ\n",
|
||||||
|
"148\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"decoded=b\"\"\n",
|
||||||
|
"for chr in result:\n",
|
||||||
|
" print(chr)\n",
|
||||||
|
" if chr in utb:\n",
|
||||||
|
" print(utb[chr])\n",
|
||||||
|
" decoded+=bytes([utb[chr]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 29,
|
||||||
|
"id": "b1bf1289-2cab-4a97-ad21-b2d24de6d688",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'天气'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 29,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"decoded.decode(\"utf-8\", errors='replace')"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 5,
|
||||||
|
@ -31,6 +31,13 @@ def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64)
|
|||||||
dataset.append((class_to_idx[label], embeddings))
|
dataset.append((class_to_idx[label], embeddings))
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
def get_sentences(data):
|
||||||
|
result = []
|
||||||
|
for _, sentences in data.items():
|
||||||
|
for sentence in sentences:
|
||||||
|
result.append(sentence)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
|
@ -32,18 +32,21 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class AttentionBasedModel(nn.Module):
|
class AttentionBasedModel(nn.Module):
|
||||||
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512):
|
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512, num_layers=3):
|
||||||
super(AttentionBasedModel, self).__init__()
|
super(AttentionBasedModel, self).__init__()
|
||||||
self.self_attention = SelfAttention(input_dim, heads)
|
self.self_attention_layers = nn.ModuleList([
|
||||||
|
SelfAttention(input_dim, heads) for _ in range(num_layers)
|
||||||
|
])
|
||||||
self.fc1 = nn.Linear(input_dim, dim_feedforward)
|
self.fc1 = nn.Linear(input_dim, dim_feedforward)
|
||||||
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
||||||
self.dropout = nn.Dropout(0.5)
|
self.dropout = nn.Dropout(0.5)
|
||||||
self.norm = nn.LayerNorm(input_dim)
|
self.norm = nn.LayerNorm(input_dim)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
attn_output = self.self_attention(x)
|
for attn_layer in self.self_attention_layers:
|
||||||
attn_output = self.norm(attn_output + x)
|
attn_output = attn_layer(x)
|
||||||
pooled_output = torch.mean(attn_output, dim=1)
|
x = self.norm(attn_output + x)
|
||||||
|
pooled_output = torch.mean(x, dim=1)
|
||||||
x = F.relu(self.fc1(pooled_output))
|
x = F.relu(self.fc1(pooled_output))
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.fc2(x)
|
x = self.fc2(x)
|
||||||
|
@ -66,8 +66,8 @@ embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
|||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
# Example usage:
|
# Example usage:
|
||||||
ENERGY_THRESHOLD = 0
|
ENERGY_THRESHOLD = 2
|
||||||
sentence = "天气"
|
sentence = "what on earth is the cross entropy loss"
|
||||||
energy_threshold = ENERGY_THRESHOLD
|
energy_threshold = ENERGY_THRESHOLD
|
||||||
predicted = predict_with_energy(
|
predicted = predict_with_energy(
|
||||||
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
|
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
|
||||||
|
80
intention-classify/validation/openset_validation.py
Normal file
80
intention-classify/validation/openset_validation.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from training.model import AttentionBasedModel
|
||||||
|
from training.config import model_name
|
||||||
|
from training.config import DIMENSIONS
|
||||||
|
from training.data_utils import get_sentences
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from tqdm import tqdm
|
||||||
|
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support
|
||||||
|
|
||||||
|
def energy_score(logits):
|
||||||
|
# Energy score is minus logsumexp
|
||||||
|
return -torch.logsumexp(logits, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_energy(
|
||||||
|
model,
|
||||||
|
sentence,
|
||||||
|
embedding_map,
|
||||||
|
tokenizer,
|
||||||
|
max_length=64,
|
||||||
|
):
|
||||||
|
model.eval()
|
||||||
|
tokens = tokenizer.tokenize(sentence)
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
|
||||||
|
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
|
||||||
|
current_shape = embeddings.shape
|
||||||
|
|
||||||
|
if current_shape[1] < 2:
|
||||||
|
pad_size = 2 - current_shape[1]
|
||||||
|
embeddings = F.pad(
|
||||||
|
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(embeddings)
|
||||||
|
# Calculate energy score
|
||||||
|
energy = energy_score(logits)
|
||||||
|
|
||||||
|
return energy
|
||||||
|
|
||||||
|
|
||||||
|
with open("data.json", "r") as f:
|
||||||
|
positive_data = json.load(f)
|
||||||
|
class_to_idx = {cls: idx for idx, cls in enumerate(positive_data.keys())}
|
||||||
|
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||||
|
num_classes = len(class_to_idx)
|
||||||
|
|
||||||
|
with open("noise.json", "r") as f:
|
||||||
|
negative_data = json.load(f)
|
||||||
|
|
||||||
|
input_dim = DIMENSIONS
|
||||||
|
model = AttentionBasedModel(input_dim, num_classes)
|
||||||
|
model.load_state_dict(torch.load("./model.pt"))
|
||||||
|
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
|
||||||
|
all_preds = []
|
||||||
|
all_labels = []
|
||||||
|
ENERGY_THRESHOLD = 2
|
||||||
|
for item in tqdm(get_sentences(positive_data)):
|
||||||
|
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
|
||||||
|
all_preds.append(result)
|
||||||
|
all_labels.append(1)
|
||||||
|
|
||||||
|
for item in tqdm(negative_data):
|
||||||
|
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
|
||||||
|
all_preds.append(result)
|
||||||
|
all_labels.append(0)
|
||||||
|
|
||||||
|
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
|
||||||
|
accuracy = accuracy_score(all_labels, all_preds)
|
||||||
|
|
||||||
|
print(f'Accuracy: {accuracy:.4f}')
|
||||||
|
print(f'Precision: {precision:.4f}')
|
||||||
|
print(f'Recall: {recall:.4f}')
|
||||||
|
print(f'F1 Score: {f1:.4f}')
|
Loading…
Reference in New Issue
Block a user