1
0

add: focal loss

This commit is contained in:
alikia2x (寒寒) 2025-12-11 01:58:08 +08:00
parent 3d96f4986d
commit a6f0d8a27c
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG Key ID: 56209E0CCD8420C6
5 changed files with 218 additions and 22 deletions

View File

@ -2,8 +2,8 @@
Training module for ML models
"""
from .models import EmbeddingClassifier
from .models import EmbeddingClassifier, FocalLoss
from .trainer import ModelTrainer
from .data_loader import DatasetLoader
__all__ = ['EmbeddingClassifier', 'ModelTrainer', 'DatasetLoader']
__all__ = ['EmbeddingClassifier', 'FocalLoss', 'ModelTrainer', 'DatasetLoader']

View File

@ -10,6 +10,90 @@ from ml_new.config.logger_config import get_logger
logger = get_logger(__name__)
class FocalLoss(nn.Module):
"""
Focal Loss for binary classification
Focal Loss = -α_t * (1 - p_t)^γ * log(p_t)
Where:
- p_t is the predicted probability for the true class
- α_t is the weighting factor for class imbalance
- γ (gamma) is the focusing parameter that reduces the relative loss for well-classified examples
"""
def __init__(
self,
alpha: float = 1.0,
gamma: float = 2.0,
reduction: str = 'mean',
class_weights: Optional[torch.Tensor] = None
):
"""
Initialize Focal Loss
Args:
alpha: Weighting factor for class imbalance (default: 1.0)
gamma: Focusing parameter (default: 2.0)
reduction: Reduction method ('mean', 'sum', 'none')
class_weights: Optional tensor of class weights for imbalance
"""
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.class_weights = class_weights
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute focal loss
Args:
inputs: Logits of shape (batch_size,) or (batch_size, 1)
targets: Binary labels of shape (batch_size,) or (batch_size, 1)
Returns:
Focal loss value
"""
# Ensure inputs are properly shaped
if inputs.dim() > 1:
inputs = inputs.squeeze()
if targets.dim() > 1:
targets = targets.squeeze()
# Compute sigmoid probabilities
prob = torch.sigmoid(inputs)
# Convert targets to float for computation
targets = targets.float()
# Compute focal loss components
# pt = p if y=1, pt = 1-p if y=0
pt = torch.where(targets == 1, prob, 1 - prob)
# focal weight = (1 - pt)^gamma
focal_weight = (1 - pt).pow(self.gamma)
# alpha weight = alpha for positive class, 1-alpha for negative class
alpha_weight = torch.where(targets == 1,
torch.tensor(self.alpha, device=inputs.device),
torch.tensor(1 - self.alpha, device=inputs.device))
# Apply class weights if provided
if self.class_weights is not None:
alpha_weight = alpha_weight * self.class_weights[targets.long()]
# focal loss = -alpha * (1-pt)^gamma * log(pt)
focal_loss = -alpha_weight * focal_weight * pt.log()
# Apply reduction
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class EmbeddingClassifier(nn.Module):
"""

View File

@ -194,6 +194,46 @@ def load_model_from_experiment(
return model, model_config
def safe_extract_aid(metadata_entry):
"""Safely extract aid from metadata entry"""
if isinstance(metadata_entry, dict) and 'aid' in metadata_entry:
return metadata_entry['aid']
return None
def normalize_batch_metadata(metadata, expected_batch_size):
"""
Normalize batch metadata to ensure consistent structure
Args:
metadata: Raw metadata from DataLoader (could be various formats)
expected_batch_size: Expected number of metadata entries
Returns:
List of metadata dictionaries
"""
# Handle different metadata structures
if metadata is None:
return [{}] * expected_batch_size
if isinstance(metadata, dict):
# Single metadata object - duplicate for entire batch
return [metadata] * expected_batch_size
if isinstance(metadata, (list, tuple)):
if len(metadata) == expected_batch_size:
return list(metadata)
elif len(metadata) < expected_batch_size:
# Pad with empty dicts
padded = list(metadata) + [{}] * (expected_batch_size - len(metadata))
return padded
else:
# Truncate to expected size
return list(metadata[:expected_batch_size])
# Unknown format - return empty dicts
logger.warning(f"Unknown metadata format: {type(metadata)}")
return [{}] * expected_batch_size
def evaluate_model(
model,
test_loader: DataLoader,
@ -213,9 +253,7 @@ def evaluate_model(
Tuple of (metrics, predictions, probabilities, true_labels, fn_aids, fp_aids)
"""
model.eval()
criterion = torch.nn.BCEWithLogitsLoss()
total_loss = 0.0
all_predictions = []
all_labels = []
all_probabilities = []
@ -230,10 +268,6 @@ def evaluate_model(
# Forward pass
outputs = model(embeddings)
loss = criterion(outputs.squeeze(), labels)
# Collect statistics
total_loss += loss.item()
# Get predictions and probabilities
probabilities = torch.sigmoid(outputs).squeeze()
@ -244,23 +278,45 @@ def evaluate_model(
all_probabilities.extend(probabilities.cpu().numpy())
# Collect metadata and track FN/FP
batch_metadata = metadata if isinstance(metadata, list) else [metadata]
batch_size = len(labels)
batch_metadata = normalize_batch_metadata(metadata, batch_size)
all_metadata.extend(batch_metadata)
# Track FN and FP aids for this batch
logger.debug(f"Batch {batch_idx}: labels shape {labels.shape}, predictions shape {predictions.shape}, metadata structure: {type(batch_metadata)}")
if len(batch_metadata) != len(labels):
logger.warning(f"Metadata length mismatch: {len(batch_metadata)} metadata entries vs {len(labels)} samples")
for i, (true_label, pred_label) in enumerate(zip(labels.cpu().numpy(), predictions.cpu().numpy())):
if isinstance(batch_metadata[i], dict) and 'aid' in batch_metadata[i]:
aid = batch_metadata[i]['aid']
if true_label == 1 and pred_label == 0: # False Negative
fn_aids.append(aid)
elif true_label == 0 and pred_label == 1: # False Positive
fp_aids.append(aid)
try:
# Safely get metadata entry with bounds checking
if i >= len(batch_metadata):
logger.warning(f"Index {i} out of range for batch_metadata (length: {len(batch_metadata)})")
continue
meta_entry = batch_metadata[i]
if not isinstance(meta_entry, dict):
logger.warning(f"Metadata entry {i} is not a dict: {type(meta_entry)}")
continue
if 'aid' not in meta_entry:
logger.debug(f"No 'aid' key in metadata entry {i}")
continue
aid = safe_extract_aid(meta_entry)
if aid is not None:
if true_label == 1 and pred_label == 0: # False Negative
fn_aids.append(aid)
elif true_label == 0 and pred_label == 1: # False Positive
fp_aids.append(aid)
except Exception as e:
logger.warning(f"Error processing metadata entry {i}: {e}")
continue
if (batch_idx + 1) % 10 == 0:
logger.info(f"Processed {batch_idx + 1}/{len(test_loader)} batches")
# Calculate metrics
test_loss = total_loss / len(test_loader)
test_accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_predictions, average='binary', zero_division=0
@ -279,7 +335,6 @@ def evaluate_model(
tn, fp, fn, tp = 0, 0, 0, 0
metrics = {
'loss': test_loss,
'accuracy': test_accuracy,
'precision': precision,
'recall': recall,
@ -578,8 +633,6 @@ def main():
if 'failed_requests' in metrics:
logger.info(f"Failed API requests: {metrics['failed_requests']}")
logger.info("-" * 50)
if 'loss' in metrics:
logger.info(f"Loss: {metrics['loss']:.4f}")
logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
logger.info(f"Precision: {metrics['precision']:.4f}")
logger.info(f"Recall: {metrics['recall']:.4f}")

View File

@ -158,6 +158,29 @@ def parse_args():
help="Metric for model selection and early stopping"
)
# Loss function arguments
parser.add_argument(
"--loss-type",
type=str,
choices=["focal", "bce"],
default="focal",
help="Type of loss function to use"
)
parser.add_argument(
"--focal-alpha",
type=float,
default=1.0,
help="Alpha parameter for focal loss (class imbalance weighting)"
)
parser.add_argument(
"--focal-gamma",
type=float,
default=2.0,
help="Gamma parameter for focal loss (focusing parameter)"
)
# Data arguments
parser.add_argument(
"--train-ratio",
@ -423,6 +446,9 @@ def main():
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
scheduler_type=scheduler_type,
loss_type=args.loss_type,
focal_alpha=args.focal_alpha,
focal_gamma=args.focal_gamma,
device=device,
save_dir=args.save_dir,
experiment_name=args.experiment_name
@ -455,7 +481,10 @@ def main():
"weight_decay": args.weight_decay,
"scheduler": args.scheduler,
"early_stopping_patience": args.early_stopping_patience,
"validation_metric": args.validation_metric
"validation_metric": args.validation_metric,
"loss_type": args.loss_type,
"focal_alpha": args.focal_alpha,
"focal_gamma": args.focal_gamma
},
"data_config": {
"train_ratio": args.train_ratio,

View File

@ -13,7 +13,7 @@ from typing import Dict, Any, Optional
import json
from datetime import datetime
from pathlib import Path
from ml_new.training.models import EmbeddingClassifier
from ml_new.training.models import EmbeddingClassifier, FocalLoss
from ml_new.config.logger_config import get_logger
logger = get_logger(__name__)
@ -70,6 +70,12 @@ class ModelTrainer:
self.optimizer = optimizer or optim.Adam(self.model.parameters(), lr=0.001)
self.scheduler = scheduler
# Store loss function type for configuration
if criterion is None:
self.loss_type = 'bce'
else:
self.loss_type = getattr(criterion, 'loss_type', 'custom')
# Training configuration
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
@ -442,6 +448,10 @@ def create_trainer(
learning_rate: float = 0.001,
weight_decay: float = 1e-5,
scheduler_type: Optional[str] = 'plateau',
loss_type: str = 'focal', # Default to focal loss
focal_alpha: float = 1.0,
focal_gamma: float = 2.0,
class_weights: Optional[torch.Tensor] = None,
**kwargs
) -> ModelTrainer:
"""
@ -455,11 +465,30 @@ def create_trainer(
learning_rate: Initial learning rate
weight_decay: L2 regularization
scheduler_type: Learning rate scheduler type ('plateau', 'step', None)
loss_type: Type of loss function ('focal', 'bce')
focal_alpha: Alpha parameter for focal loss (class imbalance weighting)
focal_gamma: Gamma parameter for focal loss (focusing parameter)
class_weights: Optional tensor of class weights for additional balancing
**kwargs: Additional trainer arguments
Returns:
Configured ModelTrainer instance
"""
# Create loss function
if loss_type == 'focal':
criterion = FocalLoss(
alpha=focal_alpha,
gamma=focal_gamma,
reduction='mean',
class_weights=class_weights
)
logger.info(f"Using Focal Loss with alpha={focal_alpha}, gamma={focal_gamma}")
elif loss_type == 'bce':
criterion = nn.BCEWithLogitsLoss()
logger.info("Using BCE With Logits Loss")
else:
raise ValueError(f"Unsupported loss type: {loss_type}")
# Create optimizer
optimizer = optim.AdamW(
model.parameters(),
@ -490,6 +519,7 @@ def create_trainer(
val_loader=val_loader,
test_loader=test_loader,
optimizer=optimizer,
criterion=criterion,
scheduler=scheduler,
**kwargs
)