72 lines
1.9 KiB
Python
72 lines
1.9 KiB
Python
# data_utils.py
|
|
|
|
import json
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
|
|
|
|
def load_data(file_path):
|
|
with open(file_path, "r") as f:
|
|
data = json.load(f)
|
|
return data
|
|
|
|
|
|
def create_class_mappings(data):
|
|
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
|
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
|
return class_to_idx, idx_to_class
|
|
|
|
|
|
def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64):
|
|
dataset = []
|
|
for label, sentences in data.items():
|
|
for sentence in sentences:
|
|
tokens = tokenizer.tokenize(sentence)
|
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
|
embeddings = [
|
|
embedding_map[token_id] for token_id in token_ids[:max_length]
|
|
]
|
|
embeddings = torch.tensor(embeddings)
|
|
dataset.append((class_to_idx[label], embeddings))
|
|
return dataset
|
|
|
|
def get_sentences(data):
|
|
result = []
|
|
for _, sentences in data.items():
|
|
for sentence in sentences:
|
|
result.append(sentence)
|
|
return result
|
|
|
|
|
|
class TextDataset(Dataset):
|
|
def __init__(self, data):
|
|
self.data = data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.data[idx]
|
|
|
|
def collate_fn(self, batch):
|
|
labels, embeddings = zip(*batch)
|
|
labels = torch.tensor(labels)
|
|
embeddings = pad_sequence(embeddings, batch_first=True)
|
|
return labels, embeddings
|
|
|
|
|
|
class NegativeSampleDataset(Dataset):
|
|
def __init__(self, negative_samples):
|
|
self.samples = negative_samples
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.samples[idx]
|
|
|
|
def collate_fn(self, batch):
|
|
embeddings = pad_sequence(batch, batch_first=True)
|
|
return embeddings
|