sparkastML/translate-old/zh-en/dataloader/wikititle.py

42 lines
1.3 KiB
Python

import random
from torch.utils.data import Dataset
max_dataset_size = 100000
class Wikititle(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
with open(data_file, "rt", encoding="utf-8") as f:
total_lines = sum(1 for _ in f)
# 生成不重复的随机行号列表
random_line_numbers = random.sample(
range(total_lines), min(max_dataset_size, total_lines)
)
random_line_numbers.sort() # 排序以便按顺序读取文件
Data = []
current_line_number = 0
with open(data_file, "rt", encoding="utf-8") as f:
for idx, line in enumerate(f):
if current_line_number >= len(random_line_numbers):
break
if idx == random_line_numbers[current_line_number]:
zh, en = line.split("\t")
sample = {
"chinese": zh,
"english": en
}
Data.append(sample)
current_line_number += 1
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]