diff --git a/filter/train.py b/filter/train.py index 5295fde..99566d9 100644 --- a/filter/train.py +++ b/filter/train.py @@ -74,11 +74,11 @@ os.makedirs('./filter/checkpoints', exist_ok=True) # 优化器 eval_interval = 20 num_epochs = 20 -total_steps = samples_count * num_epochs / train_loader.batch_size +total_steps = samples_count * num_epochs / batch_size warmup_rate = 0.1 -optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-5) +optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-3) cosine_annealing_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps - int(total_steps * warmup_rate)) -warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.4, end_factor=1.0, total_iters=int(total_steps * warmup_rate)) +warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.8, end_factor=1.0, total_iters=int(total_steps * warmup_rate)) scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_annealing_scheduler], milestones=[int(total_steps * warmup_rate)]) criterion = nn.CrossEntropyLoss(weight=class_weights).to(device)