add: open set validation

This commit is contained in:
alikia2x (寒寒) 2024-09-28 21:53:55 +08:00
parent bf2c9a393a
commit ae6f10a6f0
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
6 changed files with 215 additions and 8 deletions

View File

@ -0,0 +1 @@
{"idx_to_class": {"0": "weather", "1": "base64", "2": "url-encode", "3": "html-encode", "4": "ai.command", "5": "knowledge", "6": "ai.question", "7": "datetime"}, "threshold": 1.7}

View File

@ -40,7 +40,7 @@
{ {
"data": { "data": {
"application/vnd.jupyter.widget-view+json": { "application/vnd.jupyter.widget-view+json": {
"model_id": "11caef0e1b674f6ab15880f3f25eca6a", "model_id": "38137fc55ad24a9785ecbe1978bbc605",
"version_major": 2, "version_major": 2,
"version_minor": 0 "version_minor": 0
}, },
@ -69,6 +69,122 @@
"vocab = tokenizer.get_vocab()" "vocab = tokenizer.get_vocab()"
] ]
}, },
{
"cell_type": "code",
"execution_count": 8,
"id": "21214ff4-018d-4230-81b9-331ebb42773b",
"metadata": {},
"outputs": [],
"source": [
"def bytes_to_unicode():\n",
" \"\"\"\n",
" Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control\n",
" characters the bpe code barfs on.\n",
"\n",
" The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab\n",
" if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for\n",
" decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup\n",
" tables between utf-8 bytes and unicode strings.\n",
" \"\"\"\n",
" bs = (\n",
" list(range(ord(\"!\"), ord(\"~\") + 1)) + list(range(ord(\"¡\"), ord(\"¬\") + 1)) + list(range(ord(\"®\"), ord(\"ÿ\") + 1))\n",
" )\n",
" cs = bs[:]\n",
" n = 0\n",
" for b in range(2**8):\n",
" if b not in bs:\n",
" bs.append(b)\n",
" cs.append(2**8 + n)\n",
" n += 1\n",
" cs = [chr(n) for n in cs]\n",
" return dict(zip(bs, cs))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "cbc23d2d-985b-443a-83ee-c2286046ad5e",
"metadata": {},
"outputs": [],
"source": [
"btu=bytes_to_unicode()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "4a99fa07-4922-4d8d-9c28-2275bf9cb8df",
"metadata": {},
"outputs": [],
"source": [
"utb = reversed_dict = {value: key for key, value in btu.items()}"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "cb218ea7-50c7-4bb8-aa7f-0ee85da76147",
"metadata": {},
"outputs": [],
"source": [
"result = tokenizer.convert_ids_to_tokens([104307])[0]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "2dcb332a-cba9-4a14-9486-4e1ff6bd3dba",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"å\n",
"229\n",
"¤\n",
"164\n",
"©\n",
"169\n",
"æ\n",
"230\n",
"°\n",
"176\n",
"Ķ\n",
"148\n"
]
}
],
"source": [
"decoded=b\"\"\n",
"for chr in result:\n",
" print(chr)\n",
" if chr in utb:\n",
" print(utb[chr])\n",
" decoded+=bytes([utb[chr]])"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "b1bf1289-2cab-4a97-ad21-b2d24de6d688",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'天气'"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"decoded.decode(\"utf-8\", errors='replace')"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 5,

View File

@ -31,6 +31,13 @@ def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64)
dataset.append((class_to_idx[label], embeddings)) dataset.append((class_to_idx[label], embeddings))
return dataset return dataset
def get_sentences(data):
result = []
for _, sentences in data.items():
for sentence in sentences:
result.append(sentence)
return result
class TextDataset(Dataset): class TextDataset(Dataset):
def __init__(self, data): def __init__(self, data):

View File

@ -32,18 +32,21 @@ class SelfAttention(nn.Module):
class AttentionBasedModel(nn.Module): class AttentionBasedModel(nn.Module):
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512): def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512, num_layers=3):
super(AttentionBasedModel, self).__init__() super(AttentionBasedModel, self).__init__()
self.self_attention = SelfAttention(input_dim, heads) self.self_attention_layers = nn.ModuleList([
SelfAttention(input_dim, heads) for _ in range(num_layers)
])
self.fc1 = nn.Linear(input_dim, dim_feedforward) self.fc1 = nn.Linear(input_dim, dim_feedforward)
self.fc2 = nn.Linear(dim_feedforward, num_classes) self.fc2 = nn.Linear(dim_feedforward, num_classes)
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
self.norm = nn.LayerNorm(input_dim) self.norm = nn.LayerNorm(input_dim)
def forward(self, x): def forward(self, x):
attn_output = self.self_attention(x) for attn_layer in self.self_attention_layers:
attn_output = self.norm(attn_output + x) attn_output = attn_layer(x)
pooled_output = torch.mean(attn_output, dim=1) x = self.norm(attn_output + x)
pooled_output = torch.mean(x, dim=1)
x = F.relu(self.fc1(pooled_output)) x = F.relu(self.fc1(pooled_output))
x = self.dropout(x) x = self.dropout(x)
x = self.fc2(x) x = self.fc2(x)

View File

@ -66,8 +66,8 @@ embedding_map = torch.load("token_id_to_reduced_embedding.pt")
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
# Example usage: # Example usage:
ENERGY_THRESHOLD = 0 ENERGY_THRESHOLD = 2
sentence = "天气" sentence = "what on earth is the cross entropy loss"
energy_threshold = ENERGY_THRESHOLD energy_threshold = ENERGY_THRESHOLD
predicted = predict_with_energy( predicted = predict_with_energy(
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold

View File

@ -0,0 +1,80 @@
from training.model import AttentionBasedModel
from training.config import model_name
from training.config import DIMENSIONS
from training.data_utils import get_sentences
import json
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support
def energy_score(logits):
# Energy score is minus logsumexp
return -torch.logsumexp(logits, dim=1)
def get_energy(
model,
sentence,
embedding_map,
tokenizer,
max_length=64,
):
model.eval()
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
current_shape = embeddings.shape
if current_shape[1] < 2:
pad_size = 2 - current_shape[1]
embeddings = F.pad(
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
)
with torch.no_grad():
logits = model(embeddings)
# Calculate energy score
energy = energy_score(logits)
return energy
with open("data.json", "r") as f:
positive_data = json.load(f)
class_to_idx = {cls: idx for idx, cls in enumerate(positive_data.keys())}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
num_classes = len(class_to_idx)
with open("noise.json", "r") as f:
negative_data = json.load(f)
input_dim = DIMENSIONS
model = AttentionBasedModel(input_dim, num_classes)
model.load_state_dict(torch.load("./model.pt"))
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
tokenizer = AutoTokenizer.from_pretrained(model_name)
all_preds = []
all_labels = []
ENERGY_THRESHOLD = 2
for item in tqdm(get_sentences(positive_data)):
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
all_preds.append(result)
all_labels.append(1)
for item in tqdm(negative_data):
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
all_preds.append(result)
all_labels.append(0)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
accuracy = accuracy_score(all_labels, all_preds)
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')