diff --git a/filter/train.py b/filter/train.py index 8f68d62..065d38e 100644 --- a/filter/train.py +++ b/filter/train.py @@ -4,7 +4,7 @@ import numpy as np from torch.utils.data import DataLoader import torch.optim as optim from dataset import MultiChannelDataset -from filter.modelV3_10 import VideoClassifierV3_10 +from filter.modelV3_12 import VideoClassifierV3_12 from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report import os import torch @@ -51,39 +51,30 @@ class_weights = torch.tensor( ) # 初始化模型和SentenceTransformer -model1 = VideoClassifierV3_10(output_dim=2, hidden_dim=384) -model2 = VideoClassifierV3_10(output_dim=2, hidden_dim=384) -checkpoint1_name = './filter/checkpoints/best_model_V3.14-part1.pt' -checkpoint2_name = './filter/checkpoints/best_model_V3.14-part2.pt' +model = VideoClassifierV3_12() +checkpoint_name = './filter/checkpoints/best_model_V3.12.pt' # 模型保存路径 os.makedirs('./filter/checkpoints', exist_ok=True) # 优化器 -optimizer1 = optim.AdamW(model1.parameters(), lr=4e-4) -optimizer2 = optim.AdamW(model2.parameters(), lr=4e-4) +optimizer = optim.AdamW(model.parameters(), lr=4e-4) # Cross entropy loss -criterion1 = nn.CrossEntropyLoss() -criterion2 = nn.CrossEntropyLoss() +criterion = nn.CrossEntropyLoss() def count_trainable_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) -def evaluate(model1, model2, dataloader): - model1.eval() - model2.eval() +def evaluate(model, dataloader): + model.eval() all_preds = [] all_labels = [] with torch.no_grad(): for batch in dataloader: batch_tensor = prepare_batch(batch['texts']) - logits1 = model1(batch_tensor) - logits2 = model2(batch_tensor) - preds1 = torch.argmax(logits1, dim=1) - preds2 = torch.argmax(logits2, dim=1) - # 如果preds1输出为0,那么预测结果为0,否则使用preds2的结果加上1 - preds = torch.where(preds1 == 0, preds1, preds2 + 1) + logits = model(batch_tensor) + preds = torch.argmax(logits, dim=1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(batch['label'].cpu().numpy()) @@ -95,10 +86,10 @@ def evaluate(model1, model2, dataloader): # 获取每个类别的详细指标 class_report = classification_report(all_labels, all_preds, output_dict=True) - - return f1, recall, precision, accuracy, class_report -print(f"Trainable parameters: {count_trainable_parameters(model1) + count_trainable_parameters(model2)}") + return f1, recall, precision, accuracy, class_report + +print(f"Trainable parameters: {count_trainable_parameters(model)}") # 训练循环 best_f1 = 0 @@ -107,44 +98,29 @@ eval_interval = 20 num_epochs = 8 for epoch in range(num_epochs): - model1.train() - model2.train() - epoch_loss_1 = 0 - epoch_loss_2 = 0 + model.train() + epoch_loss = 0 # 训练阶段 for batch_idx, batch in enumerate(train_loader): - optimizer1.zero_grad() - optimizer2.zero_grad() + optimizer.zero_grad() batch_tensor = prepare_batch(batch['texts']) - batch_tensor_1 = batch_tensor - mask = batch['label'] != 0 - batch_tensor_2 = batch_tensor_1[mask] - logits1 = model1(batch_tensor_1) - logits2 = model2(batch_tensor_2) + logits = model(batch_tensor) - label1 = torch.where(batch['label'] == 0, 0, 1) - label2 = torch.where(batch['label'][mask] == 1, 0, 1) - loss1 = criterion1(logits1, label1) - loss1.backward() - loss2 = criterion2(logits2, label2) - loss2.backward() - optimizer1.step() - optimizer2.step() - epoch_loss_1 += loss1.item() - epoch_loss_2 += loss2.item() + loss = criterion(logits, batch['label']) + loss.backward() + optimizer.step() + epoch_loss += loss.item() # 记录训练损失 - writer.add_scalar('Train/Loss-1', loss1.item(), step) - writer.add_scalar('Train/Loss-2', loss2.item(), step) + writer.add_scalar('Train/Loss', loss.item(), step) step += 1 # 每隔 eval_interval 步执行验证 if step % eval_interval == 0: - eval_f1, eval_recall, eval_precision, eval_accuracy, eval_class_report = evaluate(model1, model2, eval_loader) - + eval_f1, eval_recall, eval_precision, eval_accuracy, eval_class_report = evaluate(model, eval_loader) writer.add_scalar('Eval/F1', eval_f1, step) writer.add_scalar('Eval/Recall', eval_recall, step) writer.add_scalar('Eval/Precision', eval_precision, step) @@ -160,19 +136,17 @@ for epoch in range(num_epochs): # 保存最佳模型 if eval_f1 > best_f1: best_f1 = eval_f1 - torch.save(model1.state_dict(), checkpoint1_name) - torch.save(model2.state_dict(), checkpoint2_name) + torch.save(model.state_dict(), checkpoint_name) print(" Saved best model") - print("Channel weights: ", model1.get_channel_weights()) - print("Channel weights: ", model2.get_channel_weights()) + print("Channel weights: ", model.get_channel_weights()) # 记录每个 epoch 的平均训练损失 - avg_epoch_loss = (epoch_loss_1 + epoch_loss_2) / 2 / len(train_loader) + avg_epoch_loss = epoch_loss / len(train_loader) writer.add_scalar('Train/Epoch_Loss', avg_epoch_loss, epoch) # 每个 epoch 结束后执行一次完整验证 - train_f1, train_recall, train_precision, train_accuracy, train_class_report = evaluate(model1, model2, train_loader) - eval_f1, eval_recall, eval_precision, eval_accuracy, eval_class_report = evaluate(model1, model2, eval_loader) + train_f1, train_recall, train_precision, train_accuracy, train_class_report = evaluate(model, train_loader) + eval_f1, eval_recall, eval_precision, eval_accuracy, eval_class_report = evaluate(model, eval_loader) writer.add_scalar('Train/Epoch_F1', train_f1, epoch) writer.add_scalar('Train/Epoch_Recall', train_recall, epoch) @@ -199,9 +173,8 @@ for epoch in range(num_epochs): # 测试阶段 print("\nTesting...") -model1.load_state_dict(torch.load(checkpoint1_name)) -model2.load_state_dict(torch.load(checkpoint2_name)) -test_f1, test_recall, test_precision, test_accuracy, test_class_report = evaluate(model1, model2, test_loader) +model.load_state_dict(torch.load(checkpoint_name)) +test_f1, test_recall, test_precision, test_accuracy, test_class_report = evaluate(model, test_loader) writer.add_scalar('Test/F1', test_f1, step) writer.add_scalar('Test/Recall', test_recall, step) writer.add_scalar('Test/Precision', test_precision, step)