add: training script
This commit is contained in:
parent
664784dd3e
commit
c14c680228
9
ml_new/training/__init__.py
Normal file
9
ml_new/training/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
"""
|
||||
Training module for ML models
|
||||
"""
|
||||
|
||||
from .models import EmbeddingClassifier
|
||||
from .trainer import ModelTrainer
|
||||
from .data_loader import DatasetLoader
|
||||
|
||||
__all__ = ['EmbeddingClassifier', 'ModelTrainer', 'DatasetLoader']
|
||||
389
ml_new/training/data_loader.py
Normal file
389
ml_new/training/data_loader.py
Normal file
@ -0,0 +1,389 @@
|
||||
"""
|
||||
Data loader for embedding datasets
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import json
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EmbeddingDataset(Dataset):
|
||||
"""
|
||||
PyTorch Dataset for embedding-based classification
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embeddings: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
metadata: Optional[List[Dict[str, Any]]] = None,
|
||||
transform: Optional[callable] = None,
|
||||
normalize: bool = True
|
||||
):
|
||||
"""
|
||||
Initialize embedding dataset
|
||||
|
||||
Args:
|
||||
embeddings: Array of embedding vectors (n_samples, embedding_dim)
|
||||
labels: Array of binary labels (n_samples,)
|
||||
metadata: Optional list of metadata dictionaries
|
||||
transform: Optional transformation function
|
||||
normalize: Whether to normalize embeddings
|
||||
"""
|
||||
assert len(embeddings) == len(labels), "Embeddings and labels must have same length"
|
||||
|
||||
self.embeddings = embeddings.astype(np.float32)
|
||||
self.labels = labels.astype(np.int64)
|
||||
self.metadata = metadata or []
|
||||
self.transform = transform
|
||||
|
||||
# Normalize embeddings if requested
|
||||
if normalize and len(embeddings) > 0:
|
||||
self.scaler = StandardScaler()
|
||||
self.embeddings = self.scaler.fit_transform(self.embeddings)
|
||||
else:
|
||||
self.scaler = None
|
||||
|
||||
# Calculate class weights for balanced sampling
|
||||
self._calculate_class_weights()
|
||||
|
||||
def _calculate_class_weights(self):
|
||||
"""Calculate weights for each class for balanced sampling"""
|
||||
unique, counts = np.unique(self.labels, return_counts=True)
|
||||
total_samples = len(self.labels)
|
||||
|
||||
self.class_weights = {}
|
||||
for class_label, count in zip(unique, counts):
|
||||
# Inverse frequency weighting
|
||||
weight = total_samples / (2 * count)
|
||||
self.class_weights[class_label] = weight
|
||||
|
||||
logger.info(f"Class distribution: {dict(zip(unique, counts))}")
|
||||
logger.info(f"Class weights: {self.class_weights}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.embeddings)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
"""
|
||||
Get a single sample from the dataset
|
||||
|
||||
Returns:
|
||||
tuple: (embedding, label, metadata)
|
||||
"""
|
||||
embedding = torch.from_numpy(self.embeddings[idx])
|
||||
label = torch.tensor(self.labels[idx], dtype=torch.long)
|
||||
|
||||
metadata = {}
|
||||
if self.metadata and idx < len(self.metadata):
|
||||
metadata = self.metadata[idx]
|
||||
|
||||
if self.transform:
|
||||
embedding = self.transform(embedding)
|
||||
|
||||
return embedding, label, metadata
|
||||
|
||||
|
||||
class DatasetLoader:
|
||||
"""
|
||||
Loader for embedding datasets stored in Parquet format
|
||||
"""
|
||||
|
||||
def __init__(self, datasets_dir: str = "training/datasets"):
|
||||
"""
|
||||
Initialize dataset loader
|
||||
|
||||
Args:
|
||||
datasets_dir: Directory containing dataset files
|
||||
"""
|
||||
self.datasets_dir = Path(datasets_dir)
|
||||
self.datasets_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def load_dataset(self, dataset_id: str) -> Tuple[np.ndarray, np.ndarray, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Load a dataset by ID from Parquet files
|
||||
|
||||
Args:
|
||||
dataset_id: Unique identifier for the dataset
|
||||
|
||||
Returns:
|
||||
tuple: (embeddings, labels, metadata_list)
|
||||
"""
|
||||
dataset_file = self.datasets_dir / f"{dataset_id}.parquet"
|
||||
metadata_file = self.datasets_dir / f"{dataset_id}.metadata.json"
|
||||
|
||||
if not dataset_file.exists():
|
||||
raise FileNotFoundError(f"Dataset file not found: {dataset_file}")
|
||||
|
||||
# Load metadata
|
||||
metadata = {}
|
||||
if metadata_file.exists():
|
||||
with open(metadata_file, 'r') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Load data from Parquet
|
||||
logger.info(f"Loading dataset {dataset_id} from {dataset_file}")
|
||||
df = pd.read_parquet(dataset_file)
|
||||
|
||||
# Extract embeddings (they might be stored as list or numpy array)
|
||||
embeddings = self._extract_embeddings(df)
|
||||
|
||||
# Extract labels
|
||||
labels = df['label'].values.astype(np.int64)
|
||||
|
||||
# Extract metadata
|
||||
metadata_list = []
|
||||
if 'metadata_json' in df.columns:
|
||||
for _, row in df.iterrows():
|
||||
meta = {}
|
||||
if pd.notna(row.get('metadata_json')):
|
||||
try:
|
||||
meta = json.loads(row['metadata_json'])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
meta = {}
|
||||
|
||||
# Add other fields
|
||||
meta.update({
|
||||
'aid': row.get('aid'),
|
||||
'inconsistent': row.get('inconsistent', False),
|
||||
'text_checksum': row.get('text_checksum')
|
||||
})
|
||||
metadata_list.append(meta)
|
||||
else:
|
||||
# Create basic metadata
|
||||
metadata_list = [{
|
||||
'aid': aid,
|
||||
'inconsistent': inconsistent,
|
||||
'text_checksum': checksum
|
||||
} for aid, inconsistent, checksum in zip(
|
||||
df.get('aid', []),
|
||||
df.get('inconsistent', [False] * len(df)),
|
||||
df.get('text_checksum', [''] * len(df))
|
||||
)]
|
||||
|
||||
logger.info(f"Loaded dataset with {len(embeddings)} samples, {embeddings.shape[1]} embedding dimensions")
|
||||
|
||||
return embeddings, labels, metadata_list
|
||||
|
||||
def _extract_embeddings(self, df: pd.DataFrame) -> np.ndarray:
|
||||
"""Extract embeddings from DataFrame, handling different storage formats"""
|
||||
embedding_col = None
|
||||
for col in ['embedding', 'embeddings', 'vec_2048', 'vec_1024']:
|
||||
if col in df.columns:
|
||||
embedding_col = col
|
||||
break
|
||||
|
||||
if embedding_col is None:
|
||||
raise ValueError("No embedding column found in dataset")
|
||||
|
||||
embeddings_data = df[embedding_col]
|
||||
|
||||
# Handle different embedding storage formats
|
||||
if embeddings_data.dtype == 'object':
|
||||
# Likely stored as lists or numpy arrays
|
||||
embeddings = np.array([
|
||||
np.array(emb) if isinstance(emb, (list, np.ndarray)) else np.zeros(2048)
|
||||
for emb in embeddings_data
|
||||
])
|
||||
else:
|
||||
# Already numpy array
|
||||
embeddings = embeddings_data.values
|
||||
|
||||
# Ensure 2D array
|
||||
if embeddings.ndim == 1:
|
||||
# If embeddings are flattened, reshape
|
||||
embedding_dim = len(embeddings) // len(df)
|
||||
embeddings = embeddings.reshape(len(df), embedding_dim)
|
||||
|
||||
return embeddings.astype(np.float32)
|
||||
|
||||
def create_data_loaders(
|
||||
self,
|
||||
dataset_id: str,
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
random_state: int = 42,
|
||||
normalize: bool = True,
|
||||
use_weighted_sampler: bool = True
|
||||
) -> Tuple[DataLoader, DataLoader, DataLoader, Dict[str, Any]]:
|
||||
"""
|
||||
Create train, validation, and test data loaders
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset identifier
|
||||
train_ratio: Proportion of data for training
|
||||
val_ratio: Proportion of data for validation
|
||||
batch_size: Batch size for data loaders
|
||||
num_workers: Number of worker processes
|
||||
random_state: Random seed for reproducibility
|
||||
normalize: Whether to normalize embeddings
|
||||
use_weighted_sampler: Whether to use weighted random sampling
|
||||
|
||||
Returns:
|
||||
tuple: (train_loader, val_loader, test_loader, dataset_info)
|
||||
"""
|
||||
# Load dataset
|
||||
embeddings, labels, metadata = self.load_dataset(dataset_id)
|
||||
|
||||
# Split data
|
||||
(
|
||||
train_emb, test_emb,
|
||||
train_lbl, test_lbl,
|
||||
train_meta, test_meta
|
||||
) = train_test_split(
|
||||
embeddings, labels, metadata,
|
||||
test_size=1 - train_ratio,
|
||||
stratify=labels,
|
||||
random_state=random_state
|
||||
)
|
||||
|
||||
# Split test into val and test
|
||||
val_size = val_ratio / (val_ratio + (1 - train_ratio - val_ratio))
|
||||
(
|
||||
val_emb, test_emb,
|
||||
val_lbl, test_lbl,
|
||||
val_meta, test_meta
|
||||
) = train_test_split(
|
||||
test_emb, test_lbl, test_meta,
|
||||
test_size=1 - val_size,
|
||||
stratify=test_lbl,
|
||||
random_state=random_state
|
||||
)
|
||||
|
||||
# Create datasets
|
||||
train_dataset = EmbeddingDataset(train_emb, train_lbl, train_meta, normalize=normalize)
|
||||
val_dataset = EmbeddingDataset(val_emb, val_lbl, val_meta, normalize=False) # Don't re-normalize
|
||||
test_dataset = EmbeddingDataset(test_emb, test_lbl, test_meta, normalize=False)
|
||||
|
||||
# Create samplers
|
||||
train_sampler = None
|
||||
if use_weighted_sampler and hasattr(train_dataset, 'class_weights'):
|
||||
# Create weighted sampler for balanced training
|
||||
sample_weights = [train_dataset.class_weights[label] for label in train_dataset.labels]
|
||||
train_sampler = WeightedRandomSampler(
|
||||
weights=sample_weights,
|
||||
num_samples=len(sample_weights),
|
||||
replacement=True
|
||||
)
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=train_sampler,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
# Dataset info
|
||||
dataset_info = {
|
||||
'dataset_id': dataset_id,
|
||||
'total_samples': len(embeddings),
|
||||
'embedding_dim': embeddings.shape[1],
|
||||
'train_samples': len(train_dataset),
|
||||
'val_samples': len(val_dataset),
|
||||
'test_samples': len(test_dataset),
|
||||
'train_ratio': len(train_dataset) / len(embeddings),
|
||||
'val_ratio': len(val_dataset) / len(embeddings),
|
||||
'test_ratio': len(test_dataset) / len(embeddings),
|
||||
'class_distribution': {
|
||||
'train': dict(zip(*np.unique(train_dataset.labels, return_counts=True))),
|
||||
'val': dict(zip(*np.unique(val_dataset.labels, return_counts=True))),
|
||||
'test': dict(zip(*np.unique(test_dataset.labels, return_counts=True)))
|
||||
},
|
||||
'normalize': normalize,
|
||||
'use_weighted_sampler': use_weighted_sampler
|
||||
}
|
||||
|
||||
logger.info(f"Created data loaders: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}")
|
||||
|
||||
return train_loader, val_loader, test_loader, dataset_info
|
||||
|
||||
def list_datasets(self) -> List[str]:
|
||||
"""List all available datasets"""
|
||||
parquet_files = list(self.datasets_dir.glob("*.parquet"))
|
||||
return [f.stem for f in parquet_files]
|
||||
|
||||
def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]:
|
||||
"""Get detailed information about a dataset"""
|
||||
metadata_file = self.datasets_dir / f"{dataset_id}.metadata.json"
|
||||
|
||||
if metadata_file.exists():
|
||||
with open(metadata_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
# Fallback: load dataset and return basic info
|
||||
embeddings, labels, metadata = self.load_dataset(dataset_id)
|
||||
return {
|
||||
'dataset_id': dataset_id,
|
||||
'total_samples': len(embeddings),
|
||||
'embedding_dim': embeddings.shape[1],
|
||||
'class_distribution': dict(zip(*np.unique(labels, return_counts=True))),
|
||||
'file_format': 'parquet',
|
||||
'created_at': 'unknown'
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test dataset loading
|
||||
loader = DatasetLoader()
|
||||
|
||||
# List available datasets
|
||||
datasets = loader.list_datasets()
|
||||
print(f"Available datasets: {datasets}")
|
||||
|
||||
if datasets:
|
||||
# Test loading first dataset
|
||||
dataset_id = datasets[0]
|
||||
print(f"\nTesting dataset: {dataset_id}")
|
||||
|
||||
info = loader.get_dataset_info(dataset_id)
|
||||
print("Dataset info:", info)
|
||||
|
||||
# Test creating data loaders
|
||||
try:
|
||||
train_loader, val_loader, test_loader, data_info = loader.create_data_loaders(
|
||||
dataset_id,
|
||||
batch_size=8,
|
||||
normalize=True
|
||||
)
|
||||
|
||||
print("\nData loader info:")
|
||||
for key, value in data_info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Test single batch
|
||||
for batch_embeddings, batch_labels, batch_metadata in train_loader:
|
||||
print(f"\nBatch test:")
|
||||
print(f" Embeddings shape: {batch_embeddings.shape}")
|
||||
print(f" Labels shape: {batch_labels.shape}")
|
||||
print(f" Sample labels: {batch_labels[:5].tolist()}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating data loaders: {e}")
|
||||
324
ml_new/training/models.py
Normal file
324
ml_new/training/models.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""
|
||||
Neural network model for binary classification of video embeddings
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple, Optional
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EmbeddingClassifier(nn.Module):
|
||||
"""
|
||||
Neural network model for binary classification of video embeddings
|
||||
|
||||
Architecture:
|
||||
- Embedding layer (configurable input dimension)
|
||||
- Hidden layers with dropout and batch normalization
|
||||
- Binary classification head
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 2048,
|
||||
hidden_dims: Optional[Tuple[int, ...]] = None,
|
||||
dropout_rate: float = 0.3,
|
||||
batch_norm: bool = True,
|
||||
activation: str = "relu"
|
||||
):
|
||||
"""
|
||||
Initialize the embedding classifier model
|
||||
|
||||
Args:
|
||||
input_dim: Dimension of input embeddings (default: 2048 for qwen3-embedding)
|
||||
hidden_dims: Tuple of hidden layer dimensions (default: (512, 256, 128))
|
||||
dropout_rate: Dropout probability for regularization
|
||||
batch_norm: Whether to use batch normalization
|
||||
activation: Activation function ('relu', 'gelu', 'tanh')
|
||||
"""
|
||||
super(EmbeddingClassifier, self).__init__()
|
||||
|
||||
# Default hidden dimensions if not provided
|
||||
if hidden_dims is None:
|
||||
hidden_dims = (512, 256, 128)
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dims = hidden_dims
|
||||
self.dropout_rate = dropout_rate
|
||||
self.batch_norm = batch_norm
|
||||
|
||||
# Build layers
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
# Input dimension to first hidden layer
|
||||
prev_dim = input_dim
|
||||
|
||||
for i, hidden_dim in enumerate(hidden_dims):
|
||||
# Linear layer
|
||||
linear_layer = nn.Linear(prev_dim, hidden_dim)
|
||||
self.layers.append(linear_layer)
|
||||
|
||||
# Batch normalization (optional)
|
||||
if batch_norm:
|
||||
bn_layer = nn.BatchNorm1d(hidden_dim)
|
||||
self.layers.append(bn_layer)
|
||||
|
||||
# Activation function
|
||||
activation_layer = self._get_activation(activation)
|
||||
self.layers.append(activation_layer)
|
||||
|
||||
# Dropout
|
||||
dropout_layer = nn.Dropout(dropout_rate)
|
||||
self.layers.append(dropout_layer)
|
||||
|
||||
prev_dim = hidden_dim
|
||||
|
||||
# Binary classification head
|
||||
self.classifier = nn.Linear(prev_dim, 1)
|
||||
|
||||
# Initialize weights
|
||||
self._initialize_weights()
|
||||
|
||||
logger.info(f"Initialized EmbeddingClassifier: input_dim={input_dim}, hidden_dims={hidden_dims}")
|
||||
|
||||
def _get_activation(self, activation: str) -> nn.Module:
|
||||
"""Get activation function module"""
|
||||
activations = {
|
||||
'relu': nn.ReLU(),
|
||||
'gelu': nn.GELU(),
|
||||
'tanh': nn.Tanh(),
|
||||
'leaky_relu': nn.LeakyReLU(0.1),
|
||||
'elu': nn.ELU()
|
||||
}
|
||||
|
||||
if activation not in activations:
|
||||
logger.warning(f"Unknown activation '{activation}', using ReLU")
|
||||
return nn.ReLU()
|
||||
|
||||
return activations[activation]
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""Initialize model weights using Xavier/Glorot initialization"""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
elif isinstance(module, nn.BatchNorm1d):
|
||||
nn.init.constant_(module.weight, 1)
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the network
|
||||
|
||||
Args:
|
||||
x: Input embeddings of shape (batch_size, input_dim)
|
||||
|
||||
Returns:
|
||||
logits of shape (batch_size, 1)
|
||||
"""
|
||||
# Ensure input is float tensor
|
||||
if not x.dtype == torch.float32:
|
||||
x = x.float()
|
||||
|
||||
# Flatten input if it's multi-dimensional (shouldn't be for embeddings)
|
||||
if x.dim() > 2:
|
||||
x = x.view(x.size(0), -1)
|
||||
|
||||
# Process through layers
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
# Final classification layer
|
||||
logits = self.classifier(x)
|
||||
|
||||
return logits
|
||||
|
||||
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Predict class probabilities
|
||||
|
||||
Args:
|
||||
x: Input embeddings
|
||||
|
||||
Returns:
|
||||
Class probabilities of shape (batch_size, 2)
|
||||
"""
|
||||
logits = self.forward(x)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
|
||||
# Convert to [negative_prob, positive_prob] format
|
||||
prob_0 = 1 - probabilities
|
||||
prob_1 = probabilities
|
||||
|
||||
return torch.cat([prob_0, prob_1], dim=1)
|
||||
|
||||
def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
|
||||
"""
|
||||
Predict class labels
|
||||
|
||||
Args:
|
||||
x: Input embeddings
|
||||
threshold: Classification threshold
|
||||
|
||||
Returns:
|
||||
Binary predictions of shape (batch_size,)
|
||||
"""
|
||||
probabilities = self.predict_proba(x)
|
||||
predictions = (probabilities[:, 1] > threshold).long()
|
||||
return predictions
|
||||
|
||||
def get_model_info(self) -> dict:
|
||||
"""Get model architecture information"""
|
||||
total_params = sum(p.numel() for p in self.parameters())
|
||||
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
return {
|
||||
'model_class': self.__class__.__name__,
|
||||
'input_dim': self.input_dim,
|
||||
'hidden_dims': self.hidden_dims,
|
||||
'dropout_rate': self.dropout_rate,
|
||||
'batch_norm': self.batch_norm,
|
||||
'total_parameters': total_params,
|
||||
'trainable_parameters': trainable_params,
|
||||
'model_size_mb': total_params * 4 / (1024 * 1024) # Assuming float32
|
||||
}
|
||||
|
||||
|
||||
class AttentionEmbeddingClassifier(EmbeddingClassifier):
|
||||
"""
|
||||
Enhanced classifier with self-attention mechanism
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 2048,
|
||||
hidden_dims: Optional[Tuple[int, ...]] = None,
|
||||
dropout_rate: float = 0.3,
|
||||
batch_norm: bool = True,
|
||||
activation: str = "relu",
|
||||
attention_dim: int = 512
|
||||
):
|
||||
super().__init__(input_dim, hidden_dims, dropout_rate, batch_norm, activation)
|
||||
|
||||
# Self-attention mechanism
|
||||
self.attention_dim = attention_dim
|
||||
self.attention = nn.MultiheadAttention(
|
||||
embed_dim=input_dim,
|
||||
num_heads=8,
|
||||
dropout=dropout_rate,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
# Attention projection layer
|
||||
self.attention_projection = nn.Linear(input_dim, attention_dim)
|
||||
|
||||
# Re-initialize attention weights
|
||||
self._initialize_attention_weights()
|
||||
|
||||
logger.info(f"Initialized AttentionEmbeddingClassifier with attention_dim={attention_dim}")
|
||||
|
||||
def _initialize_attention_weights(self):
|
||||
"""Initialize attention mechanism weights"""
|
||||
for module in self.attention.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with attention mechanism
|
||||
|
||||
Args:
|
||||
x: Input embeddings of shape (batch_size, input_dim)
|
||||
|
||||
Returns:
|
||||
logits of shape (batch_size, 1)
|
||||
"""
|
||||
# Ensure input is float tensor
|
||||
if not x.dtype == torch.float32:
|
||||
x = x.float()
|
||||
|
||||
# Add sequence dimension for attention (batch_size, 1, input_dim)
|
||||
x_expanded = x.unsqueeze(1)
|
||||
|
||||
# Apply self-attention
|
||||
attended, attention_weights = self.attention(x_expanded, x_expanded, x_expanded)
|
||||
|
||||
# Remove sequence dimension (batch_size, input_dim)
|
||||
attended = attended.squeeze(1)
|
||||
|
||||
# Project to attention dimension
|
||||
attended = self.attention_projection(attended)
|
||||
|
||||
# Process through original classification layers
|
||||
for layer in self.layers:
|
||||
attended = layer(attended)
|
||||
|
||||
# Final classification layer
|
||||
logits = self.classifier(attended)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def create_model(
|
||||
model_type: str = "standard",
|
||||
input_dim: int = 2048,
|
||||
hidden_dims: Optional[Tuple[int, ...]] = None,
|
||||
**kwargs
|
||||
) -> EmbeddingClassifier:
|
||||
"""
|
||||
Factory function to create embedding classifier models
|
||||
|
||||
Args:
|
||||
model_type: Type of model ('standard', 'attention')
|
||||
input_dim: Input embedding dimension
|
||||
hidden_dims: Hidden layer dimensions
|
||||
**kwargs: Additional model arguments
|
||||
|
||||
Returns:
|
||||
Initialized model
|
||||
"""
|
||||
if model_type == "standard":
|
||||
return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
|
||||
elif model_type == "attention":
|
||||
return AttentionEmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test model creation and forward pass
|
||||
model = create_model(
|
||||
model_type="standard",
|
||||
input_dim=2048,
|
||||
hidden_dims=(512, 256, 128),
|
||||
dropout_rate=0.3
|
||||
)
|
||||
|
||||
# Print model info
|
||||
info = model.get_model_info()
|
||||
print("Model Information:")
|
||||
for key, value in info.items():
|
||||
print(f" {key}: {value}")
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 8
|
||||
dummy_input = torch.randn(batch_size, 2048)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(dummy_input)
|
||||
probabilities = model.predict_proba(dummy_input)
|
||||
predictions = model.predict(dummy_input)
|
||||
|
||||
print(f"\nTest Results:")
|
||||
print(f" Input shape: {dummy_input.shape}")
|
||||
print(f" Logits shape: {logits.shape}")
|
||||
print(f" Probabilities shape: {probabilities.shape}")
|
||||
print(f" Predictions shape: {predictions.shape}")
|
||||
print(f" Sample predictions: {predictions.tolist()}")
|
||||
519
ml_new/training/train.py
Normal file
519
ml_new/training/train.py
Normal file
@ -0,0 +1,519 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Main training script for embedding classification models
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Add the parent directory to the path to import ml_new modules
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from ml_new.training.models import create_model
|
||||
from ml_new.training.data_loader import DatasetLoader
|
||||
from ml_new.training.trainer import create_trainer
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
|
||||
def json_safe_convert(obj):
|
||||
"""Convert objects to JSON-serializable format"""
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): json_safe_convert(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return [json_safe_convert(item) for item in obj]
|
||||
elif isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif hasattr(obj, 'item'): # numpy scalar
|
||||
return obj.item()
|
||||
else:
|
||||
return obj
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train embedding classification model",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="ID of the dataset to use for training"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--datasets-dir",
|
||||
type=str,
|
||||
default="training/datasets",
|
||||
help="Directory containing dataset files"
|
||||
)
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
choices=["standard", "attention"],
|
||||
default="standard",
|
||||
help="Type of model architecture"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-dim",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Input embedding dimension"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hidden-dims",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[512, 256, 128],
|
||||
help="Hidden layer dimensions"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dropout-rate",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="Dropout rate for regularization"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-norm",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Use batch normalization"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--activation",
|
||||
type=str,
|
||||
choices=["relu", "gelu", "tanh", "leaky_relu", "elu"],
|
||||
default="relu",
|
||||
help="Activation function"
|
||||
)
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of training epochs"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for training"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Initial learning rate"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--weight-decay",
|
||||
type=float,
|
||||
default=1e-5,
|
||||
help="L2 regularization weight decay"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
choices=["plateau", "step", "none"],
|
||||
default="plateau",
|
||||
help="Learning rate scheduler"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--early-stopping-patience",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Patience for early stopping"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--min-delta",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Minimum improvement for early stopping"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--validation-metric",
|
||||
type=str,
|
||||
choices=["f1", "auc", "accuracy"],
|
||||
default="f1",
|
||||
help="Metric for model selection and early stopping"
|
||||
)
|
||||
|
||||
# Data arguments
|
||||
parser.add_argument(
|
||||
"--train-ratio",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Proportion of data for training"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--val-ratio",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Proportion of data for validation"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Normalize embeddings during training"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-weighted-sampler",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use weighted random sampling for balanced training"
|
||||
)
|
||||
|
||||
# Output arguments
|
||||
parser.add_argument(
|
||||
"--save-dir",
|
||||
type=str,
|
||||
default="training/checkpoints",
|
||||
help="Directory to save model checkpoints"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--experiment-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name for the experiment (auto-generated if not provided)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Save checkpoint every N epochs"
|
||||
)
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of worker processes for data loading"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Device to use (auto, cpu, cuda)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for reproducibility"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--list-datasets",
|
||||
action="store_true",
|
||||
help="List available datasets and exit"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Test data loading without training"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def setup_device(device_arg: str):
|
||||
"""Setup training device"""
|
||||
import torch
|
||||
|
||||
if device_arg == "auto":
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(device_arg)
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
if device.type == "cuda":
|
||||
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
logger.info(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def list_available_datasets(datasets_dir: str):
|
||||
"""List and display information about available datasets"""
|
||||
loader = DatasetLoader(datasets_dir)
|
||||
datasets = loader.list_datasets()
|
||||
|
||||
print(f"\nAvailable datasets in {datasets_dir}:")
|
||||
print("-" * 50)
|
||||
|
||||
if not datasets:
|
||||
print("No datasets found.")
|
||||
return []
|
||||
|
||||
for i, dataset_id in enumerate(datasets, 1):
|
||||
try:
|
||||
info = loader.get_dataset_info(dataset_id)
|
||||
print(f"{i}. {dataset_id}")
|
||||
print(f" Samples: {info.get('total_samples', 'N/A')}")
|
||||
print(f" Embedding dim: {info.get('embedding_dim', 'N/A')}")
|
||||
print(f" Classes: {info.get('class_distribution', 'N/A')}")
|
||||
print(f" Created: {info.get('created_at', 'N/A')}")
|
||||
print()
|
||||
except Exception as e:
|
||||
print(f"{i}. {dataset_id} (Error loading info: {e})")
|
||||
print()
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
def validate_arguments(args):
|
||||
"""Validate command line arguments"""
|
||||
if args.train_ratio + args.val_ratio >= 1.0:
|
||||
raise ValueError("train_ratio + val_ratio must be less than 1.0")
|
||||
|
||||
if args.batch_size <= 0:
|
||||
raise ValueError("batch_size must be positive")
|
||||
|
||||
if args.learning_rate <= 0:
|
||||
raise ValueError("learning_rate must be positive")
|
||||
|
||||
if args.dropout_rate < 0 or args.dropout_rate >= 1.0:
|
||||
raise ValueError("dropout_rate must be between 0 and 1")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training function"""
|
||||
args = parse_args()
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(args.random_seed)
|
||||
np.random.seed(args.random_seed)
|
||||
|
||||
# Setup device
|
||||
device = setup_device(args.device)
|
||||
|
||||
# List datasets if requested
|
||||
if args.list_datasets:
|
||||
list_available_datasets(args.datasets_dir)
|
||||
return
|
||||
|
||||
# Validate arguments
|
||||
try:
|
||||
validate_arguments(args)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid arguments: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Check if dataset exists
|
||||
loader = DatasetLoader(args.datasets_dir)
|
||||
datasets = loader.list_datasets()
|
||||
|
||||
if args.dataset_id not in datasets:
|
||||
logger.error(f"Dataset '{args.dataset_id}' not found in {args.datasets_dir}")
|
||||
logger.info(f"Available datasets: {datasets}")
|
||||
sys.exit(1)
|
||||
|
||||
# Create experiment name if not provided
|
||||
if args.experiment_name is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
args.experiment_name = f"{args.model_type}_{args.dataset_id}_{timestamp}"
|
||||
|
||||
logger.info(f"Starting experiment: {args.experiment_name}")
|
||||
logger.info(f"Dataset: {args.dataset_id}")
|
||||
logger.info(f"Model: {args.model_type} with hidden dims {args.hidden_dims}")
|
||||
|
||||
# Load dataset and create data loaders
|
||||
try:
|
||||
logger.info("Loading dataset and creating data loaders...")
|
||||
train_loader, val_loader, test_loader, data_info = loader.create_data_loaders(
|
||||
dataset_id=args.dataset_id,
|
||||
train_ratio=args.train_ratio,
|
||||
val_ratio=args.val_ratio,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
random_state=args.random_seed,
|
||||
normalize=args.normalize,
|
||||
use_weighted_sampler=args.use_weighted_sampler
|
||||
)
|
||||
|
||||
logger.info("Data loaders created successfully")
|
||||
logger.info(f" Training samples: {data_info['train_samples']}")
|
||||
logger.info(f" Validation samples: {data_info['val_samples']}")
|
||||
logger.info(f" Test samples: {data_info['test_samples']}")
|
||||
logger.info(f" Training class distribution: {data_info['class_distribution']['train']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create data loaders: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Create model
|
||||
try:
|
||||
logger.info("Creating model...")
|
||||
model = create_model(
|
||||
model_type=args.model_type,
|
||||
input_dim=args.input_dim,
|
||||
hidden_dims=tuple(args.hidden_dims),
|
||||
dropout_rate=args.dropout_rate,
|
||||
batch_norm=args.batch_norm,
|
||||
activation=args.activation
|
||||
)
|
||||
|
||||
model_info = model.get_model_info()
|
||||
logger.info("Model created successfully")
|
||||
logger.info(f" Parameters: {model_info['total_parameters']:,}")
|
||||
logger.info(f" Model size: {model_info['model_size_mb']:.1f} MB")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Test data loading (dry run)
|
||||
if args.dry_run:
|
||||
logger.info("Dry run: Testing data loading...")
|
||||
|
||||
# Test one batch from each loader
|
||||
for name, loader_obj in [("train", train_loader), ("val", val_loader)]:
|
||||
for batch_embeddings, batch_labels, batch_metadata in loader_obj:
|
||||
logger.info(f" {name} batch - Embeddings: {batch_embeddings.shape}, Labels: {batch_labels.shape}")
|
||||
logger.info(f" Sample labels: {batch_labels[:5].tolist()}")
|
||||
break
|
||||
|
||||
logger.info("Dry run completed successfully")
|
||||
return
|
||||
|
||||
# Create trainer
|
||||
try:
|
||||
logger.info("Creating trainer...")
|
||||
|
||||
# Determine scheduler type
|
||||
scheduler_type = None if args.scheduler == "none" else args.scheduler
|
||||
|
||||
trainer = create_trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
test_loader=test_loader,
|
||||
learning_rate=args.learning_rate,
|
||||
weight_decay=args.weight_decay,
|
||||
scheduler_type=scheduler_type,
|
||||
device=device,
|
||||
save_dir=args.save_dir,
|
||||
experiment_name=args.experiment_name
|
||||
)
|
||||
|
||||
logger.info("Trainer created successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create trainer: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Save configuration
|
||||
config_path = Path(args.save_dir) / args.experiment_name / "config.json"
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config = {
|
||||
"experiment_name": args.experiment_name,
|
||||
"dataset_id": args.dataset_id,
|
||||
"model_config": {
|
||||
"model_type": args.model_type,
|
||||
"input_dim": args.input_dim,
|
||||
"hidden_dims": args.hidden_dims,
|
||||
"dropout_rate": args.dropout_rate,
|
||||
"batch_norm": args.batch_norm,
|
||||
"activation": args.activation
|
||||
},
|
||||
"training_config": {
|
||||
"num_epochs": args.num_epochs,
|
||||
"batch_size": args.batch_size,
|
||||
"learning_rate": args.learning_rate,
|
||||
"weight_decay": args.weight_decay,
|
||||
"scheduler": args.scheduler,
|
||||
"early_stopping_patience": args.early_stopping_patience,
|
||||
"validation_metric": args.validation_metric
|
||||
},
|
||||
"data_config": {
|
||||
"train_ratio": args.train_ratio,
|
||||
"val_ratio": args.val_ratio,
|
||||
"normalize": args.normalize,
|
||||
"use_weighted_sampler": args.use_weighted_sampler,
|
||||
"random_seed": args.random_seed
|
||||
},
|
||||
"dataset_info": json_safe_convert(data_info)
|
||||
}
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
logger.info(f"Configuration saved to {config_path}")
|
||||
|
||||
# Start training
|
||||
try:
|
||||
logger.info("Starting training...")
|
||||
results = trainer.train(
|
||||
num_epochs=args.num_epochs,
|
||||
save_every=args.save_every,
|
||||
early_stopping_patience=args.early_stopping_patience,
|
||||
min_delta=args.min_delta,
|
||||
validation_score=args.validation_metric
|
||||
)
|
||||
|
||||
# Save results
|
||||
results_path = Path(args.save_dir) / args.experiment_name / "results.json"
|
||||
with open(results_path, 'w') as f:
|
||||
json.dump(results, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Training completed successfully!")
|
||||
logger.info(f"Results saved to {results_path}")
|
||||
logger.info(f"Best validation {args.validation_metric}: {results['best_val_score']:.4f}")
|
||||
|
||||
if results.get('test_metrics'):
|
||||
logger.info(f"Test accuracy: {results['test_metrics']['accuracy']:.4f}")
|
||||
logger.info(f"Test F1: {results['test_metrics']['f1']:.4f}")
|
||||
logger.info(f"Test AUC: {results['test_metrics']['auc']:.4f}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
537
ml_new/training/trainer.py
Normal file
537
ml_new/training/trainer.py
Normal file
@ -0,0 +1,537 @@
|
||||
"""
|
||||
Model trainer for embedding classification
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
|
||||
import numpy as np
|
||||
from typing import Dict, Any, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from ml_new.training.models import EmbeddingClassifier
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ModelTrainer:
|
||||
"""
|
||||
Trainer for embedding classification models
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: EmbeddingClassifier,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
test_loader: Optional[DataLoader] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
criterion: Optional[nn.Module] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
save_dir: str = "training/checkpoints",
|
||||
experiment_name: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize model trainer
|
||||
|
||||
Args:
|
||||
model: PyTorch model to train
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
test_loader: Test data loader (optional)
|
||||
optimizer: Optimizer instance (default: Adam)
|
||||
criterion: Loss function (default: BCEWithLogitsLoss)
|
||||
scheduler: Learning rate scheduler (optional)
|
||||
device: Device to use (default: auto-detect)
|
||||
save_dir: Directory to save checkpoints
|
||||
experiment_name: Name for the experiment
|
||||
"""
|
||||
self.model = model
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.test_loader = test_loader
|
||||
|
||||
# Device setup
|
||||
if device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
# Loss function and optimizer
|
||||
self.criterion = criterion or nn.BCEWithLogitsLoss()
|
||||
self.optimizer = optimizer or optim.Adam(self.model.parameters(), lr=0.001)
|
||||
self.scheduler = scheduler
|
||||
|
||||
# Training configuration
|
||||
self.save_dir = Path(save_dir)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Experiment tracking
|
||||
self.experiment_name = experiment_name or f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
self.checkpoint_dir = self.save_dir / self.experiment_name
|
||||
self.checkpoint_dir.mkdir(exist_ok=True)
|
||||
|
||||
# TensorBoard logging
|
||||
self.writer = SummaryWriter(log_dir=str(self.checkpoint_dir / "logs"))
|
||||
|
||||
# Training state
|
||||
self.current_epoch = 0
|
||||
self.best_val_score = 0.0
|
||||
self.patience_counter = 0
|
||||
self.training_history = {
|
||||
'train_loss': [],
|
||||
'train_acc': [],
|
||||
'val_loss': [],
|
||||
'val_acc': [],
|
||||
'val_auc': [],
|
||||
'learning_rates': []
|
||||
}
|
||||
|
||||
logger.info(f"Initialized trainer with device: {self.device}")
|
||||
logger.info(f"Model info: {self.model.get_model_info()}")
|
||||
|
||||
def train_epoch(self) -> Dict[str, float]:
|
||||
"""Train for one epoch"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
|
||||
for batch_idx, (embeddings, labels, metadata) in enumerate(self.train_loader):
|
||||
embeddings = embeddings.to(self.device)
|
||||
labels = labels.to(self.device).float()
|
||||
|
||||
# Zero gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(embeddings)
|
||||
loss = self.criterion(outputs.squeeze(), labels)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Collect statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Get predictions for metrics
|
||||
with torch.no_grad():
|
||||
predictions = torch.sigmoid(outputs).squeeze()
|
||||
all_predictions.extend(predictions.cpu().numpy())
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
# Log progress
|
||||
if batch_idx % 5 == 0:
|
||||
logger.info(f"Batch {batch_idx}/{len(self.train_loader)}, Loss: {loss.item():.4f}")
|
||||
|
||||
# Calculate epoch metrics
|
||||
epoch_loss = total_loss / len(self.train_loader)
|
||||
epoch_acc = accuracy_score(all_labels, (np.array(all_predictions) > 0.5).astype(int))
|
||||
|
||||
return {
|
||||
'loss': epoch_loss,
|
||||
'accuracy': epoch_acc
|
||||
}
|
||||
|
||||
def validate(self) -> Dict[str, float]:
|
||||
"""Validate the model"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
all_probabilities = []
|
||||
|
||||
with torch.no_grad():
|
||||
for embeddings, labels, metadata in self.val_loader:
|
||||
embeddings = embeddings.to(self.device)
|
||||
labels = labels.to(self.device).float()
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(embeddings)
|
||||
loss = self.criterion(outputs.squeeze(), labels)
|
||||
|
||||
# Collect statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Get predictions and probabilities
|
||||
probabilities = torch.sigmoid(outputs).squeeze()
|
||||
predictions = (probabilities > 0.5).long()
|
||||
|
||||
all_predictions.extend(predictions.cpu().numpy())
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
all_probabilities.extend(probabilities.cpu().numpy())
|
||||
|
||||
# Calculate metrics
|
||||
val_loss = total_loss / len(self.val_loader)
|
||||
val_accuracy = accuracy_score(all_labels, all_predictions)
|
||||
|
||||
# Calculate additional metrics
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
all_labels, all_predictions, average='binary', zero_division=0
|
||||
)
|
||||
|
||||
try:
|
||||
val_auc = roc_auc_score(all_labels, all_probabilities)
|
||||
except ValueError:
|
||||
val_auc = 0.0 # AUC not defined for single class
|
||||
|
||||
# Confusion matrix
|
||||
cm = confusion_matrix(all_labels, all_predictions)
|
||||
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
|
||||
|
||||
return {
|
||||
'loss': val_loss,
|
||||
'accuracy': val_accuracy,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1,
|
||||
'auc': val_auc,
|
||||
'true_negatives': tn,
|
||||
'false_positives': fp,
|
||||
'false_negatives': fn,
|
||||
'true_positives': tp
|
||||
}
|
||||
|
||||
def train(
|
||||
self,
|
||||
num_epochs: int = 100,
|
||||
save_every: int = 10,
|
||||
early_stopping_patience: int = 15,
|
||||
min_delta: float = 0.001,
|
||||
validation_score: str = 'f1'
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Train the model
|
||||
|
||||
Args:
|
||||
num_epochs: Number of training epochs
|
||||
save_every: Save checkpoint every N epochs
|
||||
early_stopping_patience: Patience for early stopping
|
||||
min_delta: Minimum improvement for early stopping
|
||||
validation_score: Metric to use for early stopping ('f1', 'auc', 'accuracy')
|
||||
|
||||
Returns:
|
||||
Training results dictionary
|
||||
"""
|
||||
logger.info(f"Starting training for {num_epochs} epochs")
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
self.current_epoch = epoch
|
||||
|
||||
# Training
|
||||
train_metrics = self.train_epoch()
|
||||
|
||||
# Validation
|
||||
val_metrics = self.validate()
|
||||
|
||||
# Learning rate scheduling
|
||||
if self.scheduler:
|
||||
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
||||
self.scheduler.step(val_metrics['loss'])
|
||||
else:
|
||||
self.scheduler.step()
|
||||
|
||||
# Log metrics
|
||||
current_lr = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
# TensorBoard logging
|
||||
self.writer.add_scalar('Loss/Train', train_metrics['loss'], epoch)
|
||||
self.writer.add_scalar('Accuracy/Train', train_metrics['accuracy'], epoch)
|
||||
self.writer.add_scalar('Loss/Validation', val_metrics['loss'], epoch)
|
||||
self.writer.add_scalar('Accuracy/Validation', val_metrics['accuracy'], epoch)
|
||||
self.writer.add_scalar('F1/Validation', val_metrics['f1'], epoch)
|
||||
self.writer.add_scalar('AUC/Validation', val_metrics['auc'], epoch)
|
||||
self.writer.add_scalar('Learning_Rate', current_lr, epoch)
|
||||
|
||||
# Update training history
|
||||
self.training_history['train_loss'].append(train_metrics['loss'])
|
||||
self.training_history['train_acc'].append(train_metrics['accuracy'])
|
||||
self.training_history['val_loss'].append(val_metrics['loss'])
|
||||
self.training_history['val_acc'].append(val_metrics['accuracy'])
|
||||
self.training_history['val_auc'].append(val_metrics['auc'])
|
||||
self.training_history['learning_rates'].append(current_lr)
|
||||
|
||||
# Print progress
|
||||
logger.info(
|
||||
f"Epoch {epoch+1}/{num_epochs} - "
|
||||
f"Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.4f} - "
|
||||
f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}, "
|
||||
f"Val F1: {val_metrics['f1']:.4f}, Val AUC: {val_metrics['auc']:.4f}"
|
||||
)
|
||||
|
||||
# Check for best model
|
||||
val_score = val_metrics[validation_score]
|
||||
if val_score > self.best_val_score + min_delta:
|
||||
self.best_val_score = val_score
|
||||
self.patience_counter = 0
|
||||
|
||||
# Save best model
|
||||
self.save_checkpoint(is_best=True)
|
||||
logger.info(f"New best {validation_score}: {val_score:.4f}")
|
||||
else:
|
||||
self.patience_counter += 1
|
||||
|
||||
# Save checkpoint
|
||||
if (epoch + 1) % save_every == 0:
|
||||
self.save_checkpoint(is_best=False)
|
||||
|
||||
# Early stopping
|
||||
if self.patience_counter >= early_stopping_patience:
|
||||
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
||||
break
|
||||
|
||||
# Final evaluation on test set if available
|
||||
test_metrics = None
|
||||
if self.test_loader:
|
||||
test_metrics = self.evaluate_test_set()
|
||||
|
||||
# Save final training state
|
||||
self.save_training_state()
|
||||
|
||||
# Close TensorBoard writer
|
||||
self.writer.close()
|
||||
|
||||
results = {
|
||||
'experiment_name': self.experiment_name,
|
||||
'best_val_score': self.best_val_score,
|
||||
'total_epochs': self.current_epoch + 1,
|
||||
'training_history': self.training_history,
|
||||
'final_val_metrics': self.validate(),
|
||||
'test_metrics': test_metrics,
|
||||
'model_info': self.model.get_model_info()
|
||||
}
|
||||
|
||||
logger.info(f"Training completed. Best validation {validation_score}: {self.best_val_score:.4f}")
|
||||
return results
|
||||
|
||||
def evaluate_test_set(self) -> Dict[str, float]:
|
||||
"""Evaluate model on test set"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
all_probabilities = []
|
||||
|
||||
with torch.no_grad():
|
||||
for embeddings, labels, metadata in self.test_loader:
|
||||
embeddings = embeddings.to(self.device)
|
||||
labels = labels.to(self.device).float()
|
||||
|
||||
# Forward pass
|
||||
outputs = self.model(embeddings)
|
||||
loss = self.criterion(outputs.squeeze(), labels)
|
||||
|
||||
# Collect statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Get predictions and probabilities
|
||||
probabilities = torch.sigmoid(outputs).squeeze()
|
||||
predictions = (probabilities > 0.5).long()
|
||||
|
||||
all_predictions.extend(predictions.cpu().numpy())
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
all_probabilities.extend(probabilities.cpu().numpy())
|
||||
|
||||
# Calculate metrics
|
||||
test_loss = total_loss / len(self.test_loader)
|
||||
test_accuracy = accuracy_score(all_labels, all_predictions)
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
all_labels, all_predictions, average='binary', zero_division=0
|
||||
)
|
||||
|
||||
try:
|
||||
test_auc = roc_auc_score(all_labels, all_probabilities)
|
||||
except ValueError:
|
||||
test_auc = 0.0
|
||||
|
||||
logger.info(f"Test Set Results - Loss: {test_loss:.4f}, Acc: {test_accuracy:.4f}, F1: {f1:.4f}, AUC: {test_auc:.4f}")
|
||||
|
||||
return {
|
||||
'loss': test_loss,
|
||||
'accuracy': test_accuracy,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1,
|
||||
'auc': test_auc
|
||||
}
|
||||
|
||||
def save_checkpoint(self, is_best: bool = False) -> None:
|
||||
"""Save model checkpoint"""
|
||||
checkpoint = {
|
||||
'epoch': self.current_epoch,
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||
'best_val_score': self.best_val_score,
|
||||
'training_history': self.training_history,
|
||||
'model_config': {
|
||||
'input_dim': self.model.input_dim,
|
||||
'hidden_dims': self.model.hidden_dims,
|
||||
'dropout_rate': self.model.dropout_rate,
|
||||
'batch_norm': self.model.batch_norm
|
||||
}
|
||||
}
|
||||
|
||||
if self.scheduler:
|
||||
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
||||
|
||||
# Save checkpoint
|
||||
checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
|
||||
torch.save(checkpoint, checkpoint_path)
|
||||
|
||||
# Save best model separately
|
||||
if is_best:
|
||||
best_model_path = self.checkpoint_dir / "best_model.pth"
|
||||
torch.save(checkpoint, best_model_path)
|
||||
logger.info(f"Saved best model to {best_model_path}")
|
||||
|
||||
def save_training_state(self) -> None:
|
||||
"""Save final training state"""
|
||||
state = {
|
||||
'experiment_name': self.experiment_name,
|
||||
'best_val_score': self.best_val_score,
|
||||
'total_epochs': self.current_epoch + 1,
|
||||
'training_history': self.training_history,
|
||||
'model_info': self.model.get_model_info(),
|
||||
'training_config': {
|
||||
'optimizer': self.optimizer.__class__.__name__,
|
||||
'criterion': self.criterion.__class__.__name__,
|
||||
'device': str(self.device)
|
||||
}
|
||||
}
|
||||
|
||||
state_path = self.checkpoint_dir / "training_state.json"
|
||||
with open(state_path, 'w') as f:
|
||||
json.dump(state, f, indent=2)
|
||||
|
||||
logger.info(f"Saved training state to {state_path}")
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str, load_optimizer: bool = True) -> None:
|
||||
"""Load model from checkpoint"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
if load_optimizer:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
if self.scheduler and 'scheduler_state_dict' in checkpoint:
|
||||
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
self.current_epoch = checkpoint['epoch']
|
||||
self.best_val_score = checkpoint.get('best_val_score', 0.0)
|
||||
self.training_history = checkpoint.get('training_history', {})
|
||||
|
||||
logger.info(f"Loaded checkpoint from {checkpoint_path}")
|
||||
|
||||
|
||||
def create_trainer(
|
||||
model: EmbeddingClassifier,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
test_loader: Optional[DataLoader] = None,
|
||||
learning_rate: float = 0.001,
|
||||
weight_decay: float = 1e-5,
|
||||
scheduler_type: Optional[str] = 'plateau',
|
||||
**kwargs
|
||||
) -> ModelTrainer:
|
||||
"""
|
||||
Factory function to create a configured trainer
|
||||
|
||||
Args:
|
||||
model: The model to train
|
||||
train_loader: Training data loader
|
||||
val_loader: Validation data loader
|
||||
test_loader: Test data loader (optional)
|
||||
learning_rate: Initial learning rate
|
||||
weight_decay: L2 regularization
|
||||
scheduler_type: Learning rate scheduler type ('plateau', 'step', None)
|
||||
**kwargs: Additional trainer arguments
|
||||
|
||||
Returns:
|
||||
Configured ModelTrainer instance
|
||||
"""
|
||||
# Create optimizer
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay
|
||||
)
|
||||
|
||||
# Create scheduler
|
||||
scheduler = None
|
||||
if scheduler_type == 'plateau':
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode='max',
|
||||
factor=0.5,
|
||||
patience=10
|
||||
)
|
||||
elif scheduler_type == 'step':
|
||||
scheduler = optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=1,
|
||||
gamma=0.1
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
trainer = ModelTrainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
test_loader=test_loader,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test trainer creation
|
||||
from ml_new.training.models import create_model
|
||||
from ml_new.training.data_loader import DatasetLoader
|
||||
|
||||
# Create dummy model and data
|
||||
model = create_model(
|
||||
model_type="standard",
|
||||
input_dim=2048,
|
||||
hidden_dims=(512, 256, 128)
|
||||
)
|
||||
|
||||
loader = DatasetLoader()
|
||||
datasets = loader.list_datasets()
|
||||
|
||||
if datasets:
|
||||
train_loader, val_loader, test_loader, data_info = loader.create_data_loaders(
|
||||
datasets[0],
|
||||
batch_size=8,
|
||||
normalize=True
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
trainer = create_trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
test_loader=test_loader,
|
||||
experiment_name="test_experiment"
|
||||
)
|
||||
|
||||
# Test one epoch
|
||||
print("Testing trainer...")
|
||||
train_metrics = trainer.train_epoch()
|
||||
val_metrics = trainer.validate()
|
||||
|
||||
print(f"Train metrics: {train_metrics}")
|
||||
print(f"Val metrics: {val_metrics}")
|
||||
Loading…
Reference in New Issue
Block a user