1
0
cvsa/ml_new/training/trainer.py

536 lines
19 KiB
Python

"""
Model trainer for embedding classification
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
import numpy as np
from typing import Dict, Any, Optional
import json
from datetime import datetime
from pathlib import Path
from ml_new.training.models import EmbeddingClassifier
from ml_new.config.logger_config import get_logger
logger = get_logger(__name__)
class ModelTrainer:
"""
Trainer for embedding classification models
"""
def __init__(
self,
model: EmbeddingClassifier,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: Optional[DataLoader] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
criterion: Optional[nn.Module] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[torch.device] = None,
save_dir: str = "training/checkpoints",
experiment_name: Optional[str] = None
):
"""
Initialize model trainer
Args:
model: PyTorch model to train
train_loader: Training data loader
val_loader: Validation data loader
test_loader: Test data loader (optional)
optimizer: Optimizer instance (default: Adam)
criterion: Loss function (default: BCEWithLogitsLoss)
scheduler: Learning rate scheduler (optional)
device: Device to use (default: auto-detect)
save_dir: Directory to save checkpoints
experiment_name: Name for the experiment
"""
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
# Device setup
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
self.model.to(self.device)
# Loss function and optimizer
self.criterion = criterion or nn.BCEWithLogitsLoss()
self.optimizer = optimizer or optim.Adam(self.model.parameters(), lr=0.001)
self.scheduler = scheduler
# Training configuration
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
# Experiment tracking
self.experiment_name = experiment_name or f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.checkpoint_dir = self.save_dir / self.experiment_name
self.checkpoint_dir.mkdir(exist_ok=True)
# TensorBoard logging
self.writer = SummaryWriter(log_dir=str(self.checkpoint_dir / "logs"))
# Training state
self.current_epoch = 0
self.best_val_score = 0.0
self.patience_counter = 0
self.training_history = {
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': [],
'val_auc': [],
'learning_rates': []
}
logger.info(f"Initialized trainer with device: {self.device}")
logger.info(f"Model info: {self.model.get_model_info()}")
def train_epoch(self) -> Dict[str, float]:
"""Train for one epoch"""
self.model.train()
total_loss = 0.0
all_predictions = []
all_labels = []
for batch_idx, (embeddings, labels, metadata) in enumerate(self.train_loader):
embeddings = embeddings.to(self.device)
labels = labels.to(self.device).float()
# Zero gradients
self.optimizer.zero_grad()
# Forward pass
outputs = self.model(embeddings)
loss = self.criterion(outputs.squeeze(), labels)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Collect statistics
total_loss += loss.item()
# Get predictions for metrics
with torch.no_grad():
predictions = torch.sigmoid(outputs).squeeze()
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Log progress
if batch_idx % 5 == 0:
logger.info(f"Batch {batch_idx}/{len(self.train_loader)}, Loss: {loss.item():.4f}")
# Calculate epoch metrics
epoch_loss = total_loss / len(self.train_loader)
epoch_acc = accuracy_score(all_labels, (np.array(all_predictions) > 0.5).astype(int))
return {
'loss': epoch_loss,
'accuracy': epoch_acc
}
def validate(self) -> Dict[str, float]:
"""Validate the model"""
self.model.eval()
total_loss = 0.0
all_predictions = []
all_labels = []
all_probabilities = []
with torch.no_grad():
for embeddings, labels, metadata in self.val_loader:
embeddings = embeddings.to(self.device)
labels = labels.to(self.device).float()
# Forward pass
outputs = self.model(embeddings)
loss = self.criterion(outputs.squeeze(), labels)
# Collect statistics
total_loss += loss.item()
# Get predictions and probabilities
probabilities = torch.sigmoid(outputs).squeeze()
predictions = (probabilities > 0.5).long()
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# Calculate metrics
val_loss = total_loss / len(self.val_loader)
val_accuracy = accuracy_score(all_labels, all_predictions)
# Calculate additional metrics
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_predictions, average='binary', zero_division=0
)
try:
val_auc = roc_auc_score(all_labels, all_probabilities)
except ValueError:
val_auc = 0.0 # AUC not defined for single class
# Confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
return {
'loss': val_loss,
'accuracy': val_accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'auc': val_auc,
'true_negatives': tn,
'false_positives': fp,
'false_negatives': fn,
'true_positives': tp
}
def train(
self,
num_epochs: int = 100,
save_every: int = 10,
early_stopping_patience: int = 15,
min_delta: float = 0.001,
validation_score: str = 'f1'
) -> Dict[str, Any]:
"""
Train the model
Args:
num_epochs: Number of training epochs
save_every: Save checkpoint every N epochs
early_stopping_patience: Patience for early stopping
min_delta: Minimum improvement for early stopping
validation_score: Metric to use for early stopping ('f1', 'auc', 'accuracy')
Returns:
Training results dictionary
"""
logger.info(f"Starting training for {num_epochs} epochs")
for epoch in range(num_epochs):
self.current_epoch = epoch
# Training
train_metrics = self.train_epoch()
# Validation
val_metrics = self.validate()
# Learning rate scheduling
if self.scheduler:
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(val_metrics['loss'])
else:
self.scheduler.step()
# Log metrics
current_lr = self.optimizer.param_groups[0]['lr']
# TensorBoard logging
self.writer.add_scalar('Loss/Train', train_metrics['loss'], epoch)
self.writer.add_scalar('Accuracy/Train', train_metrics['accuracy'], epoch)
self.writer.add_scalar('Loss/Validation', val_metrics['loss'], epoch)
self.writer.add_scalar('Accuracy/Validation', val_metrics['accuracy'], epoch)
self.writer.add_scalar('F1/Validation', val_metrics['f1'], epoch)
self.writer.add_scalar('AUC/Validation', val_metrics['auc'], epoch)
self.writer.add_scalar('Learning_Rate', current_lr, epoch)
# Update training history
self.training_history['train_loss'].append(train_metrics['loss'])
self.training_history['train_acc'].append(train_metrics['accuracy'])
self.training_history['val_loss'].append(val_metrics['loss'])
self.training_history['val_acc'].append(val_metrics['accuracy'])
self.training_history['val_auc'].append(val_metrics['auc'])
self.training_history['learning_rates'].append(current_lr)
# Print progress
logger.info(
f"Epoch {epoch+1}/{num_epochs} - "
f"Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.4f} - "
f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}, "
f"Val F1: {val_metrics['f1']:.4f}, Val AUC: {val_metrics['auc']:.4f}"
)
# Check for best model
val_score = val_metrics[validation_score]
if val_score > self.best_val_score + min_delta:
self.best_val_score = val_score
self.patience_counter = 0
# Save best model
self.save_checkpoint(is_best=True)
logger.info(f"New best {validation_score}: {val_score:.4f}")
else:
self.patience_counter += 1
# Save checkpoint
if (epoch + 1) % save_every == 0:
self.save_checkpoint(is_best=False)
# Early stopping
if self.patience_counter >= early_stopping_patience:
logger.info(f"Early stopping triggered after {epoch+1} epochs")
break
# Final evaluation on test set if available
test_metrics = None
if self.test_loader:
test_metrics = self.evaluate_test_set()
# Save final training state
self.save_training_state()
# Close TensorBoard writer
self.writer.close()
results = {
'experiment_name': self.experiment_name,
'best_val_score': self.best_val_score,
'total_epochs': self.current_epoch + 1,
'training_history': self.training_history,
'final_val_metrics': self.validate(),
'test_metrics': test_metrics,
'model_info': self.model.get_model_info()
}
logger.info(f"Training completed. Best validation {validation_score}: {self.best_val_score:.4f}")
return results
def evaluate_test_set(self) -> Dict[str, float]:
"""Evaluate model on test set"""
self.model.eval()
total_loss = 0.0
all_predictions = []
all_labels = []
all_probabilities = []
with torch.no_grad():
for embeddings, labels, metadata in self.test_loader:
embeddings = embeddings.to(self.device)
labels = labels.to(self.device).float()
# Forward pass
outputs = self.model(embeddings)
loss = self.criterion(outputs.squeeze(), labels)
# Collect statistics
total_loss += loss.item()
# Get predictions and probabilities
probabilities = torch.sigmoid(outputs).squeeze()
predictions = (probabilities > 0.5).long()
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# Calculate metrics
test_loss = total_loss / len(self.test_loader)
test_accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_predictions, average='binary', zero_division=0
)
try:
test_auc = roc_auc_score(all_labels, all_probabilities)
except ValueError:
test_auc = 0.0
logger.info(f"Test Set Results - Loss: {test_loss:.4f}, Acc: {test_accuracy:.4f}, F1: {f1:.4f}, AUC: {test_auc:.4f}")
return {
'loss': test_loss,
'accuracy': test_accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'auc': test_auc
}
def save_checkpoint(self, is_best: bool = False) -> None:
"""Save model checkpoint"""
checkpoint = {
'epoch': self.current_epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_val_score': self.best_val_score,
'training_history': self.training_history,
'model_config': {
'input_dim': self.model.input_dim,
'hidden_dims': self.model.hidden_dims,
'dropout_rate': self.model.dropout_rate,
'batch_norm': self.model.batch_norm
}
}
if self.scheduler:
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
# Save checkpoint
checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
torch.save(checkpoint, checkpoint_path)
# Save best model separately
if is_best:
best_model_path = self.checkpoint_dir / "best_model.pth"
torch.save(checkpoint, best_model_path)
logger.info(f"Saved best model to {best_model_path}")
def save_training_state(self) -> None:
"""Save final training state"""
state = {
'experiment_name': self.experiment_name,
'best_val_score': self.best_val_score,
'total_epochs': self.current_epoch + 1,
'training_history': self.training_history,
'model_info': self.model.get_model_info(),
'training_config': {
'optimizer': self.optimizer.__class__.__name__,
'criterion': self.criterion.__class__.__name__,
'device': str(self.device)
}
}
state_path = self.checkpoint_dir / "training_state.json"
with open(state_path, 'w') as f:
json.dump(state, f, indent=2)
logger.info(f"Saved training state to {state_path}")
def load_checkpoint(self, checkpoint_path: str, load_optimizer: bool = True) -> None:
"""Load model from checkpoint"""
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
if load_optimizer:
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if self.scheduler and 'scheduler_state_dict' in checkpoint:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.current_epoch = checkpoint['epoch']
self.best_val_score = checkpoint.get('best_val_score', 0.0)
self.training_history = checkpoint.get('training_history', {})
logger.info(f"Loaded checkpoint from {checkpoint_path}")
def create_trainer(
model: EmbeddingClassifier,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: Optional[DataLoader] = None,
learning_rate: float = 0.001,
weight_decay: float = 1e-5,
scheduler_type: Optional[str] = 'plateau',
**kwargs
) -> ModelTrainer:
"""
Factory function to create a configured trainer
Args:
model: The model to train
train_loader: Training data loader
val_loader: Validation data loader
test_loader: Test data loader (optional)
learning_rate: Initial learning rate
weight_decay: L2 regularization
scheduler_type: Learning rate scheduler type ('plateau', 'step', None)
**kwargs: Additional trainer arguments
Returns:
Configured ModelTrainer instance
"""
# Create optimizer
optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay
)
# Create scheduler
scheduler = None
if scheduler_type == 'plateau':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
factor=0.5,
patience=10
)
elif scheduler_type == 'step':
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=1,
gamma=0.1
)
# Create trainer
trainer = ModelTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
optimizer=optimizer,
scheduler=scheduler,
**kwargs
)
return trainer
if __name__ == "__main__":
# Test trainer creation
from ml_new.training.models import create_model
from ml_new.training.data_loader import DatasetLoader
# Create dummy model and data
model = create_model(
input_dim=2048,
hidden_dims=(512, 256, 128)
)
loader = DatasetLoader()
datasets = loader.list_datasets()
if datasets:
train_loader, val_loader, test_loader, data_info = loader.create_data_loaders(
datasets[0],
batch_size=8,
normalize=True
)
# Create trainer
trainer = create_trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
experiment_name="test_experiment"
)
# Test one epoch
print("Testing trainer...")
train_metrics = trainer.train_epoch()
val_metrics = trainer.validate()
print(f"Train metrics: {train_metrics}")
print(f"Val metrics: {val_metrics}")