ref: return to cascade classifier
This commit is contained in:
parent
f08a863ac6
commit
b84f8a1f3e
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from dataset import MultiChannelDataset
|
from dataset import MultiChannelDataset
|
||||||
from filter.modelV3_10 import VideoClassifierV3_10
|
from filter.modelV3_12 import VideoClassifierV3_12
|
||||||
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
@ -51,39 +51,30 @@ class_weights = torch.tensor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 初始化模型和SentenceTransformer
|
# 初始化模型和SentenceTransformer
|
||||||
model1 = VideoClassifierV3_10(output_dim=2, hidden_dim=384)
|
model = VideoClassifierV3_12()
|
||||||
model2 = VideoClassifierV3_10(output_dim=2, hidden_dim=384)
|
checkpoint_name = './filter/checkpoints/best_model_V3.12.pt'
|
||||||
checkpoint1_name = './filter/checkpoints/best_model_V3.14-part1.pt'
|
|
||||||
checkpoint2_name = './filter/checkpoints/best_model_V3.14-part2.pt'
|
|
||||||
|
|
||||||
# 模型保存路径
|
# 模型保存路径
|
||||||
os.makedirs('./filter/checkpoints', exist_ok=True)
|
os.makedirs('./filter/checkpoints', exist_ok=True)
|
||||||
|
|
||||||
# 优化器
|
# 优化器
|
||||||
optimizer1 = optim.AdamW(model1.parameters(), lr=4e-4)
|
optimizer = optim.AdamW(model.parameters(), lr=4e-4)
|
||||||
optimizer2 = optim.AdamW(model2.parameters(), lr=4e-4)
|
|
||||||
# Cross entropy loss
|
# Cross entropy loss
|
||||||
criterion1 = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
criterion2 = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def count_trainable_parameters(model):
|
def count_trainable_parameters(model):
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
def evaluate(model1, model2, dataloader):
|
def evaluate(model, dataloader):
|
||||||
model1.eval()
|
model.eval()
|
||||||
model2.eval()
|
|
||||||
all_preds = []
|
all_preds = []
|
||||||
all_labels = []
|
all_labels = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch_tensor = prepare_batch(batch['texts'])
|
batch_tensor = prepare_batch(batch['texts'])
|
||||||
logits1 = model1(batch_tensor)
|
logits = model(batch_tensor)
|
||||||
logits2 = model2(batch_tensor)
|
preds = torch.argmax(logits, dim=1)
|
||||||
preds1 = torch.argmax(logits1, dim=1)
|
|
||||||
preds2 = torch.argmax(logits2, dim=1)
|
|
||||||
# 如果preds1输出为0,那么预测结果为0,否则使用preds2的结果加上1
|
|
||||||
preds = torch.where(preds1 == 0, preds1, preds2 + 1)
|
|
||||||
all_preds.extend(preds.cpu().numpy())
|
all_preds.extend(preds.cpu().numpy())
|
||||||
all_labels.extend(batch['label'].cpu().numpy())
|
all_labels.extend(batch['label'].cpu().numpy())
|
||||||
|
|
||||||
@ -98,7 +89,7 @@ def evaluate(model1, model2, dataloader):
|
|||||||
|
|
||||||
return f1, recall, precision, accuracy, class_report
|
return f1, recall, precision, accuracy, class_report
|
||||||
|
|
||||||
print(f"Trainable parameters: {count_trainable_parameters(model1) + count_trainable_parameters(model2)}")
|
print(f"Trainable parameters: {count_trainable_parameters(model)}")
|
||||||
|
|
||||||
# 训练循环
|
# 训练循环
|
||||||
best_f1 = 0
|
best_f1 = 0
|
||||||
@ -107,44 +98,29 @@ eval_interval = 20
|
|||||||
num_epochs = 8
|
num_epochs = 8
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
model1.train()
|
model.train()
|
||||||
model2.train()
|
epoch_loss = 0
|
||||||
epoch_loss_1 = 0
|
|
||||||
epoch_loss_2 = 0
|
|
||||||
|
|
||||||
# 训练阶段
|
# 训练阶段
|
||||||
for batch_idx, batch in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
optimizer1.zero_grad()
|
optimizer.zero_grad()
|
||||||
optimizer2.zero_grad()
|
|
||||||
|
|
||||||
batch_tensor = prepare_batch(batch['texts'])
|
batch_tensor = prepare_batch(batch['texts'])
|
||||||
batch_tensor_1 = batch_tensor
|
|
||||||
mask = batch['label'] != 0
|
|
||||||
batch_tensor_2 = batch_tensor_1[mask]
|
|
||||||
|
|
||||||
logits1 = model1(batch_tensor_1)
|
logits = model(batch_tensor)
|
||||||
logits2 = model2(batch_tensor_2)
|
|
||||||
|
|
||||||
label1 = torch.where(batch['label'] == 0, 0, 1)
|
loss = criterion(logits, batch['label'])
|
||||||
label2 = torch.where(batch['label'][mask] == 1, 0, 1)
|
loss.backward()
|
||||||
loss1 = criterion1(logits1, label1)
|
optimizer.step()
|
||||||
loss1.backward()
|
epoch_loss += loss.item()
|
||||||
loss2 = criterion2(logits2, label2)
|
|
||||||
loss2.backward()
|
|
||||||
optimizer1.step()
|
|
||||||
optimizer2.step()
|
|
||||||
epoch_loss_1 += loss1.item()
|
|
||||||
epoch_loss_2 += loss2.item()
|
|
||||||
|
|
||||||
# 记录训练损失
|
# 记录训练损失
|
||||||
writer.add_scalar('Train/Loss-1', loss1.item(), step)
|
writer.add_scalar('Train/Loss', loss.item(), step)
|
||||||
writer.add_scalar('Train/Loss-2', loss2.item(), step)
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
# 每隔 eval_interval 步执行验证
|
# 每隔 eval_interval 步执行验证
|
||||||
if step % eval_interval == 0:
|
if step % eval_interval == 0:
|
||||||
eval_f1, eval_recall, eval_precision, eval_accuracy, eval_class_report = evaluate(model1, model2, eval_loader)
|
eval_f1, eval_recall, eval_precision, eval_accuracy, eval_class_report = evaluate(model, eval_loader)
|
||||||
|
|
||||||
writer.add_scalar('Eval/F1', eval_f1, step)
|
writer.add_scalar('Eval/F1', eval_f1, step)
|
||||||
writer.add_scalar('Eval/Recall', eval_recall, step)
|
writer.add_scalar('Eval/Recall', eval_recall, step)
|
||||||
writer.add_scalar('Eval/Precision', eval_precision, step)
|
writer.add_scalar('Eval/Precision', eval_precision, step)
|
||||||
@ -160,19 +136,17 @@ for epoch in range(num_epochs):
|
|||||||
# 保存最佳模型
|
# 保存最佳模型
|
||||||
if eval_f1 > best_f1:
|
if eval_f1 > best_f1:
|
||||||
best_f1 = eval_f1
|
best_f1 = eval_f1
|
||||||
torch.save(model1.state_dict(), checkpoint1_name)
|
torch.save(model.state_dict(), checkpoint_name)
|
||||||
torch.save(model2.state_dict(), checkpoint2_name)
|
|
||||||
print(" Saved best model")
|
print(" Saved best model")
|
||||||
print("Channel weights: ", model1.get_channel_weights())
|
print("Channel weights: ", model.get_channel_weights())
|
||||||
print("Channel weights: ", model2.get_channel_weights())
|
|
||||||
|
|
||||||
# 记录每个 epoch 的平均训练损失
|
# 记录每个 epoch 的平均训练损失
|
||||||
avg_epoch_loss = (epoch_loss_1 + epoch_loss_2) / 2 / len(train_loader)
|
avg_epoch_loss = epoch_loss / len(train_loader)
|
||||||
writer.add_scalar('Train/Epoch_Loss', avg_epoch_loss, epoch)
|
writer.add_scalar('Train/Epoch_Loss', avg_epoch_loss, epoch)
|
||||||
|
|
||||||
# 每个 epoch 结束后执行一次完整验证
|
# 每个 epoch 结束后执行一次完整验证
|
||||||
train_f1, train_recall, train_precision, train_accuracy, train_class_report = evaluate(model1, model2, train_loader)
|
train_f1, train_recall, train_precision, train_accuracy, train_class_report = evaluate(model, train_loader)
|
||||||
eval_f1, eval_recall, eval_precision, eval_accuracy, eval_class_report = evaluate(model1, model2, eval_loader)
|
eval_f1, eval_recall, eval_precision, eval_accuracy, eval_class_report = evaluate(model, eval_loader)
|
||||||
|
|
||||||
writer.add_scalar('Train/Epoch_F1', train_f1, epoch)
|
writer.add_scalar('Train/Epoch_F1', train_f1, epoch)
|
||||||
writer.add_scalar('Train/Epoch_Recall', train_recall, epoch)
|
writer.add_scalar('Train/Epoch_Recall', train_recall, epoch)
|
||||||
@ -199,9 +173,8 @@ for epoch in range(num_epochs):
|
|||||||
|
|
||||||
# 测试阶段
|
# 测试阶段
|
||||||
print("\nTesting...")
|
print("\nTesting...")
|
||||||
model1.load_state_dict(torch.load(checkpoint1_name))
|
model.load_state_dict(torch.load(checkpoint_name))
|
||||||
model2.load_state_dict(torch.load(checkpoint2_name))
|
test_f1, test_recall, test_precision, test_accuracy, test_class_report = evaluate(model, test_loader)
|
||||||
test_f1, test_recall, test_precision, test_accuracy, test_class_report = evaluate(model1, model2, test_loader)
|
|
||||||
writer.add_scalar('Test/F1', test_f1, step)
|
writer.add_scalar('Test/F1', test_f1, step)
|
||||||
writer.add_scalar('Test/Recall', test_recall, step)
|
writer.add_scalar('Test/Recall', test_recall, step)
|
||||||
writer.add_scalar('Test/Precision', test_precision, step)
|
writer.add_scalar('Test/Precision', test_precision, step)
|
||||||
|
Loading…
Reference in New Issue
Block a user