diff --git a/ml_new/training/__init__.py b/ml_new/training/__init__.py index d78502f..2433a61 100644 --- a/ml_new/training/__init__.py +++ b/ml_new/training/__init__.py @@ -2,8 +2,8 @@ Training module for ML models """ -from .models import EmbeddingClassifier +from .models import EmbeddingClassifier, FocalLoss from .trainer import ModelTrainer from .data_loader import DatasetLoader -__all__ = ['EmbeddingClassifier', 'ModelTrainer', 'DatasetLoader'] \ No newline at end of file +__all__ = ['EmbeddingClassifier', 'FocalLoss', 'ModelTrainer', 'DatasetLoader'] \ No newline at end of file diff --git a/ml_new/training/models.py b/ml_new/training/models.py index 740424b..9450fa6 100644 --- a/ml_new/training/models.py +++ b/ml_new/training/models.py @@ -10,6 +10,90 @@ from ml_new.config.logger_config import get_logger logger = get_logger(__name__) +class FocalLoss(nn.Module): + """ + Focal Loss for binary classification + + Focal Loss = -α_t * (1 - p_t)^γ * log(p_t) + + Where: + - p_t is the predicted probability for the true class + - α_t is the weighting factor for class imbalance + - γ (gamma) is the focusing parameter that reduces the relative loss for well-classified examples + """ + + def __init__( + self, + alpha: float = 1.0, + gamma: float = 2.0, + reduction: str = 'mean', + class_weights: Optional[torch.Tensor] = None + ): + """ + Initialize Focal Loss + + Args: + alpha: Weighting factor for class imbalance (default: 1.0) + gamma: Focusing parameter (default: 2.0) + reduction: Reduction method ('mean', 'sum', 'none') + class_weights: Optional tensor of class weights for imbalance + """ + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + self.class_weights = class_weights + + def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Compute focal loss + + Args: + inputs: Logits of shape (batch_size,) or (batch_size, 1) + targets: Binary labels of shape (batch_size,) or (batch_size, 1) + + Returns: + Focal loss value + """ + # Ensure inputs are properly shaped + if inputs.dim() > 1: + inputs = inputs.squeeze() + if targets.dim() > 1: + targets = targets.squeeze() + + # Compute sigmoid probabilities + prob = torch.sigmoid(inputs) + + # Convert targets to float for computation + targets = targets.float() + + # Compute focal loss components + # pt = p if y=1, pt = 1-p if y=0 + pt = torch.where(targets == 1, prob, 1 - prob) + + # focal weight = (1 - pt)^gamma + focal_weight = (1 - pt).pow(self.gamma) + + # alpha weight = alpha for positive class, 1-alpha for negative class + alpha_weight = torch.where(targets == 1, + torch.tensor(self.alpha, device=inputs.device), + torch.tensor(1 - self.alpha, device=inputs.device)) + + # Apply class weights if provided + if self.class_weights is not None: + alpha_weight = alpha_weight * self.class_weights[targets.long()] + + # focal loss = -alpha * (1-pt)^gamma * log(pt) + focal_loss = -alpha_weight * focal_weight * pt.log() + + # Apply reduction + if self.reduction == 'mean': + return focal_loss.mean() + elif self.reduction == 'sum': + return focal_loss.sum() + else: + return focal_loss + class EmbeddingClassifier(nn.Module): """ diff --git a/ml_new/training/test.py b/ml_new/training/test.py index 1853074..37960b6 100644 --- a/ml_new/training/test.py +++ b/ml_new/training/test.py @@ -194,6 +194,46 @@ def load_model_from_experiment( return model, model_config +def safe_extract_aid(metadata_entry): + """Safely extract aid from metadata entry""" + if isinstance(metadata_entry, dict) and 'aid' in metadata_entry: + return metadata_entry['aid'] + return None + +def normalize_batch_metadata(metadata, expected_batch_size): + """ + Normalize batch metadata to ensure consistent structure + + Args: + metadata: Raw metadata from DataLoader (could be various formats) + expected_batch_size: Expected number of metadata entries + + Returns: + List of metadata dictionaries + """ + # Handle different metadata structures + if metadata is None: + return [{}] * expected_batch_size + + if isinstance(metadata, dict): + # Single metadata object - duplicate for entire batch + return [metadata] * expected_batch_size + + if isinstance(metadata, (list, tuple)): + if len(metadata) == expected_batch_size: + return list(metadata) + elif len(metadata) < expected_batch_size: + # Pad with empty dicts + padded = list(metadata) + [{}] * (expected_batch_size - len(metadata)) + return padded + else: + # Truncate to expected size + return list(metadata[:expected_batch_size]) + + # Unknown format - return empty dicts + logger.warning(f"Unknown metadata format: {type(metadata)}") + return [{}] * expected_batch_size + def evaluate_model( model, test_loader: DataLoader, @@ -213,9 +253,7 @@ def evaluate_model( Tuple of (metrics, predictions, probabilities, true_labels, fn_aids, fp_aids) """ model.eval() - criterion = torch.nn.BCEWithLogitsLoss() - total_loss = 0.0 all_predictions = [] all_labels = [] all_probabilities = [] @@ -230,10 +268,6 @@ def evaluate_model( # Forward pass outputs = model(embeddings) - loss = criterion(outputs.squeeze(), labels) - - # Collect statistics - total_loss += loss.item() # Get predictions and probabilities probabilities = torch.sigmoid(outputs).squeeze() @@ -244,23 +278,45 @@ def evaluate_model( all_probabilities.extend(probabilities.cpu().numpy()) # Collect metadata and track FN/FP - batch_metadata = metadata if isinstance(metadata, list) else [metadata] + batch_size = len(labels) + batch_metadata = normalize_batch_metadata(metadata, batch_size) all_metadata.extend(batch_metadata) # Track FN and FP aids for this batch + logger.debug(f"Batch {batch_idx}: labels shape {labels.shape}, predictions shape {predictions.shape}, metadata structure: {type(batch_metadata)}") + if len(batch_metadata) != len(labels): + logger.warning(f"Metadata length mismatch: {len(batch_metadata)} metadata entries vs {len(labels)} samples") + for i, (true_label, pred_label) in enumerate(zip(labels.cpu().numpy(), predictions.cpu().numpy())): - if isinstance(batch_metadata[i], dict) and 'aid' in batch_metadata[i]: - aid = batch_metadata[i]['aid'] - if true_label == 1 and pred_label == 0: # False Negative - fn_aids.append(aid) - elif true_label == 0 and pred_label == 1: # False Positive - fp_aids.append(aid) + try: + # Safely get metadata entry with bounds checking + if i >= len(batch_metadata): + logger.warning(f"Index {i} out of range for batch_metadata (length: {len(batch_metadata)})") + continue + + meta_entry = batch_metadata[i] + if not isinstance(meta_entry, dict): + logger.warning(f"Metadata entry {i} is not a dict: {type(meta_entry)}") + continue + + if 'aid' not in meta_entry: + logger.debug(f"No 'aid' key in metadata entry {i}") + continue + + aid = safe_extract_aid(meta_entry) + if aid is not None: + if true_label == 1 and pred_label == 0: # False Negative + fn_aids.append(aid) + elif true_label == 0 and pred_label == 1: # False Positive + fp_aids.append(aid) + + except Exception as e: + logger.warning(f"Error processing metadata entry {i}: {e}") + continue if (batch_idx + 1) % 10 == 0: logger.info(f"Processed {batch_idx + 1}/{len(test_loader)} batches") - # Calculate metrics - test_loss = total_loss / len(test_loader) test_accuracy = accuracy_score(all_labels, all_predictions) precision, recall, f1, _ = precision_recall_fscore_support( all_labels, all_predictions, average='binary', zero_division=0 @@ -279,7 +335,6 @@ def evaluate_model( tn, fp, fn, tp = 0, 0, 0, 0 metrics = { - 'loss': test_loss, 'accuracy': test_accuracy, 'precision': precision, 'recall': recall, @@ -578,8 +633,6 @@ def main(): if 'failed_requests' in metrics: logger.info(f"Failed API requests: {metrics['failed_requests']}") logger.info("-" * 50) - if 'loss' in metrics: - logger.info(f"Loss: {metrics['loss']:.4f}") logger.info(f"Accuracy: {metrics['accuracy']:.4f}") logger.info(f"Precision: {metrics['precision']:.4f}") logger.info(f"Recall: {metrics['recall']:.4f}") diff --git a/ml_new/training/train.py b/ml_new/training/train.py index eaecfa8..17df0a6 100644 --- a/ml_new/training/train.py +++ b/ml_new/training/train.py @@ -158,6 +158,29 @@ def parse_args(): help="Metric for model selection and early stopping" ) + # Loss function arguments + parser.add_argument( + "--loss-type", + type=str, + choices=["focal", "bce"], + default="focal", + help="Type of loss function to use" + ) + + parser.add_argument( + "--focal-alpha", + type=float, + default=1.0, + help="Alpha parameter for focal loss (class imbalance weighting)" + ) + + parser.add_argument( + "--focal-gamma", + type=float, + default=2.0, + help="Gamma parameter for focal loss (focusing parameter)" + ) + # Data arguments parser.add_argument( "--train-ratio", @@ -423,6 +446,9 @@ def main(): learning_rate=args.learning_rate, weight_decay=args.weight_decay, scheduler_type=scheduler_type, + loss_type=args.loss_type, + focal_alpha=args.focal_alpha, + focal_gamma=args.focal_gamma, device=device, save_dir=args.save_dir, experiment_name=args.experiment_name @@ -455,7 +481,10 @@ def main(): "weight_decay": args.weight_decay, "scheduler": args.scheduler, "early_stopping_patience": args.early_stopping_patience, - "validation_metric": args.validation_metric + "validation_metric": args.validation_metric, + "loss_type": args.loss_type, + "focal_alpha": args.focal_alpha, + "focal_gamma": args.focal_gamma }, "data_config": { "train_ratio": args.train_ratio, diff --git a/ml_new/training/trainer.py b/ml_new/training/trainer.py index 81247d8..1515db6 100644 --- a/ml_new/training/trainer.py +++ b/ml_new/training/trainer.py @@ -13,7 +13,7 @@ from typing import Dict, Any, Optional import json from datetime import datetime from pathlib import Path -from ml_new.training.models import EmbeddingClassifier +from ml_new.training.models import EmbeddingClassifier, FocalLoss from ml_new.config.logger_config import get_logger logger = get_logger(__name__) @@ -70,6 +70,12 @@ class ModelTrainer: self.optimizer = optimizer or optim.Adam(self.model.parameters(), lr=0.001) self.scheduler = scheduler + # Store loss function type for configuration + if criterion is None: + self.loss_type = 'bce' + else: + self.loss_type = getattr(criterion, 'loss_type', 'custom') + # Training configuration self.save_dir = Path(save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) @@ -442,6 +448,10 @@ def create_trainer( learning_rate: float = 0.001, weight_decay: float = 1e-5, scheduler_type: Optional[str] = 'plateau', + loss_type: str = 'focal', # Default to focal loss + focal_alpha: float = 1.0, + focal_gamma: float = 2.0, + class_weights: Optional[torch.Tensor] = None, **kwargs ) -> ModelTrainer: """ @@ -455,11 +465,30 @@ def create_trainer( learning_rate: Initial learning rate weight_decay: L2 regularization scheduler_type: Learning rate scheduler type ('plateau', 'step', None) + loss_type: Type of loss function ('focal', 'bce') + focal_alpha: Alpha parameter for focal loss (class imbalance weighting) + focal_gamma: Gamma parameter for focal loss (focusing parameter) + class_weights: Optional tensor of class weights for additional balancing **kwargs: Additional trainer arguments Returns: Configured ModelTrainer instance """ + # Create loss function + if loss_type == 'focal': + criterion = FocalLoss( + alpha=focal_alpha, + gamma=focal_gamma, + reduction='mean', + class_weights=class_weights + ) + logger.info(f"Using Focal Loss with alpha={focal_alpha}, gamma={focal_gamma}") + elif loss_type == 'bce': + criterion = nn.BCEWithLogitsLoss() + logger.info("Using BCE With Logits Loss") + else: + raise ValueError(f"Unsupported loss type: {loss_type}") + # Create optimizer optimizer = optim.AdamW( model.parameters(), @@ -490,6 +519,7 @@ def create_trainer( val_loader=val_loader, test_loader=test_loader, optimizer=optimizer, + criterion=criterion, scheduler=scheduler, **kwargs )