add: focal loss
This commit is contained in:
parent
3d96f4986d
commit
a6f0d8a27c
@ -2,8 +2,8 @@
|
||||
Training module for ML models
|
||||
"""
|
||||
|
||||
from .models import EmbeddingClassifier
|
||||
from .models import EmbeddingClassifier, FocalLoss
|
||||
from .trainer import ModelTrainer
|
||||
from .data_loader import DatasetLoader
|
||||
|
||||
__all__ = ['EmbeddingClassifier', 'ModelTrainer', 'DatasetLoader']
|
||||
__all__ = ['EmbeddingClassifier', 'FocalLoss', 'ModelTrainer', 'DatasetLoader']
|
||||
@ -10,6 +10,90 @@ from ml_new.config.logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""
|
||||
Focal Loss for binary classification
|
||||
|
||||
Focal Loss = -α_t * (1 - p_t)^γ * log(p_t)
|
||||
|
||||
Where:
|
||||
- p_t is the predicted probability for the true class
|
||||
- α_t is the weighting factor for class imbalance
|
||||
- γ (gamma) is the focusing parameter that reduces the relative loss for well-classified examples
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha: float = 1.0,
|
||||
gamma: float = 2.0,
|
||||
reduction: str = 'mean',
|
||||
class_weights: Optional[torch.Tensor] = None
|
||||
):
|
||||
"""
|
||||
Initialize Focal Loss
|
||||
|
||||
Args:
|
||||
alpha: Weighting factor for class imbalance (default: 1.0)
|
||||
gamma: Focusing parameter (default: 2.0)
|
||||
reduction: Reduction method ('mean', 'sum', 'none')
|
||||
class_weights: Optional tensor of class weights for imbalance
|
||||
"""
|
||||
super(FocalLoss, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.reduction = reduction
|
||||
self.class_weights = class_weights
|
||||
|
||||
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute focal loss
|
||||
|
||||
Args:
|
||||
inputs: Logits of shape (batch_size,) or (batch_size, 1)
|
||||
targets: Binary labels of shape (batch_size,) or (batch_size, 1)
|
||||
|
||||
Returns:
|
||||
Focal loss value
|
||||
"""
|
||||
# Ensure inputs are properly shaped
|
||||
if inputs.dim() > 1:
|
||||
inputs = inputs.squeeze()
|
||||
if targets.dim() > 1:
|
||||
targets = targets.squeeze()
|
||||
|
||||
# Compute sigmoid probabilities
|
||||
prob = torch.sigmoid(inputs)
|
||||
|
||||
# Convert targets to float for computation
|
||||
targets = targets.float()
|
||||
|
||||
# Compute focal loss components
|
||||
# pt = p if y=1, pt = 1-p if y=0
|
||||
pt = torch.where(targets == 1, prob, 1 - prob)
|
||||
|
||||
# focal weight = (1 - pt)^gamma
|
||||
focal_weight = (1 - pt).pow(self.gamma)
|
||||
|
||||
# alpha weight = alpha for positive class, 1-alpha for negative class
|
||||
alpha_weight = torch.where(targets == 1,
|
||||
torch.tensor(self.alpha, device=inputs.device),
|
||||
torch.tensor(1 - self.alpha, device=inputs.device))
|
||||
|
||||
# Apply class weights if provided
|
||||
if self.class_weights is not None:
|
||||
alpha_weight = alpha_weight * self.class_weights[targets.long()]
|
||||
|
||||
# focal loss = -alpha * (1-pt)^gamma * log(pt)
|
||||
focal_loss = -alpha_weight * focal_weight * pt.log()
|
||||
|
||||
# Apply reduction
|
||||
if self.reduction == 'mean':
|
||||
return focal_loss.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return focal_loss.sum()
|
||||
else:
|
||||
return focal_loss
|
||||
|
||||
|
||||
class EmbeddingClassifier(nn.Module):
|
||||
"""
|
||||
|
||||
@ -194,6 +194,46 @@ def load_model_from_experiment(
|
||||
return model, model_config
|
||||
|
||||
|
||||
def safe_extract_aid(metadata_entry):
|
||||
"""Safely extract aid from metadata entry"""
|
||||
if isinstance(metadata_entry, dict) and 'aid' in metadata_entry:
|
||||
return metadata_entry['aid']
|
||||
return None
|
||||
|
||||
def normalize_batch_metadata(metadata, expected_batch_size):
|
||||
"""
|
||||
Normalize batch metadata to ensure consistent structure
|
||||
|
||||
Args:
|
||||
metadata: Raw metadata from DataLoader (could be various formats)
|
||||
expected_batch_size: Expected number of metadata entries
|
||||
|
||||
Returns:
|
||||
List of metadata dictionaries
|
||||
"""
|
||||
# Handle different metadata structures
|
||||
if metadata is None:
|
||||
return [{}] * expected_batch_size
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
# Single metadata object - duplicate for entire batch
|
||||
return [metadata] * expected_batch_size
|
||||
|
||||
if isinstance(metadata, (list, tuple)):
|
||||
if len(metadata) == expected_batch_size:
|
||||
return list(metadata)
|
||||
elif len(metadata) < expected_batch_size:
|
||||
# Pad with empty dicts
|
||||
padded = list(metadata) + [{}] * (expected_batch_size - len(metadata))
|
||||
return padded
|
||||
else:
|
||||
# Truncate to expected size
|
||||
return list(metadata[:expected_batch_size])
|
||||
|
||||
# Unknown format - return empty dicts
|
||||
logger.warning(f"Unknown metadata format: {type(metadata)}")
|
||||
return [{}] * expected_batch_size
|
||||
|
||||
def evaluate_model(
|
||||
model,
|
||||
test_loader: DataLoader,
|
||||
@ -213,9 +253,7 @@ def evaluate_model(
|
||||
Tuple of (metrics, predictions, probabilities, true_labels, fn_aids, fp_aids)
|
||||
"""
|
||||
model.eval()
|
||||
criterion = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
total_loss = 0.0
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
all_probabilities = []
|
||||
@ -230,10 +268,6 @@ def evaluate_model(
|
||||
|
||||
# Forward pass
|
||||
outputs = model(embeddings)
|
||||
loss = criterion(outputs.squeeze(), labels)
|
||||
|
||||
# Collect statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Get predictions and probabilities
|
||||
probabilities = torch.sigmoid(outputs).squeeze()
|
||||
@ -244,23 +278,45 @@ def evaluate_model(
|
||||
all_probabilities.extend(probabilities.cpu().numpy())
|
||||
|
||||
# Collect metadata and track FN/FP
|
||||
batch_metadata = metadata if isinstance(metadata, list) else [metadata]
|
||||
batch_size = len(labels)
|
||||
batch_metadata = normalize_batch_metadata(metadata, batch_size)
|
||||
all_metadata.extend(batch_metadata)
|
||||
|
||||
# Track FN and FP aids for this batch
|
||||
logger.debug(f"Batch {batch_idx}: labels shape {labels.shape}, predictions shape {predictions.shape}, metadata structure: {type(batch_metadata)}")
|
||||
if len(batch_metadata) != len(labels):
|
||||
logger.warning(f"Metadata length mismatch: {len(batch_metadata)} metadata entries vs {len(labels)} samples")
|
||||
|
||||
for i, (true_label, pred_label) in enumerate(zip(labels.cpu().numpy(), predictions.cpu().numpy())):
|
||||
if isinstance(batch_metadata[i], dict) and 'aid' in batch_metadata[i]:
|
||||
aid = batch_metadata[i]['aid']
|
||||
if true_label == 1 and pred_label == 0: # False Negative
|
||||
fn_aids.append(aid)
|
||||
elif true_label == 0 and pred_label == 1: # False Positive
|
||||
fp_aids.append(aid)
|
||||
try:
|
||||
# Safely get metadata entry with bounds checking
|
||||
if i >= len(batch_metadata):
|
||||
logger.warning(f"Index {i} out of range for batch_metadata (length: {len(batch_metadata)})")
|
||||
continue
|
||||
|
||||
meta_entry = batch_metadata[i]
|
||||
if not isinstance(meta_entry, dict):
|
||||
logger.warning(f"Metadata entry {i} is not a dict: {type(meta_entry)}")
|
||||
continue
|
||||
|
||||
if 'aid' not in meta_entry:
|
||||
logger.debug(f"No 'aid' key in metadata entry {i}")
|
||||
continue
|
||||
|
||||
aid = safe_extract_aid(meta_entry)
|
||||
if aid is not None:
|
||||
if true_label == 1 and pred_label == 0: # False Negative
|
||||
fn_aids.append(aid)
|
||||
elif true_label == 0 and pred_label == 1: # False Positive
|
||||
fp_aids.append(aid)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing metadata entry {i}: {e}")
|
||||
continue
|
||||
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
logger.info(f"Processed {batch_idx + 1}/{len(test_loader)} batches")
|
||||
|
||||
# Calculate metrics
|
||||
test_loss = total_loss / len(test_loader)
|
||||
test_accuracy = accuracy_score(all_labels, all_predictions)
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
all_labels, all_predictions, average='binary', zero_division=0
|
||||
@ -279,7 +335,6 @@ def evaluate_model(
|
||||
tn, fp, fn, tp = 0, 0, 0, 0
|
||||
|
||||
metrics = {
|
||||
'loss': test_loss,
|
||||
'accuracy': test_accuracy,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
@ -578,8 +633,6 @@ def main():
|
||||
if 'failed_requests' in metrics:
|
||||
logger.info(f"Failed API requests: {metrics['failed_requests']}")
|
||||
logger.info("-" * 50)
|
||||
if 'loss' in metrics:
|
||||
logger.info(f"Loss: {metrics['loss']:.4f}")
|
||||
logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
|
||||
logger.info(f"Precision: {metrics['precision']:.4f}")
|
||||
logger.info(f"Recall: {metrics['recall']:.4f}")
|
||||
|
||||
@ -158,6 +158,29 @@ def parse_args():
|
||||
help="Metric for model selection and early stopping"
|
||||
)
|
||||
|
||||
# Loss function arguments
|
||||
parser.add_argument(
|
||||
"--loss-type",
|
||||
type=str,
|
||||
choices=["focal", "bce"],
|
||||
default="focal",
|
||||
help="Type of loss function to use"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--focal-alpha",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Alpha parameter for focal loss (class imbalance weighting)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--focal-gamma",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="Gamma parameter for focal loss (focusing parameter)"
|
||||
)
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument(
|
||||
"--train-ratio",
|
||||
@ -423,6 +446,9 @@ def main():
|
||||
learning_rate=args.learning_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
scheduler_type=scheduler_type,
|
||||
loss_type=args.loss_type,
|
||||
focal_alpha=args.focal_alpha,
|
||||
focal_gamma=args.focal_gamma,
|
||||
device=device,
|
||||
save_dir=args.save_dir,
|
||||
experiment_name=args.experiment_name
|
||||
@ -455,7 +481,10 @@ def main():
|
||||
"weight_decay": args.weight_decay,
|
||||
"scheduler": args.scheduler,
|
||||
"early_stopping_patience": args.early_stopping_patience,
|
||||
"validation_metric": args.validation_metric
|
||||
"validation_metric": args.validation_metric,
|
||||
"loss_type": args.loss_type,
|
||||
"focal_alpha": args.focal_alpha,
|
||||
"focal_gamma": args.focal_gamma
|
||||
},
|
||||
"data_config": {
|
||||
"train_ratio": args.train_ratio,
|
||||
|
||||
@ -13,7 +13,7 @@ from typing import Dict, Any, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from ml_new.training.models import EmbeddingClassifier
|
||||
from ml_new.training.models import EmbeddingClassifier, FocalLoss
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -70,6 +70,12 @@ class ModelTrainer:
|
||||
self.optimizer = optimizer or optim.Adam(self.model.parameters(), lr=0.001)
|
||||
self.scheduler = scheduler
|
||||
|
||||
# Store loss function type for configuration
|
||||
if criterion is None:
|
||||
self.loss_type = 'bce'
|
||||
else:
|
||||
self.loss_type = getattr(criterion, 'loss_type', 'custom')
|
||||
|
||||
# Training configuration
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -442,6 +448,10 @@ def create_trainer(
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 1e-5,
|
||||
scheduler_type: Optional[str] = 'plateau',
|
||||
loss_type: str = 'focal', # Default to focal loss
|
||||
focal_alpha: float = 1.0,
|
||||
focal_gamma: float = 2.0,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> ModelTrainer:
|
||||
"""
|
||||
@ -455,11 +465,30 @@ def create_trainer(
|
||||
learning_rate: Initial learning rate
|
||||
weight_decay: L2 regularization
|
||||
scheduler_type: Learning rate scheduler type ('plateau', 'step', None)
|
||||
loss_type: Type of loss function ('focal', 'bce')
|
||||
focal_alpha: Alpha parameter for focal loss (class imbalance weighting)
|
||||
focal_gamma: Gamma parameter for focal loss (focusing parameter)
|
||||
class_weights: Optional tensor of class weights for additional balancing
|
||||
**kwargs: Additional trainer arguments
|
||||
|
||||
Returns:
|
||||
Configured ModelTrainer instance
|
||||
"""
|
||||
# Create loss function
|
||||
if loss_type == 'focal':
|
||||
criterion = FocalLoss(
|
||||
alpha=focal_alpha,
|
||||
gamma=focal_gamma,
|
||||
reduction='mean',
|
||||
class_weights=class_weights
|
||||
)
|
||||
logger.info(f"Using Focal Loss with alpha={focal_alpha}, gamma={focal_gamma}")
|
||||
elif loss_type == 'bce':
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
logger.info("Using BCE With Logits Loss")
|
||||
else:
|
||||
raise ValueError(f"Unsupported loss type: {loss_type}")
|
||||
|
||||
# Create optimizer
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(),
|
||||
@ -490,6 +519,7 @@ def create_trainer(
|
||||
val_loader=val_loader,
|
||||
test_loader=test_loader,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
scheduler=scheduler,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user