""" Neural network model for binary classification of video embeddings """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional from ml_new.config.logger_config import get_logger logger = get_logger(__name__) class FocalLoss(nn.Module): """ Focal Loss for binary classification Focal Loss = -α_t * (1 - p_t)^γ * log(p_t) Where: - p_t is the predicted probability for the true class - α_t is the weighting factor for class imbalance - γ (gamma) is the focusing parameter that reduces the relative loss for well-classified examples """ def __init__( self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean', class_weights: Optional[torch.Tensor] = None ): """ Initialize Focal Loss Args: alpha: Weighting factor for class imbalance (default: 1.0) gamma: Focusing parameter (default: 2.0) reduction: Reduction method ('mean', 'sum', 'none') class_weights: Optional tensor of class weights for imbalance """ super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction self.class_weights = class_weights def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute focal loss Args: inputs: Logits of shape (batch_size,) or (batch_size, 1) targets: Binary labels of shape (batch_size,) or (batch_size, 1) Returns: Focal loss value """ # Ensure inputs are properly shaped if inputs.dim() > 1: inputs = inputs.squeeze() if targets.dim() > 1: targets = targets.squeeze() # Compute sigmoid probabilities prob = torch.sigmoid(inputs) # Convert targets to float for computation targets = targets.float() # Compute focal loss components # pt = p if y=1, pt = 1-p if y=0 pt = torch.where(targets == 1, prob, 1 - prob) # focal weight = (1 - pt)^gamma focal_weight = (1 - pt).pow(self.gamma) # alpha weight = alpha for positive class, 1-alpha for negative class alpha_weight = torch.where(targets == 1, torch.tensor(self.alpha, device=inputs.device), torch.tensor(1 - self.alpha, device=inputs.device)) # Apply class weights if provided if self.class_weights is not None: alpha_weight = alpha_weight * self.class_weights[targets.long()] # focal loss = -alpha * (1-pt)^gamma * log(pt) focal_loss = -alpha_weight * focal_weight * pt.log() # Apply reduction if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss class EmbeddingClassifier(nn.Module): """ Neural network model for binary classification of video embeddings Architecture: - Embedding layer (configurable input dimension) - Hidden layers with dropout and batch normalization - Binary classification head """ def __init__( self, input_dim: int = 2048, hidden_dims: Optional[Tuple[int, ...]] = None, dropout_rate: float = 0.3, batch_norm: bool = True, activation: str = "relu" ): """ Initialize the embedding classifier model Args: input_dim: Dimension of input embeddings (default: 2048 for qwen3-embedding) hidden_dims: Tuple of hidden layer dimensions (default: (512, 256, 128)) dropout_rate: Dropout probability for regularization batch_norm: Whether to use batch normalization activation: Activation function ('relu', 'gelu', 'tanh') """ super(EmbeddingClassifier, self).__init__() # Default hidden dimensions if not provided if hidden_dims is None: hidden_dims = (512, 256, 128) self.input_dim = input_dim self.hidden_dims = hidden_dims self.dropout_rate = dropout_rate self.batch_norm = batch_norm # Build layers self.layers = nn.ModuleList() # Input dimension to first hidden layer prev_dim = input_dim for i, hidden_dim in enumerate(hidden_dims): # Linear layer linear_layer = nn.Linear(prev_dim, hidden_dim) self.layers.append(linear_layer) # Batch normalization (optional) if batch_norm: bn_layer = nn.BatchNorm1d(hidden_dim) self.layers.append(bn_layer) # Activation function activation_layer = self._get_activation(activation) self.layers.append(activation_layer) # Dropout dropout_layer = nn.Dropout(dropout_rate) self.layers.append(dropout_layer) prev_dim = hidden_dim # Binary classification head self.classifier = nn.Linear(prev_dim, 1) # Initialize weights self._initialize_weights() logger.info(f"Initialized EmbeddingClassifier: input_dim={input_dim}, hidden_dims={hidden_dims}") def _get_activation(self, activation: str) -> nn.Module: """Get activation function module""" activations = { 'relu': nn.ReLU(), 'gelu': nn.GELU(), 'tanh': nn.Tanh(), 'leaky_relu': nn.LeakyReLU(0.1), 'elu': nn.ELU() } if activation not in activations: logger.warning(f"Unknown activation '{activation}', using ReLU") return nn.ReLU() return activations[activation] def _initialize_weights(self): """Initialize model weights using Xavier/Glorot initialization""" for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.BatchNorm1d): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the network Args: x: Input embeddings of shape (batch_size, input_dim) Returns: logits of shape (batch_size, 1) """ # Ensure input is float tensor if not x.dtype == torch.float32: x = x.float() # Flatten input if it's multi-dimensional (shouldn't be for embeddings) if x.dim() > 2: x = x.view(x.size(0), -1) # Process through layers for layer in self.layers: x = layer(x) # Final classification layer logits = self.classifier(x) return logits def predict_proba(self, x: torch.Tensor) -> torch.Tensor: """ Predict class probabilities Args: x: Input embeddings Returns: Class probabilities of shape (batch_size, 2) """ logits = self.forward(x) probabilities = torch.sigmoid(logits) # Convert to [negative_prob, positive_prob] format prob_0 = 1 - probabilities prob_1 = probabilities return torch.cat([prob_0, prob_1], dim=1) def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: """ Predict class labels Args: x: Input embeddings threshold: Classification threshold Returns: Binary predictions of shape (batch_size,) """ probabilities = self.predict_proba(x) predictions = (probabilities[:, 1] > threshold).long() return predictions def get_model_info(self) -> dict: """Get model architecture information""" total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) return { 'model_class': self.__class__.__name__, 'input_dim': self.input_dim, 'hidden_dims': self.hidden_dims, 'dropout_rate': self.dropout_rate, 'batch_norm': self.batch_norm, 'total_parameters': total_params, 'trainable_parameters': trainable_params, 'model_size_mb': total_params * 4 / (1024 * 1024) # Assuming float32 } def create_model( input_dim: int = 2048, hidden_dims: Optional[Tuple[int, ...]] = None, **kwargs ) -> EmbeddingClassifier: """ Factory function to create embedding classifier models Args: input_dim: Input embedding dimension hidden_dims: Hidden layer dimensions **kwargs: Additional model arguments Returns: Initialized model """ return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs) if __name__ == "__main__": # Test model creation and forward pass model = create_model( input_dim=2048, hidden_dims=(512, 256, 128), dropout_rate=0.3 ) # Print model info info = model.get_model_info() print("Model Information:") for key, value in info.items(): print(f" {key}: {value}") # Test forward pass batch_size = 8 dummy_input = torch.randn(batch_size, 2048) with torch.no_grad(): logits = model(dummy_input) probabilities = model.predict_proba(dummy_input) predictions = model.predict(dummy_input) print(f"\nTest Results:") print(f" Input shape: {dummy_input.shape}") print(f" Logits shape: {logits.shape}") print(f" Probabilities shape: {probabilities.shape}") print(f" Predictions shape: {predictions.shape}") print(f" Sample predictions: {predictions.tolist()}")