49 lines
1.4 KiB
Python
Executable File
49 lines
1.4 KiB
Python
Executable File
import json, random
|
|
from torch.utils.data import Dataset
|
|
|
|
max_dataset_size = 82000
|
|
|
|
class MultiTRANS19(Dataset):
|
|
def __init__(self, data_file):
|
|
self.data = self.load_data(data_file)
|
|
|
|
def load_data(self, data_file):
|
|
Data = []
|
|
file_lines = []
|
|
|
|
with open(data_file, "rt", encoding="utf-8") as f:
|
|
file_lines = f.readlines()
|
|
|
|
combine_number_list = []
|
|
for _ in range(max_dataset_size):
|
|
num = random.randint(2, 7)
|
|
combine_number_list.append(num)
|
|
|
|
file_lines = random.sample(file_lines, sum(combine_number_list))
|
|
|
|
total = 0
|
|
for combine_count in combine_number_list:
|
|
num_combination = combine_number_list[combine_count]
|
|
sample = {
|
|
"chinese": "",
|
|
"english": ""
|
|
}
|
|
for line in file_lines[total: total+num_combination]:
|
|
try:
|
|
line_sample = json.loads(line.strip())
|
|
sample["chinese"] += line_sample["chinese"]
|
|
sample["english"] += line_sample["english"]
|
|
except json.JSONDecodeError as e:
|
|
print(f"Error decoding line: {e}")
|
|
|
|
Data.append(sample)
|
|
total+=num_combination
|
|
|
|
return Data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.data[idx]
|