1
0
cvsa/ml_new/training/models.py
2025-12-17 03:25:00 +08:00

323 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Neural network model for binary classification of video embeddings
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
from ml_new.config.logger_config import get_logger
logger = get_logger(__name__)
class FocalLoss(nn.Module):
"""
Focal Loss for binary classification
Focal Loss = -α_t * (1 - p_t)^γ * log(p_t)
Where:
- p_t is the predicted probability for the true class
- α_t is the weighting factor for class imbalance
- γ (gamma) is the focusing parameter that reduces the relative loss for well-classified examples
"""
def __init__(
self,
alpha: float = 1.0,
gamma: float = 2.0,
reduction: str = 'mean',
class_weights: Optional[torch.Tensor] = None
):
"""
Initialize Focal Loss
Args:
alpha: Weighting factor for class imbalance (default: 1.0)
gamma: Focusing parameter (default: 2.0)
reduction: Reduction method ('mean', 'sum', 'none')
class_weights: Optional tensor of class weights for imbalance
"""
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.class_weights = class_weights
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute focal loss
Args:
inputs: Logits of shape (batch_size,) or (batch_size, 1)
targets: Binary labels of shape (batch_size,) or (batch_size, 1)
Returns:
Focal loss value
"""
# Ensure inputs are properly shaped
if inputs.dim() > 1:
inputs = inputs.squeeze()
if targets.dim() > 1:
targets = targets.squeeze()
# Compute sigmoid probabilities
prob = torch.sigmoid(inputs)
# Convert targets to float for computation
targets = targets.float()
# Compute focal loss components
# pt = p if y=1, pt = 1-p if y=0
pt = torch.where(targets == 1, prob, 1 - prob)
# focal weight = (1 - pt)^gamma
focal_weight = (1 - pt).pow(self.gamma)
# alpha weight = alpha for positive class, 1-alpha for negative class
alpha_weight = torch.where(targets == 1,
torch.tensor(self.alpha, device=inputs.device),
torch.tensor(1 - self.alpha, device=inputs.device))
# Apply class weights if provided
if self.class_weights is not None:
alpha_weight = alpha_weight * self.class_weights[targets.long()]
# focal loss = -alpha * (1-pt)^gamma * log(pt)
focal_loss = -alpha_weight * focal_weight * pt.log()
# Apply reduction
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class EmbeddingClassifier(nn.Module):
"""
Neural network model for binary classification of video embeddings
Architecture:
- Embedding layer (configurable input dimension)
- Hidden layers with dropout and batch normalization
- Binary classification head
"""
def __init__(
self,
input_dim: int = 2048,
hidden_dims: Optional[Tuple[int, ...]] = None,
dropout_rate: float = 0.3,
batch_norm: bool = True,
activation: str = "relu"
):
"""
Initialize the embedding classifier model
Args:
input_dim: Dimension of input embeddings (default: 2048 for qwen3-embedding)
hidden_dims: Tuple of hidden layer dimensions (default: (512, 256, 128))
dropout_rate: Dropout probability for regularization
batch_norm: Whether to use batch normalization
activation: Activation function ('relu', 'gelu', 'tanh')
"""
super(EmbeddingClassifier, self).__init__()
# Default hidden dimensions if not provided
if hidden_dims is None:
hidden_dims = (512, 256, 128)
self.input_dim = input_dim
self.hidden_dims = hidden_dims
self.dropout_rate = dropout_rate
self.batch_norm = batch_norm
# Build layers
self.layers = nn.ModuleList()
# Input dimension to first hidden layer
prev_dim = input_dim
for i, hidden_dim in enumerate(hidden_dims):
# Linear layer
linear_layer = nn.Linear(prev_dim, hidden_dim)
self.layers.append(linear_layer)
# Batch normalization (optional)
if batch_norm:
bn_layer = nn.BatchNorm1d(hidden_dim)
self.layers.append(bn_layer)
# Activation function
activation_layer = self._get_activation(activation)
self.layers.append(activation_layer)
# Dropout
dropout_layer = nn.Dropout(dropout_rate)
self.layers.append(dropout_layer)
prev_dim = hidden_dim
# Binary classification head
self.classifier = nn.Linear(prev_dim, 1)
# Initialize weights
self._initialize_weights()
logger.info(f"Initialized EmbeddingClassifier: input_dim={input_dim}, hidden_dims={hidden_dims}")
def _get_activation(self, activation: str) -> nn.Module:
"""Get activation function module"""
activations = {
'relu': nn.ReLU(),
'gelu': nn.GELU(),
'tanh': nn.Tanh(),
'leaky_relu': nn.LeakyReLU(0.1),
'elu': nn.ELU()
}
if activation not in activations:
logger.warning(f"Unknown activation '{activation}', using ReLU")
return nn.ReLU()
return activations[activation]
def _initialize_weights(self):
"""Initialize model weights using Xavier/Glorot initialization"""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.BatchNorm1d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the network
Args:
x: Input embeddings of shape (batch_size, input_dim)
Returns:
logits of shape (batch_size, 1)
"""
# Ensure input is float tensor
if not x.dtype == torch.float32:
x = x.float()
# Flatten input if it's multi-dimensional (shouldn't be for embeddings)
if x.dim() > 2:
x = x.view(x.size(0), -1)
# Process through layers
for layer in self.layers:
x = layer(x)
# Final classification layer
logits = self.classifier(x)
return logits
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
"""
Predict class probabilities
Args:
x: Input embeddings
Returns:
Class probabilities of shape (batch_size, 2)
"""
logits = self.forward(x)
probabilities = torch.sigmoid(logits)
# Convert to [negative_prob, positive_prob] format
prob_0 = 1 - probabilities
prob_1 = probabilities
return torch.cat([prob_0, prob_1], dim=1)
def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
"""
Predict class labels
Args:
x: Input embeddings
threshold: Classification threshold
Returns:
Binary predictions of shape (batch_size,)
"""
probabilities = self.predict_proba(x)
predictions = (probabilities[:, 1] > threshold).long()
return predictions
def get_model_info(self) -> dict:
"""Get model architecture information"""
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
'model_class': self.__class__.__name__,
'input_dim': self.input_dim,
'hidden_dims': self.hidden_dims,
'dropout_rate': self.dropout_rate,
'batch_norm': self.batch_norm,
'total_parameters': total_params,
'trainable_parameters': trainable_params,
'model_size_mb': total_params * 4 / (1024 * 1024) # Assuming float32
}
def create_model(
input_dim: int = 2048,
hidden_dims: Optional[Tuple[int, ...]] = None,
**kwargs
) -> EmbeddingClassifier:
"""
Factory function to create embedding classifier models
Args:
input_dim: Input embedding dimension
hidden_dims: Hidden layer dimensions
**kwargs: Additional model arguments
Returns:
Initialized model
"""
return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
if __name__ == "__main__":
# Test model creation and forward pass
model = create_model(
input_dim=2048,
hidden_dims=(512, 256, 128),
dropout_rate=0.3
)
# Print model info
info = model.get_model_info()
print("Model Information:")
for key, value in info.items():
print(f" {key}: {value}")
# Test forward pass
batch_size = 8
dummy_input = torch.randn(batch_size, 2048)
with torch.no_grad():
logits = model(dummy_input)
probabilities = model.predict_proba(dummy_input)
predictions = model.predict(dummy_input)
print(f"\nTest Results:")
print(f" Input shape: {dummy_input.shape}")
print(f" Logits shape: {logits.shape}")
print(f" Probabilities shape: {probabilities.shape}")
print(f" Predictions shape: {predictions.shape}")
print(f" Sample predictions: {predictions.tolist()}")