cvsa/filter/dataset.py

112 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import torch
import os
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
def split_and_save_data(file_path, test_size=0.1, eval_size=0.1, random_state=42):
"""分割数据并保存test、train和eval集"""
# 读取原始数据
with open(file_path, 'r', encoding='utf-8') as f:
all_examples = [json.loads(line) for line in f]
# 检查test文件是否存在
test_file = os.path.join(os.path.dirname(file_path), 'test.jsonl')
if not os.path.exists(test_file):
# 如果test文件不存在分割test数据并保存
train_eval_examples, test_examples = train_test_split(
all_examples, test_size=test_size, random_state=random_state
)
# 保存test集
with open(test_file, 'w', encoding='utf-8') as f:
for example in test_examples:
f.write(json.dumps(example, ensure_ascii=False) + '\n')
else:
# 如果test文件存在直接读取test数据
with open(test_file, 'r', encoding='utf-8') as f:
test_examples = [json.loads(line) for line in f]
# 剩余的作为train_eval_examples
test_ids = set(json.dumps(example, sort_keys=True) for example in test_examples)
train_eval_examples = [example for example in all_examples if json.dumps(example, sort_keys=True) not in test_ids]
# 分割train和eval数据
train_examples, eval_examples = train_test_split(
train_eval_examples, test_size=eval_size, random_state=random_state
)
# 保存train集
train_file = os.path.join(os.path.dirname(file_path), 'train.jsonl')
with open(train_file, 'w', encoding='utf-8') as f:
for example in train_examples:
f.write(json.dumps(example, ensure_ascii=False) + '\n')
# 保存eval集
eval_file = os.path.join(os.path.dirname(file_path), 'eval.jsonl')
with open(eval_file, 'w', encoding='utf-8') as f:
for example in eval_examples:
f.write(json.dumps(example, ensure_ascii=False) + '\n')
return train_examples, eval_examples, test_examples
class MultiChannelDataset(Dataset):
def __init__(self, file_path, max_length=128, mode='train', test_size=0.1, eval_size=0.1, random_state=42):
self.max_length = max_length
self.mode = mode
# 检查train、eval和test文件是否存在
train_file = os.path.join(os.path.dirname(file_path), 'train.jsonl')
eval_file = os.path.join(os.path.dirname(file_path), 'eval.jsonl')
test_file = os.path.join(os.path.dirname(file_path), 'test.jsonl')
if os.path.exists(train_file) and os.path.exists(eval_file) and os.path.exists(test_file):
# 如果文件存在,直接读取对应的数据
if self.mode == 'train':
with open(train_file, 'r', encoding='utf-8') as f:
self.examples = [json.loads(line) for line in f]
elif self.mode == 'eval':
with open(eval_file, 'r', encoding='utf-8') as f:
self.examples = [json.loads(line) for line in f]
elif self.mode == 'test':
with open(test_file, 'r', encoding='utf-8') as f:
self.examples = [json.loads(line) for line in f]
else:
raise ValueError("Invalid mode. Choose from 'train', 'eval', or 'test'.")
else:
# 如果文件不存在,执行分割并保存文件
train_examples, eval_examples, test_examples = split_and_save_data(
file_path, test_size, eval_size, random_state
)
# 根据mode选择对应的数据
if self.mode == 'train':
self.examples = train_examples
elif self.mode == 'eval':
self.examples = eval_examples
elif self.mode == 'test':
self.examples = test_examples
else:
raise ValueError("Invalid mode. Choose from 'train', 'eval', or 'test'.")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
example = self.examples[idx]
# 处理tags将数组转换为空格分隔的字符串
tags_text = ",".join(example['tags'])
# 返回文本字典
texts = {
'title': example['title'],
'description': example['description'],
'tags': tags_text
}
return {
'texts': texts, # 文本字典
'label': torch.tensor(example['label'], dtype=torch.long) # 标签
}