In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import json
import numpy as np
from sacrebleu.metrics import BLEU
from tqdm import tqdm

In [1]:
# 定义参数
checkpoint_path = "./step_137000_valid_bleu_25.55_model_weights.bin"  # 假设你要加载第2个epoch中的500步的checkpoint
data_file = "./data/translation2019zh/translation2019zh_valid.json"  # 假设使用验证集来测试
model_checkpoint = "Helsinki-NLP/opus-mt-zh-en"
max_dataset_size = 100
max_input_length = 128
max_target_length = 128
batch_size = 8

In [3]:
class TRANS(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)
    
    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                if idx >= max_dataset_size:
                    break
                sample = json.loads(line.strip())
                Data[idx] = sample
        return Data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [4]:
def collote_fn(batch_samples):
    batch_inputs, batch_targets = [], []
    for sample in batch_samples:
        batch_inputs.append(sample['chinese'])
        batch_targets.append(sample['english'])
    batch_data = tokenizer(
        batch_inputs, 
        padding=True, 
        max_length=max_input_length,
        truncation=True, 
        return_tensors="pt"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch_targets, 
            padding=True, 
            max_length=max_target_length,
            truncation=True, 
            return_tensors="pt"
        )["input_ids"]
        batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
        end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
        for idx, end_idx in enumerate(end_token_index):
            labels[idx][end_idx+1:] = -100
        batch_data['labels'] = labels
    return batch_data


In [11]:
# 加载模型和tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

# 加载checkpoint
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()

# 将模型转移到设备
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model = model.to(device)

# 加载测试数据
test_data = TRANS(data_file)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=collote_fn)

# 定义BLEU评估函数
bleu = BLEU()

  model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))


In [12]:
import time
def test_model(dataloader, model):
    preds, labels = [], []

    model.eval()
    for batch_data in tqdm(dataloader):
        batch_data = batch_data.to(device)
        with torch.no_grad():
            generated_tokens = model.generate(
                batch_data["input_ids"],
                attention_mask=batch_data["attention_mask"],
                max_length=max_target_length,
            ).cpu().numpy()

        label_tokens = batch_data["labels"].cpu().numpy()
        

        decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        preds += [pred.strip() for pred in decoded_preds]
        labels += [[label.strip()] for label in decoded_labels]
    
    bleu_score = bleu.corpus_score(preds, labels).score
    print(f"BLEU: {bleu_score:>0.2f}")
    return bleu_score

In [13]:
print("Testing model...")
bleu_score = test_model(test_dataloader, model)
print(f"Test BLEU score: {bleu_score:0.2f}")

Testing model...


100%|███████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:33<00:00,  2.61s/it]

BLEU: 12.95
Test BLEU score: 12.95



