42 lines
1.3 KiB
Python
42 lines
1.3 KiB
Python
import random
|
|
from torch.utils.data import Dataset
|
|
|
|
max_dataset_size = 100000
|
|
|
|
class Wikititle(Dataset):
|
|
def __init__(self, data_file):
|
|
self.data = self.load_data(data_file)
|
|
|
|
def load_data(self, data_file):
|
|
with open(data_file, "rt", encoding="utf-8") as f:
|
|
total_lines = sum(1 for _ in f)
|
|
|
|
# 生成不重复的随机行号列表
|
|
random_line_numbers = random.sample(
|
|
range(total_lines), min(max_dataset_size, total_lines)
|
|
)
|
|
random_line_numbers.sort() # 排序以便按顺序读取文件
|
|
|
|
Data = []
|
|
current_line_number = 0
|
|
|
|
with open(data_file, "rt", encoding="utf-8") as f:
|
|
for idx, line in enumerate(f):
|
|
if current_line_number >= len(random_line_numbers):
|
|
break
|
|
if idx == random_line_numbers[current_line_number]:
|
|
zh, en = line.split("\t")
|
|
sample = {
|
|
"chinese": zh,
|
|
"english": en
|
|
}
|
|
Data.append(sample)
|
|
current_line_number += 1
|
|
|
|
return Data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.data[idx] |