diff --git a/ml_new/training/__init__.py b/ml_new/training/__init__.py new file mode 100644 index 0000000..d78502f --- /dev/null +++ b/ml_new/training/__init__.py @@ -0,0 +1,9 @@ +""" +Training module for ML models +""" + +from .models import EmbeddingClassifier +from .trainer import ModelTrainer +from .data_loader import DatasetLoader + +__all__ = ['EmbeddingClassifier', 'ModelTrainer', 'DatasetLoader'] \ No newline at end of file diff --git a/ml_new/training/data_loader.py b/ml_new/training/data_loader.py new file mode 100644 index 0000000..2d2cf53 --- /dev/null +++ b/ml_new/training/data_loader.py @@ -0,0 +1,389 @@ +""" +Data loader for embedding datasets +""" + +import pandas as pd +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler +from typing import List, Dict, Any, Optional, Tuple +from pathlib import Path +import json +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from ml_new.config.logger_config import get_logger + +logger = get_logger(__name__) + + +class EmbeddingDataset(Dataset): + """ + PyTorch Dataset for embedding-based classification + """ + + def __init__( + self, + embeddings: np.ndarray, + labels: np.ndarray, + metadata: Optional[List[Dict[str, Any]]] = None, + transform: Optional[callable] = None, + normalize: bool = True + ): + """ + Initialize embedding dataset + + Args: + embeddings: Array of embedding vectors (n_samples, embedding_dim) + labels: Array of binary labels (n_samples,) + metadata: Optional list of metadata dictionaries + transform: Optional transformation function + normalize: Whether to normalize embeddings + """ + assert len(embeddings) == len(labels), "Embeddings and labels must have same length" + + self.embeddings = embeddings.astype(np.float32) + self.labels = labels.astype(np.int64) + self.metadata = metadata or [] + self.transform = transform + + # Normalize embeddings if requested + if normalize and len(embeddings) > 0: + self.scaler = StandardScaler() + self.embeddings = self.scaler.fit_transform(self.embeddings) + else: + self.scaler = None + + # Calculate class weights for balanced sampling + self._calculate_class_weights() + + def _calculate_class_weights(self): + """Calculate weights for each class for balanced sampling""" + unique, counts = np.unique(self.labels, return_counts=True) + total_samples = len(self.labels) + + self.class_weights = {} + for class_label, count in zip(unique, counts): + # Inverse frequency weighting + weight = total_samples / (2 * count) + self.class_weights[class_label] = weight + + logger.info(f"Class distribution: {dict(zip(unique, counts))}") + logger.info(f"Class weights: {self.class_weights}") + + def __len__(self) -> int: + return len(self.embeddings) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + """ + Get a single sample from the dataset + + Returns: + tuple: (embedding, label, metadata) + """ + embedding = torch.from_numpy(self.embeddings[idx]) + label = torch.tensor(self.labels[idx], dtype=torch.long) + + metadata = {} + if self.metadata and idx < len(self.metadata): + metadata = self.metadata[idx] + + if self.transform: + embedding = self.transform(embedding) + + return embedding, label, metadata + + +class DatasetLoader: + """ + Loader for embedding datasets stored in Parquet format + """ + + def __init__(self, datasets_dir: str = "training/datasets"): + """ + Initialize dataset loader + + Args: + datasets_dir: Directory containing dataset files + """ + self.datasets_dir = Path(datasets_dir) + self.datasets_dir.mkdir(parents=True, exist_ok=True) + + def load_dataset(self, dataset_id: str) -> Tuple[np.ndarray, np.ndarray, List[Dict[str, Any]]]: + """ + Load a dataset by ID from Parquet files + + Args: + dataset_id: Unique identifier for the dataset + + Returns: + tuple: (embeddings, labels, metadata_list) + """ + dataset_file = self.datasets_dir / f"{dataset_id}.parquet" + metadata_file = self.datasets_dir / f"{dataset_id}.metadata.json" + + if not dataset_file.exists(): + raise FileNotFoundError(f"Dataset file not found: {dataset_file}") + + # Load metadata + metadata = {} + if metadata_file.exists(): + with open(metadata_file, 'r') as f: + metadata = json.load(f) + + # Load data from Parquet + logger.info(f"Loading dataset {dataset_id} from {dataset_file}") + df = pd.read_parquet(dataset_file) + + # Extract embeddings (they might be stored as list or numpy array) + embeddings = self._extract_embeddings(df) + + # Extract labels + labels = df['label'].values.astype(np.int64) + + # Extract metadata + metadata_list = [] + if 'metadata_json' in df.columns: + for _, row in df.iterrows(): + meta = {} + if pd.notna(row.get('metadata_json')): + try: + meta = json.loads(row['metadata_json']) + except (json.JSONDecodeError, TypeError): + meta = {} + + # Add other fields + meta.update({ + 'aid': row.get('aid'), + 'inconsistent': row.get('inconsistent', False), + 'text_checksum': row.get('text_checksum') + }) + metadata_list.append(meta) + else: + # Create basic metadata + metadata_list = [{ + 'aid': aid, + 'inconsistent': inconsistent, + 'text_checksum': checksum + } for aid, inconsistent, checksum in zip( + df.get('aid', []), + df.get('inconsistent', [False] * len(df)), + df.get('text_checksum', [''] * len(df)) + )] + + logger.info(f"Loaded dataset with {len(embeddings)} samples, {embeddings.shape[1]} embedding dimensions") + + return embeddings, labels, metadata_list + + def _extract_embeddings(self, df: pd.DataFrame) -> np.ndarray: + """Extract embeddings from DataFrame, handling different storage formats""" + embedding_col = None + for col in ['embedding', 'embeddings', 'vec_2048', 'vec_1024']: + if col in df.columns: + embedding_col = col + break + + if embedding_col is None: + raise ValueError("No embedding column found in dataset") + + embeddings_data = df[embedding_col] + + # Handle different embedding storage formats + if embeddings_data.dtype == 'object': + # Likely stored as lists or numpy arrays + embeddings = np.array([ + np.array(emb) if isinstance(emb, (list, np.ndarray)) else np.zeros(2048) + for emb in embeddings_data + ]) + else: + # Already numpy array + embeddings = embeddings_data.values + + # Ensure 2D array + if embeddings.ndim == 1: + # If embeddings are flattened, reshape + embedding_dim = len(embeddings) // len(df) + embeddings = embeddings.reshape(len(df), embedding_dim) + + return embeddings.astype(np.float32) + + def create_data_loaders( + self, + dataset_id: str, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + batch_size: int = 32, + num_workers: int = 4, + random_state: int = 42, + normalize: bool = True, + use_weighted_sampler: bool = True + ) -> Tuple[DataLoader, DataLoader, DataLoader, Dict[str, Any]]: + """ + Create train, validation, and test data loaders + + Args: + dataset_id: Dataset identifier + train_ratio: Proportion of data for training + val_ratio: Proportion of data for validation + batch_size: Batch size for data loaders + num_workers: Number of worker processes + random_state: Random seed for reproducibility + normalize: Whether to normalize embeddings + use_weighted_sampler: Whether to use weighted random sampling + + Returns: + tuple: (train_loader, val_loader, test_loader, dataset_info) + """ + # Load dataset + embeddings, labels, metadata = self.load_dataset(dataset_id) + + # Split data + ( + train_emb, test_emb, + train_lbl, test_lbl, + train_meta, test_meta + ) = train_test_split( + embeddings, labels, metadata, + test_size=1 - train_ratio, + stratify=labels, + random_state=random_state + ) + + # Split test into val and test + val_size = val_ratio / (val_ratio + (1 - train_ratio - val_ratio)) + ( + val_emb, test_emb, + val_lbl, test_lbl, + val_meta, test_meta + ) = train_test_split( + test_emb, test_lbl, test_meta, + test_size=1 - val_size, + stratify=test_lbl, + random_state=random_state + ) + + # Create datasets + train_dataset = EmbeddingDataset(train_emb, train_lbl, train_meta, normalize=normalize) + val_dataset = EmbeddingDataset(val_emb, val_lbl, val_meta, normalize=False) # Don't re-normalize + test_dataset = EmbeddingDataset(test_emb, test_lbl, test_meta, normalize=False) + + # Create samplers + train_sampler = None + if use_weighted_sampler and hasattr(train_dataset, 'class_weights'): + # Create weighted sampler for balanced training + sample_weights = [train_dataset.class_weights[label] for label in train_dataset.labels] + train_sampler = WeightedRandomSampler( + weights=sample_weights, + num_samples=len(sample_weights), + replacement=True + ) + + # Create data loaders + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + sampler=train_sampler, + shuffle=(train_sampler is None), + num_workers=num_workers + ) + + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + test_loader = DataLoader( + test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + # Dataset info + dataset_info = { + 'dataset_id': dataset_id, + 'total_samples': len(embeddings), + 'embedding_dim': embeddings.shape[1], + 'train_samples': len(train_dataset), + 'val_samples': len(val_dataset), + 'test_samples': len(test_dataset), + 'train_ratio': len(train_dataset) / len(embeddings), + 'val_ratio': len(val_dataset) / len(embeddings), + 'test_ratio': len(test_dataset) / len(embeddings), + 'class_distribution': { + 'train': dict(zip(*np.unique(train_dataset.labels, return_counts=True))), + 'val': dict(zip(*np.unique(val_dataset.labels, return_counts=True))), + 'test': dict(zip(*np.unique(test_dataset.labels, return_counts=True))) + }, + 'normalize': normalize, + 'use_weighted_sampler': use_weighted_sampler + } + + logger.info(f"Created data loaders: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}") + + return train_loader, val_loader, test_loader, dataset_info + + def list_datasets(self) -> List[str]: + """List all available datasets""" + parquet_files = list(self.datasets_dir.glob("*.parquet")) + return [f.stem for f in parquet_files] + + def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]: + """Get detailed information about a dataset""" + metadata_file = self.datasets_dir / f"{dataset_id}.metadata.json" + + if metadata_file.exists(): + with open(metadata_file, 'r') as f: + return json.load(f) + + # Fallback: load dataset and return basic info + embeddings, labels, metadata = self.load_dataset(dataset_id) + return { + 'dataset_id': dataset_id, + 'total_samples': len(embeddings), + 'embedding_dim': embeddings.shape[1], + 'class_distribution': dict(zip(*np.unique(labels, return_counts=True))), + 'file_format': 'parquet', + 'created_at': 'unknown' + } + + +if __name__ == "__main__": + # Test dataset loading + loader = DatasetLoader() + + # List available datasets + datasets = loader.list_datasets() + print(f"Available datasets: {datasets}") + + if datasets: + # Test loading first dataset + dataset_id = datasets[0] + print(f"\nTesting dataset: {dataset_id}") + + info = loader.get_dataset_info(dataset_id) + print("Dataset info:", info) + + # Test creating data loaders + try: + train_loader, val_loader, test_loader, data_info = loader.create_data_loaders( + dataset_id, + batch_size=8, + normalize=True + ) + + print("\nData loader info:") + for key, value in data_info.items(): + print(f" {key}: {value}") + + # Test single batch + for batch_embeddings, batch_labels, batch_metadata in train_loader: + print(f"\nBatch test:") + print(f" Embeddings shape: {batch_embeddings.shape}") + print(f" Labels shape: {batch_labels.shape}") + print(f" Sample labels: {batch_labels[:5].tolist()}") + break + + except Exception as e: + print(f"Error creating data loaders: {e}") \ No newline at end of file diff --git a/ml_new/training/models.py b/ml_new/training/models.py new file mode 100644 index 0000000..e481260 --- /dev/null +++ b/ml_new/training/models.py @@ -0,0 +1,324 @@ +""" +Neural network model for binary classification of video embeddings +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional +from ml_new.config.logger_config import get_logger + +logger = get_logger(__name__) + + +class EmbeddingClassifier(nn.Module): + """ + Neural network model for binary classification of video embeddings + + Architecture: + - Embedding layer (configurable input dimension) + - Hidden layers with dropout and batch normalization + - Binary classification head + """ + + def __init__( + self, + input_dim: int = 2048, + hidden_dims: Optional[Tuple[int, ...]] = None, + dropout_rate: float = 0.3, + batch_norm: bool = True, + activation: str = "relu" + ): + """ + Initialize the embedding classifier model + + Args: + input_dim: Dimension of input embeddings (default: 2048 for qwen3-embedding) + hidden_dims: Tuple of hidden layer dimensions (default: (512, 256, 128)) + dropout_rate: Dropout probability for regularization + batch_norm: Whether to use batch normalization + activation: Activation function ('relu', 'gelu', 'tanh') + """ + super(EmbeddingClassifier, self).__init__() + + # Default hidden dimensions if not provided + if hidden_dims is None: + hidden_dims = (512, 256, 128) + + self.input_dim = input_dim + self.hidden_dims = hidden_dims + self.dropout_rate = dropout_rate + self.batch_norm = batch_norm + + # Build layers + self.layers = nn.ModuleList() + + # Input dimension to first hidden layer + prev_dim = input_dim + + for i, hidden_dim in enumerate(hidden_dims): + # Linear layer + linear_layer = nn.Linear(prev_dim, hidden_dim) + self.layers.append(linear_layer) + + # Batch normalization (optional) + if batch_norm: + bn_layer = nn.BatchNorm1d(hidden_dim) + self.layers.append(bn_layer) + + # Activation function + activation_layer = self._get_activation(activation) + self.layers.append(activation_layer) + + # Dropout + dropout_layer = nn.Dropout(dropout_rate) + self.layers.append(dropout_layer) + + prev_dim = hidden_dim + + # Binary classification head + self.classifier = nn.Linear(prev_dim, 1) + + # Initialize weights + self._initialize_weights() + + logger.info(f"Initialized EmbeddingClassifier: input_dim={input_dim}, hidden_dims={hidden_dims}") + + def _get_activation(self, activation: str) -> nn.Module: + """Get activation function module""" + activations = { + 'relu': nn.ReLU(), + 'gelu': nn.GELU(), + 'tanh': nn.Tanh(), + 'leaky_relu': nn.LeakyReLU(0.1), + 'elu': nn.ELU() + } + + if activation not in activations: + logger.warning(f"Unknown activation '{activation}', using ReLU") + return nn.ReLU() + + return activations[activation] + + def _initialize_weights(self): + """Initialize model weights using Xavier/Glorot initialization""" + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.BatchNorm1d): + nn.init.constant_(module.weight, 1) + nn.init.constant_(module.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the network + + Args: + x: Input embeddings of shape (batch_size, input_dim) + + Returns: + logits of shape (batch_size, 1) + """ + # Ensure input is float tensor + if not x.dtype == torch.float32: + x = x.float() + + # Flatten input if it's multi-dimensional (shouldn't be for embeddings) + if x.dim() > 2: + x = x.view(x.size(0), -1) + + # Process through layers + for layer in self.layers: + x = layer(x) + + # Final classification layer + logits = self.classifier(x) + + return logits + + def predict_proba(self, x: torch.Tensor) -> torch.Tensor: + """ + Predict class probabilities + + Args: + x: Input embeddings + + Returns: + Class probabilities of shape (batch_size, 2) + """ + logits = self.forward(x) + probabilities = torch.sigmoid(logits) + + # Convert to [negative_prob, positive_prob] format + prob_0 = 1 - probabilities + prob_1 = probabilities + + return torch.cat([prob_0, prob_1], dim=1) + + def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """ + Predict class labels + + Args: + x: Input embeddings + threshold: Classification threshold + + Returns: + Binary predictions of shape (batch_size,) + """ + probabilities = self.predict_proba(x) + predictions = (probabilities[:, 1] > threshold).long() + return predictions + + def get_model_info(self) -> dict: + """Get model architecture information""" + total_params = sum(p.numel() for p in self.parameters()) + trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + + return { + 'model_class': self.__class__.__name__, + 'input_dim': self.input_dim, + 'hidden_dims': self.hidden_dims, + 'dropout_rate': self.dropout_rate, + 'batch_norm': self.batch_norm, + 'total_parameters': total_params, + 'trainable_parameters': trainable_params, + 'model_size_mb': total_params * 4 / (1024 * 1024) # Assuming float32 + } + + +class AttentionEmbeddingClassifier(EmbeddingClassifier): + """ + Enhanced classifier with self-attention mechanism + """ + + def __init__( + self, + input_dim: int = 2048, + hidden_dims: Optional[Tuple[int, ...]] = None, + dropout_rate: float = 0.3, + batch_norm: bool = True, + activation: str = "relu", + attention_dim: int = 512 + ): + super().__init__(input_dim, hidden_dims, dropout_rate, batch_norm, activation) + + # Self-attention mechanism + self.attention_dim = attention_dim + self.attention = nn.MultiheadAttention( + embed_dim=input_dim, + num_heads=8, + dropout=dropout_rate, + batch_first=True + ) + + # Attention projection layer + self.attention_projection = nn.Linear(input_dim, attention_dim) + + # Re-initialize attention weights + self._initialize_attention_weights() + + logger.info(f"Initialized AttentionEmbeddingClassifier with attention_dim={attention_dim}") + + def _initialize_attention_weights(self): + """Initialize attention mechanism weights""" + for module in self.attention.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with attention mechanism + + Args: + x: Input embeddings of shape (batch_size, input_dim) + + Returns: + logits of shape (batch_size, 1) + """ + # Ensure input is float tensor + if not x.dtype == torch.float32: + x = x.float() + + # Add sequence dimension for attention (batch_size, 1, input_dim) + x_expanded = x.unsqueeze(1) + + # Apply self-attention + attended, attention_weights = self.attention(x_expanded, x_expanded, x_expanded) + + # Remove sequence dimension (batch_size, input_dim) + attended = attended.squeeze(1) + + # Project to attention dimension + attended = self.attention_projection(attended) + + # Process through original classification layers + for layer in self.layers: + attended = layer(attended) + + # Final classification layer + logits = self.classifier(attended) + + return logits + + +def create_model( + model_type: str = "standard", + input_dim: int = 2048, + hidden_dims: Optional[Tuple[int, ...]] = None, + **kwargs +) -> EmbeddingClassifier: + """ + Factory function to create embedding classifier models + + Args: + model_type: Type of model ('standard', 'attention') + input_dim: Input embedding dimension + hidden_dims: Hidden layer dimensions + **kwargs: Additional model arguments + + Returns: + Initialized model + """ + if model_type == "standard": + return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs) + elif model_type == "attention": + return AttentionEmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs) + else: + raise ValueError(f"Unknown model type: {model_type}") + + +if __name__ == "__main__": + # Test model creation and forward pass + model = create_model( + model_type="standard", + input_dim=2048, + hidden_dims=(512, 256, 128), + dropout_rate=0.3 + ) + + # Print model info + info = model.get_model_info() + print("Model Information:") + for key, value in info.items(): + print(f" {key}: {value}") + + # Test forward pass + batch_size = 8 + dummy_input = torch.randn(batch_size, 2048) + + with torch.no_grad(): + logits = model(dummy_input) + probabilities = model.predict_proba(dummy_input) + predictions = model.predict(dummy_input) + + print(f"\nTest Results:") + print(f" Input shape: {dummy_input.shape}") + print(f" Logits shape: {logits.shape}") + print(f" Probabilities shape: {probabilities.shape}") + print(f" Predictions shape: {predictions.shape}") + print(f" Sample predictions: {predictions.tolist()}") \ No newline at end of file diff --git a/ml_new/training/train.py b/ml_new/training/train.py new file mode 100644 index 0000000..e4edcdb --- /dev/null +++ b/ml_new/training/train.py @@ -0,0 +1,519 @@ +#!/usr/bin/env python3 +""" +Main training script for embedding classification models +""" + +import argparse +import json +import sys +from pathlib import Path +from datetime import datetime + +import numpy as np + +# Add the parent directory to the path to import ml_new modules +sys.path.append(str(Path(__file__).parent.parent)) + +from ml_new.training.models import create_model +from ml_new.training.data_loader import DatasetLoader +from ml_new.training.trainer import create_trainer +from ml_new.config.logger_config import get_logger + + +def json_safe_convert(obj): + """Convert objects to JSON-serializable format""" + if isinstance(obj, dict): + return {str(k): json_safe_convert(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): + return [json_safe_convert(item) for item in obj] + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif hasattr(obj, 'item'): # numpy scalar + return obj.item() + else: + return obj + +logger = get_logger(__name__) + + +def parse_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description="Train embedding classification model", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Data arguments + parser.add_argument( + "--dataset-id", + type=str, + required=True, + help="ID of the dataset to use for training" + ) + + parser.add_argument( + "--datasets-dir", + type=str, + default="training/datasets", + help="Directory containing dataset files" + ) + + # Model arguments + parser.add_argument( + "--model-type", + type=str, + choices=["standard", "attention"], + default="standard", + help="Type of model architecture" + ) + + parser.add_argument( + "--input-dim", + type=int, + default=2048, + help="Input embedding dimension" + ) + + parser.add_argument( + "--hidden-dims", + type=int, + nargs="+", + default=[512, 256, 128], + help="Hidden layer dimensions" + ) + + parser.add_argument( + "--dropout-rate", + type=float, + default=0.3, + help="Dropout rate for regularization" + ) + + parser.add_argument( + "--batch-norm", + action="store_true", + default=True, + help="Use batch normalization" + ) + + parser.add_argument( + "--activation", + type=str, + choices=["relu", "gelu", "tanh", "leaky_relu", "elu"], + default="relu", + help="Activation function" + ) + + # Training arguments + parser.add_argument( + "--num-epochs", + type=int, + default=100, + help="Number of training epochs" + ) + + parser.add_argument( + "--batch-size", + type=int, + default=32, + help="Batch size for training" + ) + + parser.add_argument( + "--learning-rate", + type=float, + default=0.001, + help="Initial learning rate" + ) + + parser.add_argument( + "--weight-decay", + type=float, + default=1e-5, + help="L2 regularization weight decay" + ) + + parser.add_argument( + "--scheduler", + type=str, + choices=["plateau", "step", "none"], + default="plateau", + help="Learning rate scheduler" + ) + + parser.add_argument( + "--early-stopping-patience", + type=int, + default=15, + help="Patience for early stopping" + ) + + parser.add_argument( + "--min-delta", + type=float, + default=0.001, + help="Minimum improvement for early stopping" + ) + + parser.add_argument( + "--validation-metric", + type=str, + choices=["f1", "auc", "accuracy"], + default="f1", + help="Metric for model selection and early stopping" + ) + + # Data arguments + parser.add_argument( + "--train-ratio", + type=float, + default=0.8, + help="Proportion of data for training" + ) + + parser.add_argument( + "--val-ratio", + type=float, + default=0.1, + help="Proportion of data for validation" + ) + + parser.add_argument( + "--normalize", + action="store_true", + default=False, + help="Normalize embeddings during training" + ) + + parser.add_argument( + "--use-weighted-sampler", + action="store_true", + default=False, + help="Use weighted random sampling for balanced training" + ) + + # Output arguments + parser.add_argument( + "--save-dir", + type=str, + default="training/checkpoints", + help="Directory to save model checkpoints" + ) + + parser.add_argument( + "--experiment-name", + type=str, + default=None, + help="Name for the experiment (auto-generated if not provided)" + ) + + parser.add_argument( + "--save-every", + type=int, + default=10, + help="Save checkpoint every N epochs" + ) + + # Other arguments + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of worker processes for data loading" + ) + + parser.add_argument( + "--device", + type=str, + default="auto", + help="Device to use (auto, cpu, cuda)" + ) + + parser.add_argument( + "--random-seed", + type=int, + default=42, + help="Random seed for reproducibility" + ) + + parser.add_argument( + "--list-datasets", + action="store_true", + help="List available datasets and exit" + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="Test data loading without training" + ) + + return parser.parse_args() + + +def setup_device(device_arg: str): + """Setup training device""" + import torch + + if device_arg == "auto": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(device_arg) + + logger.info(f"Using device: {device}") + if device.type == "cuda": + logger.info(f"GPU: {torch.cuda.get_device_name(0)}") + logger.info(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") + + return device + + +def list_available_datasets(datasets_dir: str): + """List and display information about available datasets""" + loader = DatasetLoader(datasets_dir) + datasets = loader.list_datasets() + + print(f"\nAvailable datasets in {datasets_dir}:") + print("-" * 50) + + if not datasets: + print("No datasets found.") + return [] + + for i, dataset_id in enumerate(datasets, 1): + try: + info = loader.get_dataset_info(dataset_id) + print(f"{i}. {dataset_id}") + print(f" Samples: {info.get('total_samples', 'N/A')}") + print(f" Embedding dim: {info.get('embedding_dim', 'N/A')}") + print(f" Classes: {info.get('class_distribution', 'N/A')}") + print(f" Created: {info.get('created_at', 'N/A')}") + print() + except Exception as e: + print(f"{i}. {dataset_id} (Error loading info: {e})") + print() + + return datasets + + +def validate_arguments(args): + """Validate command line arguments""" + if args.train_ratio + args.val_ratio >= 1.0: + raise ValueError("train_ratio + val_ratio must be less than 1.0") + + if args.batch_size <= 0: + raise ValueError("batch_size must be positive") + + if args.learning_rate <= 0: + raise ValueError("learning_rate must be positive") + + if args.dropout_rate < 0 or args.dropout_rate >= 1.0: + raise ValueError("dropout_rate must be between 0 and 1") + + +def main(): + """Main training function""" + args = parse_args() + + # Set random seeds for reproducibility + import torch + import numpy as np + + torch.manual_seed(args.random_seed) + np.random.seed(args.random_seed) + + # Setup device + device = setup_device(args.device) + + # List datasets if requested + if args.list_datasets: + list_available_datasets(args.datasets_dir) + return + + # Validate arguments + try: + validate_arguments(args) + except ValueError as e: + logger.error(f"Invalid arguments: {e}") + sys.exit(1) + + # Check if dataset exists + loader = DatasetLoader(args.datasets_dir) + datasets = loader.list_datasets() + + if args.dataset_id not in datasets: + logger.error(f"Dataset '{args.dataset_id}' not found in {args.datasets_dir}") + logger.info(f"Available datasets: {datasets}") + sys.exit(1) + + # Create experiment name if not provided + if args.experiment_name is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + args.experiment_name = f"{args.model_type}_{args.dataset_id}_{timestamp}" + + logger.info(f"Starting experiment: {args.experiment_name}") + logger.info(f"Dataset: {args.dataset_id}") + logger.info(f"Model: {args.model_type} with hidden dims {args.hidden_dims}") + + # Load dataset and create data loaders + try: + logger.info("Loading dataset and creating data loaders...") + train_loader, val_loader, test_loader, data_info = loader.create_data_loaders( + dataset_id=args.dataset_id, + train_ratio=args.train_ratio, + val_ratio=args.val_ratio, + batch_size=args.batch_size, + num_workers=args.num_workers, + random_state=args.random_seed, + normalize=args.normalize, + use_weighted_sampler=args.use_weighted_sampler + ) + + logger.info("Data loaders created successfully") + logger.info(f" Training samples: {data_info['train_samples']}") + logger.info(f" Validation samples: {data_info['val_samples']}") + logger.info(f" Test samples: {data_info['test_samples']}") + logger.info(f" Training class distribution: {data_info['class_distribution']['train']}") + + except Exception as e: + logger.error(f"Failed to create data loaders: {e}") + sys.exit(1) + + # Create model + try: + logger.info("Creating model...") + model = create_model( + model_type=args.model_type, + input_dim=args.input_dim, + hidden_dims=tuple(args.hidden_dims), + dropout_rate=args.dropout_rate, + batch_norm=args.batch_norm, + activation=args.activation + ) + + model_info = model.get_model_info() + logger.info("Model created successfully") + logger.info(f" Parameters: {model_info['total_parameters']:,}") + logger.info(f" Model size: {model_info['model_size_mb']:.1f} MB") + + except Exception as e: + logger.error(f"Failed to create model: {e}") + sys.exit(1) + + # Test data loading (dry run) + if args.dry_run: + logger.info("Dry run: Testing data loading...") + + # Test one batch from each loader + for name, loader_obj in [("train", train_loader), ("val", val_loader)]: + for batch_embeddings, batch_labels, batch_metadata in loader_obj: + logger.info(f" {name} batch - Embeddings: {batch_embeddings.shape}, Labels: {batch_labels.shape}") + logger.info(f" Sample labels: {batch_labels[:5].tolist()}") + break + + logger.info("Dry run completed successfully") + return + + # Create trainer + try: + logger.info("Creating trainer...") + + # Determine scheduler type + scheduler_type = None if args.scheduler == "none" else args.scheduler + + trainer = create_trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + scheduler_type=scheduler_type, + device=device, + save_dir=args.save_dir, + experiment_name=args.experiment_name + ) + + logger.info("Trainer created successfully") + + except Exception as e: + logger.error(f"Failed to create trainer: {e}") + sys.exit(1) + + # Save configuration + config_path = Path(args.save_dir) / args.experiment_name / "config.json" + config_path.parent.mkdir(parents=True, exist_ok=True) + + config = { + "experiment_name": args.experiment_name, + "dataset_id": args.dataset_id, + "model_config": { + "model_type": args.model_type, + "input_dim": args.input_dim, + "hidden_dims": args.hidden_dims, + "dropout_rate": args.dropout_rate, + "batch_norm": args.batch_norm, + "activation": args.activation + }, + "training_config": { + "num_epochs": args.num_epochs, + "batch_size": args.batch_size, + "learning_rate": args.learning_rate, + "weight_decay": args.weight_decay, + "scheduler": args.scheduler, + "early_stopping_patience": args.early_stopping_patience, + "validation_metric": args.validation_metric + }, + "data_config": { + "train_ratio": args.train_ratio, + "val_ratio": args.val_ratio, + "normalize": args.normalize, + "use_weighted_sampler": args.use_weighted_sampler, + "random_seed": args.random_seed + }, + "dataset_info": json_safe_convert(data_info) + } + + with open(config_path, 'w') as f: + json.dump(config, f, indent=2) + + logger.info(f"Configuration saved to {config_path}") + + # Start training + try: + logger.info("Starting training...") + results = trainer.train( + num_epochs=args.num_epochs, + save_every=args.save_every, + early_stopping_patience=args.early_stopping_patience, + min_delta=args.min_delta, + validation_score=args.validation_metric + ) + + # Save results + results_path = Path(args.save_dir) / args.experiment_name / "results.json" + with open(results_path, 'w') as f: + json.dump(results, f, indent=2, default=str) + + logger.info(f"Training completed successfully!") + logger.info(f"Results saved to {results_path}") + logger.info(f"Best validation {args.validation_metric}: {results['best_val_score']:.4f}") + + if results.get('test_metrics'): + logger.info(f"Test accuracy: {results['test_metrics']['accuracy']:.4f}") + logger.info(f"Test F1: {results['test_metrics']['f1']:.4f}") + logger.info(f"Test AUC: {results['test_metrics']['auc']:.4f}") + + except KeyboardInterrupt: + logger.info("Training interrupted by user") + except Exception as e: + logger.error(f"Training failed: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ml_new/training/trainer.py b/ml_new/training/trainer.py new file mode 100644 index 0000000..e873bfe --- /dev/null +++ b/ml_new/training/trainer.py @@ -0,0 +1,537 @@ +""" +Model trainer for embedding classification +""" + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix +import numpy as np +from typing import Dict, Any, Optional +import json +from datetime import datetime +from pathlib import Path +from ml_new.training.models import EmbeddingClassifier +from ml_new.config.logger_config import get_logger + +logger = get_logger(__name__) + + +class ModelTrainer: + """ + Trainer for embedding classification models + """ + + def __init__( + self, + model: EmbeddingClassifier, + train_loader: DataLoader, + val_loader: DataLoader, + test_loader: Optional[DataLoader] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + criterion: Optional[nn.Module] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + device: Optional[torch.device] = None, + save_dir: str = "training/checkpoints", + experiment_name: Optional[str] = None + ): + """ + Initialize model trainer + + Args: + model: PyTorch model to train + train_loader: Training data loader + val_loader: Validation data loader + test_loader: Test data loader (optional) + optimizer: Optimizer instance (default: Adam) + criterion: Loss function (default: BCEWithLogitsLoss) + scheduler: Learning rate scheduler (optional) + device: Device to use (default: auto-detect) + save_dir: Directory to save checkpoints + experiment_name: Name for the experiment + """ + self.model = model + self.train_loader = train_loader + self.val_loader = val_loader + self.test_loader = test_loader + + # Device setup + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device + + self.model.to(self.device) + + # Loss function and optimizer + self.criterion = criterion or nn.BCEWithLogitsLoss() + self.optimizer = optimizer or optim.Adam(self.model.parameters(), lr=0.001) + self.scheduler = scheduler + + # Training configuration + self.save_dir = Path(save_dir) + self.save_dir.mkdir(parents=True, exist_ok=True) + + # Experiment tracking + self.experiment_name = experiment_name or f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + self.checkpoint_dir = self.save_dir / self.experiment_name + self.checkpoint_dir.mkdir(exist_ok=True) + + # TensorBoard logging + self.writer = SummaryWriter(log_dir=str(self.checkpoint_dir / "logs")) + + # Training state + self.current_epoch = 0 + self.best_val_score = 0.0 + self.patience_counter = 0 + self.training_history = { + 'train_loss': [], + 'train_acc': [], + 'val_loss': [], + 'val_acc': [], + 'val_auc': [], + 'learning_rates': [] + } + + logger.info(f"Initialized trainer with device: {self.device}") + logger.info(f"Model info: {self.model.get_model_info()}") + + def train_epoch(self) -> Dict[str, float]: + """Train for one epoch""" + self.model.train() + total_loss = 0.0 + all_predictions = [] + all_labels = [] + + for batch_idx, (embeddings, labels, metadata) in enumerate(self.train_loader): + embeddings = embeddings.to(self.device) + labels = labels.to(self.device).float() + + # Zero gradients + self.optimizer.zero_grad() + + # Forward pass + outputs = self.model(embeddings) + loss = self.criterion(outputs.squeeze(), labels) + + # Backward pass + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.optimizer.step() + + # Collect statistics + total_loss += loss.item() + + # Get predictions for metrics + with torch.no_grad(): + predictions = torch.sigmoid(outputs).squeeze() + all_predictions.extend(predictions.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + # Log progress + if batch_idx % 5 == 0: + logger.info(f"Batch {batch_idx}/{len(self.train_loader)}, Loss: {loss.item():.4f}") + + # Calculate epoch metrics + epoch_loss = total_loss / len(self.train_loader) + epoch_acc = accuracy_score(all_labels, (np.array(all_predictions) > 0.5).astype(int)) + + return { + 'loss': epoch_loss, + 'accuracy': epoch_acc + } + + def validate(self) -> Dict[str, float]: + """Validate the model""" + self.model.eval() + total_loss = 0.0 + all_predictions = [] + all_labels = [] + all_probabilities = [] + + with torch.no_grad(): + for embeddings, labels, metadata in self.val_loader: + embeddings = embeddings.to(self.device) + labels = labels.to(self.device).float() + + # Forward pass + outputs = self.model(embeddings) + loss = self.criterion(outputs.squeeze(), labels) + + # Collect statistics + total_loss += loss.item() + + # Get predictions and probabilities + probabilities = torch.sigmoid(outputs).squeeze() + predictions = (probabilities > 0.5).long() + + all_predictions.extend(predictions.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + all_probabilities.extend(probabilities.cpu().numpy()) + + # Calculate metrics + val_loss = total_loss / len(self.val_loader) + val_accuracy = accuracy_score(all_labels, all_predictions) + + # Calculate additional metrics + precision, recall, f1, _ = precision_recall_fscore_support( + all_labels, all_predictions, average='binary', zero_division=0 + ) + + try: + val_auc = roc_auc_score(all_labels, all_probabilities) + except ValueError: + val_auc = 0.0 # AUC not defined for single class + + # Confusion matrix + cm = confusion_matrix(all_labels, all_predictions) + tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0) + + return { + 'loss': val_loss, + 'accuracy': val_accuracy, + 'precision': precision, + 'recall': recall, + 'f1': f1, + 'auc': val_auc, + 'true_negatives': tn, + 'false_positives': fp, + 'false_negatives': fn, + 'true_positives': tp + } + + def train( + self, + num_epochs: int = 100, + save_every: int = 10, + early_stopping_patience: int = 15, + min_delta: float = 0.001, + validation_score: str = 'f1' + ) -> Dict[str, Any]: + """ + Train the model + + Args: + num_epochs: Number of training epochs + save_every: Save checkpoint every N epochs + early_stopping_patience: Patience for early stopping + min_delta: Minimum improvement for early stopping + validation_score: Metric to use for early stopping ('f1', 'auc', 'accuracy') + + Returns: + Training results dictionary + """ + logger.info(f"Starting training for {num_epochs} epochs") + + for epoch in range(num_epochs): + self.current_epoch = epoch + + # Training + train_metrics = self.train_epoch() + + # Validation + val_metrics = self.validate() + + # Learning rate scheduling + if self.scheduler: + if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau): + self.scheduler.step(val_metrics['loss']) + else: + self.scheduler.step() + + # Log metrics + current_lr = self.optimizer.param_groups[0]['lr'] + + # TensorBoard logging + self.writer.add_scalar('Loss/Train', train_metrics['loss'], epoch) + self.writer.add_scalar('Accuracy/Train', train_metrics['accuracy'], epoch) + self.writer.add_scalar('Loss/Validation', val_metrics['loss'], epoch) + self.writer.add_scalar('Accuracy/Validation', val_metrics['accuracy'], epoch) + self.writer.add_scalar('F1/Validation', val_metrics['f1'], epoch) + self.writer.add_scalar('AUC/Validation', val_metrics['auc'], epoch) + self.writer.add_scalar('Learning_Rate', current_lr, epoch) + + # Update training history + self.training_history['train_loss'].append(train_metrics['loss']) + self.training_history['train_acc'].append(train_metrics['accuracy']) + self.training_history['val_loss'].append(val_metrics['loss']) + self.training_history['val_acc'].append(val_metrics['accuracy']) + self.training_history['val_auc'].append(val_metrics['auc']) + self.training_history['learning_rates'].append(current_lr) + + # Print progress + logger.info( + f"Epoch {epoch+1}/{num_epochs} - " + f"Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.4f} - " + f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}, " + f"Val F1: {val_metrics['f1']:.4f}, Val AUC: {val_metrics['auc']:.4f}" + ) + + # Check for best model + val_score = val_metrics[validation_score] + if val_score > self.best_val_score + min_delta: + self.best_val_score = val_score + self.patience_counter = 0 + + # Save best model + self.save_checkpoint(is_best=True) + logger.info(f"New best {validation_score}: {val_score:.4f}") + else: + self.patience_counter += 1 + + # Save checkpoint + if (epoch + 1) % save_every == 0: + self.save_checkpoint(is_best=False) + + # Early stopping + if self.patience_counter >= early_stopping_patience: + logger.info(f"Early stopping triggered after {epoch+1} epochs") + break + + # Final evaluation on test set if available + test_metrics = None + if self.test_loader: + test_metrics = self.evaluate_test_set() + + # Save final training state + self.save_training_state() + + # Close TensorBoard writer + self.writer.close() + + results = { + 'experiment_name': self.experiment_name, + 'best_val_score': self.best_val_score, + 'total_epochs': self.current_epoch + 1, + 'training_history': self.training_history, + 'final_val_metrics': self.validate(), + 'test_metrics': test_metrics, + 'model_info': self.model.get_model_info() + } + + logger.info(f"Training completed. Best validation {validation_score}: {self.best_val_score:.4f}") + return results + + def evaluate_test_set(self) -> Dict[str, float]: + """Evaluate model on test set""" + self.model.eval() + total_loss = 0.0 + all_predictions = [] + all_labels = [] + all_probabilities = [] + + with torch.no_grad(): + for embeddings, labels, metadata in self.test_loader: + embeddings = embeddings.to(self.device) + labels = labels.to(self.device).float() + + # Forward pass + outputs = self.model(embeddings) + loss = self.criterion(outputs.squeeze(), labels) + + # Collect statistics + total_loss += loss.item() + + # Get predictions and probabilities + probabilities = torch.sigmoid(outputs).squeeze() + predictions = (probabilities > 0.5).long() + + all_predictions.extend(predictions.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + all_probabilities.extend(probabilities.cpu().numpy()) + + # Calculate metrics + test_loss = total_loss / len(self.test_loader) + test_accuracy = accuracy_score(all_labels, all_predictions) + precision, recall, f1, _ = precision_recall_fscore_support( + all_labels, all_predictions, average='binary', zero_division=0 + ) + + try: + test_auc = roc_auc_score(all_labels, all_probabilities) + except ValueError: + test_auc = 0.0 + + logger.info(f"Test Set Results - Loss: {test_loss:.4f}, Acc: {test_accuracy:.4f}, F1: {f1:.4f}, AUC: {test_auc:.4f}") + + return { + 'loss': test_loss, + 'accuracy': test_accuracy, + 'precision': precision, + 'recall': recall, + 'f1': f1, + 'auc': test_auc + } + + def save_checkpoint(self, is_best: bool = False) -> None: + """Save model checkpoint""" + checkpoint = { + 'epoch': self.current_epoch, + 'model_state_dict': self.model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'best_val_score': self.best_val_score, + 'training_history': self.training_history, + 'model_config': { + 'input_dim': self.model.input_dim, + 'hidden_dims': self.model.hidden_dims, + 'dropout_rate': self.model.dropout_rate, + 'batch_norm': self.model.batch_norm + } + } + + if self.scheduler: + checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() + + # Save checkpoint + checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth" + torch.save(checkpoint, checkpoint_path) + + # Save best model separately + if is_best: + best_model_path = self.checkpoint_dir / "best_model.pth" + torch.save(checkpoint, best_model_path) + logger.info(f"Saved best model to {best_model_path}") + + def save_training_state(self) -> None: + """Save final training state""" + state = { + 'experiment_name': self.experiment_name, + 'best_val_score': self.best_val_score, + 'total_epochs': self.current_epoch + 1, + 'training_history': self.training_history, + 'model_info': self.model.get_model_info(), + 'training_config': { + 'optimizer': self.optimizer.__class__.__name__, + 'criterion': self.criterion.__class__.__name__, + 'device': str(self.device) + } + } + + state_path = self.checkpoint_dir / "training_state.json" + with open(state_path, 'w') as f: + json.dump(state, f, indent=2) + + logger.info(f"Saved training state to {state_path}") + + def load_checkpoint(self, checkpoint_path: str, load_optimizer: bool = True) -> None: + """Load model from checkpoint""" + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + self.model.load_state_dict(checkpoint['model_state_dict']) + + if load_optimizer: + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + if self.scheduler and 'scheduler_state_dict' in checkpoint: + self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + + self.current_epoch = checkpoint['epoch'] + self.best_val_score = checkpoint.get('best_val_score', 0.0) + self.training_history = checkpoint.get('training_history', {}) + + logger.info(f"Loaded checkpoint from {checkpoint_path}") + + +def create_trainer( + model: EmbeddingClassifier, + train_loader: DataLoader, + val_loader: DataLoader, + test_loader: Optional[DataLoader] = None, + learning_rate: float = 0.001, + weight_decay: float = 1e-5, + scheduler_type: Optional[str] = 'plateau', + **kwargs +) -> ModelTrainer: + """ + Factory function to create a configured trainer + + Args: + model: The model to train + train_loader: Training data loader + val_loader: Validation data loader + test_loader: Test data loader (optional) + learning_rate: Initial learning rate + weight_decay: L2 regularization + scheduler_type: Learning rate scheduler type ('plateau', 'step', None) + **kwargs: Additional trainer arguments + + Returns: + Configured ModelTrainer instance + """ + # Create optimizer + optimizer = optim.AdamW( + model.parameters(), + lr=learning_rate, + weight_decay=weight_decay + ) + + # Create scheduler + scheduler = None + if scheduler_type == 'plateau': + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode='max', + factor=0.5, + patience=10 + ) + elif scheduler_type == 'step': + scheduler = optim.lr_scheduler.StepLR( + optimizer, + step_size=1, + gamma=0.1 + ) + + # Create trainer + trainer = ModelTrainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + optimizer=optimizer, + scheduler=scheduler, + **kwargs + ) + + return trainer + + +if __name__ == "__main__": + # Test trainer creation + from ml_new.training.models import create_model + from ml_new.training.data_loader import DatasetLoader + + # Create dummy model and data + model = create_model( + model_type="standard", + input_dim=2048, + hidden_dims=(512, 256, 128) + ) + + loader = DatasetLoader() + datasets = loader.list_datasets() + + if datasets: + train_loader, val_loader, test_loader, data_info = loader.create_data_loaders( + datasets[0], + batch_size=8, + normalize=True + ) + + # Create trainer + trainer = create_trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + experiment_name="test_experiment" + ) + + # Test one epoch + print("Testing trainer...") + train_metrics = trainer.train_epoch() + val_metrics = trainer.validate() + + print(f"Train metrics: {train_metrics}") + print(f"Val metrics: {val_metrics}") \ No newline at end of file