323 lines
10 KiB
Python
323 lines
10 KiB
Python
"""
|
||
Neural network model for binary classification of video embeddings
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from typing import Tuple, Optional
|
||
from ml_new.config.logger_config import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
class FocalLoss(nn.Module):
|
||
"""
|
||
Focal Loss for binary classification
|
||
|
||
Focal Loss = -α_t * (1 - p_t)^γ * log(p_t)
|
||
|
||
Where:
|
||
- p_t is the predicted probability for the true class
|
||
- α_t is the weighting factor for class imbalance
|
||
- γ (gamma) is the focusing parameter that reduces the relative loss for well-classified examples
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
alpha: float = 1.0,
|
||
gamma: float = 2.0,
|
||
reduction: str = 'mean',
|
||
class_weights: Optional[torch.Tensor] = None
|
||
):
|
||
"""
|
||
Initialize Focal Loss
|
||
|
||
Args:
|
||
alpha: Weighting factor for class imbalance (default: 1.0)
|
||
gamma: Focusing parameter (default: 2.0)
|
||
reduction: Reduction method ('mean', 'sum', 'none')
|
||
class_weights: Optional tensor of class weights for imbalance
|
||
"""
|
||
super(FocalLoss, self).__init__()
|
||
self.alpha = alpha
|
||
self.gamma = gamma
|
||
self.reduction = reduction
|
||
self.class_weights = class_weights
|
||
|
||
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Compute focal loss
|
||
|
||
Args:
|
||
inputs: Logits of shape (batch_size,) or (batch_size, 1)
|
||
targets: Binary labels of shape (batch_size,) or (batch_size, 1)
|
||
|
||
Returns:
|
||
Focal loss value
|
||
"""
|
||
# Ensure inputs are properly shaped
|
||
if inputs.dim() > 1:
|
||
inputs = inputs.squeeze()
|
||
if targets.dim() > 1:
|
||
targets = targets.squeeze()
|
||
|
||
# Compute sigmoid probabilities
|
||
prob = torch.sigmoid(inputs)
|
||
|
||
# Convert targets to float for computation
|
||
targets = targets.float()
|
||
|
||
# Compute focal loss components
|
||
# pt = p if y=1, pt = 1-p if y=0
|
||
pt = torch.where(targets == 1, prob, 1 - prob)
|
||
|
||
# focal weight = (1 - pt)^gamma
|
||
focal_weight = (1 - pt).pow(self.gamma)
|
||
|
||
# alpha weight = alpha for positive class, 1-alpha for negative class
|
||
alpha_weight = torch.where(targets == 1,
|
||
torch.tensor(self.alpha, device=inputs.device),
|
||
torch.tensor(1 - self.alpha, device=inputs.device))
|
||
|
||
# Apply class weights if provided
|
||
if self.class_weights is not None:
|
||
alpha_weight = alpha_weight * self.class_weights[targets.long()]
|
||
|
||
# focal loss = -alpha * (1-pt)^gamma * log(pt)
|
||
focal_loss = -alpha_weight * focal_weight * pt.log()
|
||
|
||
# Apply reduction
|
||
if self.reduction == 'mean':
|
||
return focal_loss.mean()
|
||
elif self.reduction == 'sum':
|
||
return focal_loss.sum()
|
||
else:
|
||
return focal_loss
|
||
|
||
|
||
class EmbeddingClassifier(nn.Module):
|
||
"""
|
||
Neural network model for binary classification of video embeddings
|
||
|
||
Architecture:
|
||
- Embedding layer (configurable input dimension)
|
||
- Hidden layers with dropout and batch normalization
|
||
- Binary classification head
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
input_dim: int = 2048,
|
||
hidden_dims: Optional[Tuple[int, ...]] = None,
|
||
dropout_rate: float = 0.3,
|
||
batch_norm: bool = True,
|
||
activation: str = "relu"
|
||
):
|
||
"""
|
||
Initialize the embedding classifier model
|
||
|
||
Args:
|
||
input_dim: Dimension of input embeddings (default: 2048 for qwen3-embedding)
|
||
hidden_dims: Tuple of hidden layer dimensions (default: (512, 256, 128))
|
||
dropout_rate: Dropout probability for regularization
|
||
batch_norm: Whether to use batch normalization
|
||
activation: Activation function ('relu', 'gelu', 'tanh')
|
||
"""
|
||
super(EmbeddingClassifier, self).__init__()
|
||
|
||
# Default hidden dimensions if not provided
|
||
if hidden_dims is None:
|
||
hidden_dims = (512, 256, 128)
|
||
|
||
self.input_dim = input_dim
|
||
self.hidden_dims = hidden_dims
|
||
self.dropout_rate = dropout_rate
|
||
self.batch_norm = batch_norm
|
||
|
||
# Build layers
|
||
self.layers = nn.ModuleList()
|
||
|
||
# Input dimension to first hidden layer
|
||
prev_dim = input_dim
|
||
|
||
for i, hidden_dim in enumerate(hidden_dims):
|
||
# Linear layer
|
||
linear_layer = nn.Linear(prev_dim, hidden_dim)
|
||
self.layers.append(linear_layer)
|
||
|
||
# Batch normalization (optional)
|
||
if batch_norm:
|
||
bn_layer = nn.BatchNorm1d(hidden_dim)
|
||
self.layers.append(bn_layer)
|
||
|
||
# Activation function
|
||
activation_layer = self._get_activation(activation)
|
||
self.layers.append(activation_layer)
|
||
|
||
# Dropout
|
||
dropout_layer = nn.Dropout(dropout_rate)
|
||
self.layers.append(dropout_layer)
|
||
|
||
prev_dim = hidden_dim
|
||
|
||
# Binary classification head
|
||
self.classifier = nn.Linear(prev_dim, 1)
|
||
|
||
# Initialize weights
|
||
self._initialize_weights()
|
||
|
||
logger.info(f"Initialized EmbeddingClassifier: input_dim={input_dim}, hidden_dims={hidden_dims}")
|
||
|
||
def _get_activation(self, activation: str) -> nn.Module:
|
||
"""Get activation function module"""
|
||
activations = {
|
||
'relu': nn.ReLU(),
|
||
'gelu': nn.GELU(),
|
||
'tanh': nn.Tanh(),
|
||
'leaky_relu': nn.LeakyReLU(0.1),
|
||
'elu': nn.ELU()
|
||
}
|
||
|
||
if activation not in activations:
|
||
logger.warning(f"Unknown activation '{activation}', using ReLU")
|
||
return nn.ReLU()
|
||
|
||
return activations[activation]
|
||
|
||
def _initialize_weights(self):
|
||
"""Initialize model weights using Xavier/Glorot initialization"""
|
||
for module in self.modules():
|
||
if isinstance(module, nn.Linear):
|
||
nn.init.xavier_uniform_(module.weight)
|
||
if module.bias is not None:
|
||
nn.init.constant_(module.bias, 0)
|
||
elif isinstance(module, nn.BatchNorm1d):
|
||
nn.init.constant_(module.weight, 1)
|
||
nn.init.constant_(module.bias, 0)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Forward pass through the network
|
||
|
||
Args:
|
||
x: Input embeddings of shape (batch_size, input_dim)
|
||
|
||
Returns:
|
||
logits of shape (batch_size, 1)
|
||
"""
|
||
# Ensure input is float tensor
|
||
if not x.dtype == torch.float32:
|
||
x = x.float()
|
||
|
||
# Flatten input if it's multi-dimensional (shouldn't be for embeddings)
|
||
if x.dim() > 2:
|
||
x = x.view(x.size(0), -1)
|
||
|
||
# Process through layers
|
||
for layer in self.layers:
|
||
x = layer(x)
|
||
|
||
# Final classification layer
|
||
logits = self.classifier(x)
|
||
|
||
return logits
|
||
|
||
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Predict class probabilities
|
||
|
||
Args:
|
||
x: Input embeddings
|
||
|
||
Returns:
|
||
Class probabilities of shape (batch_size, 2)
|
||
"""
|
||
logits = self.forward(x)
|
||
probabilities = torch.sigmoid(logits)
|
||
|
||
# Convert to [negative_prob, positive_prob] format
|
||
prob_0 = 1 - probabilities
|
||
prob_1 = probabilities
|
||
|
||
return torch.cat([prob_0, prob_1], dim=1)
|
||
|
||
def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
|
||
"""
|
||
Predict class labels
|
||
|
||
Args:
|
||
x: Input embeddings
|
||
threshold: Classification threshold
|
||
|
||
Returns:
|
||
Binary predictions of shape (batch_size,)
|
||
"""
|
||
probabilities = self.predict_proba(x)
|
||
predictions = (probabilities[:, 1] > threshold).long()
|
||
return predictions
|
||
|
||
def get_model_info(self) -> dict:
|
||
"""Get model architecture information"""
|
||
total_params = sum(p.numel() for p in self.parameters())
|
||
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||
|
||
return {
|
||
'model_class': self.__class__.__name__,
|
||
'input_dim': self.input_dim,
|
||
'hidden_dims': self.hidden_dims,
|
||
'dropout_rate': self.dropout_rate,
|
||
'batch_norm': self.batch_norm,
|
||
'total_parameters': total_params,
|
||
'trainable_parameters': trainable_params,
|
||
'model_size_mb': total_params * 4 / (1024 * 1024) # Assuming float32
|
||
}
|
||
|
||
|
||
def create_model(
|
||
input_dim: int = 2048,
|
||
hidden_dims: Optional[Tuple[int, ...]] = None,
|
||
**kwargs
|
||
) -> EmbeddingClassifier:
|
||
"""
|
||
Factory function to create embedding classifier models
|
||
|
||
Args:
|
||
input_dim: Input embedding dimension
|
||
hidden_dims: Hidden layer dimensions
|
||
**kwargs: Additional model arguments
|
||
|
||
Returns:
|
||
Initialized model
|
||
"""
|
||
return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# Test model creation and forward pass
|
||
model = create_model(
|
||
input_dim=2048,
|
||
hidden_dims=(512, 256, 128),
|
||
dropout_rate=0.3
|
||
)
|
||
|
||
# Print model info
|
||
info = model.get_model_info()
|
||
print("Model Information:")
|
||
for key, value in info.items():
|
||
print(f" {key}: {value}")
|
||
|
||
# Test forward pass
|
||
batch_size = 8
|
||
dummy_input = torch.randn(batch_size, 2048)
|
||
|
||
with torch.no_grad():
|
||
logits = model(dummy_input)
|
||
probabilities = model.predict_proba(dummy_input)
|
||
predictions = model.predict(dummy_input)
|
||
|
||
print(f"\nTest Results:")
|
||
print(f" Input shape: {dummy_input.shape}")
|
||
print(f" Logits shape: {logits.shape}")
|
||
print(f" Probabilities shape: {probabilities.shape}")
|
||
print(f" Predictions shape: {predictions.shape}")
|
||
print(f" Sample predictions: {predictions.tolist()}") |