sparkastML/translate-old/zh-en/dataloader/multiTrans19.py

49 lines
1.4 KiB
Python
Executable File

import json, random
from torch.utils.data import Dataset
max_dataset_size = 82000
class MultiTRANS19(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
Data = []
file_lines = []
with open(data_file, "rt", encoding="utf-8") as f:
file_lines = f.readlines()
combine_number_list = []
for _ in range(max_dataset_size):
num = random.randint(2, 7)
combine_number_list.append(num)
file_lines = random.sample(file_lines, sum(combine_number_list))
total = 0
for combine_count in combine_number_list:
num_combination = combine_number_list[combine_count]
sample = {
"chinese": "",
"english": ""
}
for line in file_lines[total: total+num_combination]:
try:
line_sample = json.loads(line.strip())
sample["chinese"] += line_sample["chinese"]
sample["english"] += line_sample["english"]
except json.JSONDecodeError as e:
print(f"Error decoding line: {e}")
Data.append(sample)
total+=num_combination
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]