1
0

add: training script

This commit is contained in:
alikia2x (寒寒) 2025-12-10 21:22:22 +08:00
parent 664784dd3e
commit c14c680228
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG Key ID: 56209E0CCD8420C6
5 changed files with 1778 additions and 0 deletions

View File

@ -0,0 +1,9 @@
"""
Training module for ML models
"""
from .models import EmbeddingClassifier
from .trainer import ModelTrainer
from .data_loader import DatasetLoader
__all__ = ['EmbeddingClassifier', 'ModelTrainer', 'DatasetLoader']

View File

@ -0,0 +1,389 @@
"""
Data loader for embedding datasets
"""
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import json
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from ml_new.config.logger_config import get_logger
logger = get_logger(__name__)
class EmbeddingDataset(Dataset):
"""
PyTorch Dataset for embedding-based classification
"""
def __init__(
self,
embeddings: np.ndarray,
labels: np.ndarray,
metadata: Optional[List[Dict[str, Any]]] = None,
transform: Optional[callable] = None,
normalize: bool = True
):
"""
Initialize embedding dataset
Args:
embeddings: Array of embedding vectors (n_samples, embedding_dim)
labels: Array of binary labels (n_samples,)
metadata: Optional list of metadata dictionaries
transform: Optional transformation function
normalize: Whether to normalize embeddings
"""
assert len(embeddings) == len(labels), "Embeddings and labels must have same length"
self.embeddings = embeddings.astype(np.float32)
self.labels = labels.astype(np.int64)
self.metadata = metadata or []
self.transform = transform
# Normalize embeddings if requested
if normalize and len(embeddings) > 0:
self.scaler = StandardScaler()
self.embeddings = self.scaler.fit_transform(self.embeddings)
else:
self.scaler = None
# Calculate class weights for balanced sampling
self._calculate_class_weights()
def _calculate_class_weights(self):
"""Calculate weights for each class for balanced sampling"""
unique, counts = np.unique(self.labels, return_counts=True)
total_samples = len(self.labels)
self.class_weights = {}
for class_label, count in zip(unique, counts):
# Inverse frequency weighting
weight = total_samples / (2 * count)
self.class_weights[class_label] = weight
logger.info(f"Class distribution: {dict(zip(unique, counts))}")
logger.info(f"Class weights: {self.class_weights}")
def __len__(self) -> int:
return len(self.embeddings)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Get a single sample from the dataset
Returns:
tuple: (embedding, label, metadata)
"""
embedding = torch.from_numpy(self.embeddings[idx])
label = torch.tensor(self.labels[idx], dtype=torch.long)
metadata = {}
if self.metadata and idx < len(self.metadata):
metadata = self.metadata[idx]
if self.transform:
embedding = self.transform(embedding)
return embedding, label, metadata
class DatasetLoader:
"""
Loader for embedding datasets stored in Parquet format
"""
def __init__(self, datasets_dir: str = "training/datasets"):
"""
Initialize dataset loader
Args:
datasets_dir: Directory containing dataset files
"""
self.datasets_dir = Path(datasets_dir)
self.datasets_dir.mkdir(parents=True, exist_ok=True)
def load_dataset(self, dataset_id: str) -> Tuple[np.ndarray, np.ndarray, List[Dict[str, Any]]]:
"""
Load a dataset by ID from Parquet files
Args:
dataset_id: Unique identifier for the dataset
Returns:
tuple: (embeddings, labels, metadata_list)
"""
dataset_file = self.datasets_dir / f"{dataset_id}.parquet"
metadata_file = self.datasets_dir / f"{dataset_id}.metadata.json"
if not dataset_file.exists():
raise FileNotFoundError(f"Dataset file not found: {dataset_file}")
# Load metadata
metadata = {}
if metadata_file.exists():
with open(metadata_file, 'r') as f:
metadata = json.load(f)
# Load data from Parquet
logger.info(f"Loading dataset {dataset_id} from {dataset_file}")
df = pd.read_parquet(dataset_file)
# Extract embeddings (they might be stored as list or numpy array)
embeddings = self._extract_embeddings(df)
# Extract labels
labels = df['label'].values.astype(np.int64)
# Extract metadata
metadata_list = []
if 'metadata_json' in df.columns:
for _, row in df.iterrows():
meta = {}
if pd.notna(row.get('metadata_json')):
try:
meta = json.loads(row['metadata_json'])
except (json.JSONDecodeError, TypeError):
meta = {}
# Add other fields
meta.update({
'aid': row.get('aid'),
'inconsistent': row.get('inconsistent', False),
'text_checksum': row.get('text_checksum')
})
metadata_list.append(meta)
else:
# Create basic metadata
metadata_list = [{
'aid': aid,
'inconsistent': inconsistent,
'text_checksum': checksum
} for aid, inconsistent, checksum in zip(
df.get('aid', []),
df.get('inconsistent', [False] * len(df)),
df.get('text_checksum', [''] * len(df))
)]
logger.info(f"Loaded dataset with {len(embeddings)} samples, {embeddings.shape[1]} embedding dimensions")
return embeddings, labels, metadata_list
def _extract_embeddings(self, df: pd.DataFrame) -> np.ndarray:
"""Extract embeddings from DataFrame, handling different storage formats"""
embedding_col = None
for col in ['embedding', 'embeddings', 'vec_2048', 'vec_1024']:
if col in df.columns:
embedding_col = col
break
if embedding_col is None:
raise ValueError("No embedding column found in dataset")
embeddings_data = df[embedding_col]
# Handle different embedding storage formats
if embeddings_data.dtype == 'object':
# Likely stored as lists or numpy arrays
embeddings = np.array([
np.array(emb) if isinstance(emb, (list, np.ndarray)) else np.zeros(2048)
for emb in embeddings_data
])
else:
# Already numpy array
embeddings = embeddings_data.values
# Ensure 2D array
if embeddings.ndim == 1:
# If embeddings are flattened, reshape
embedding_dim = len(embeddings) // len(df)
embeddings = embeddings.reshape(len(df), embedding_dim)
return embeddings.astype(np.float32)
def create_data_loaders(
self,
dataset_id: str,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
batch_size: int = 32,
num_workers: int = 4,
random_state: int = 42,
normalize: bool = True,
use_weighted_sampler: bool = True
) -> Tuple[DataLoader, DataLoader, DataLoader, Dict[str, Any]]:
"""
Create train, validation, and test data loaders
Args:
dataset_id: Dataset identifier
train_ratio: Proportion of data for training
val_ratio: Proportion of data for validation
batch_size: Batch size for data loaders
num_workers: Number of worker processes
random_state: Random seed for reproducibility
normalize: Whether to normalize embeddings
use_weighted_sampler: Whether to use weighted random sampling
Returns:
tuple: (train_loader, val_loader, test_loader, dataset_info)
"""
# Load dataset
embeddings, labels, metadata = self.load_dataset(dataset_id)
# Split data
(
train_emb, test_emb,
train_lbl, test_lbl,
train_meta, test_meta
) = train_test_split(
embeddings, labels, metadata,
test_size=1 - train_ratio,
stratify=labels,
random_state=random_state
)
# Split test into val and test
val_size = val_ratio / (val_ratio + (1 - train_ratio - val_ratio))
(
val_emb, test_emb,
val_lbl, test_lbl,
val_meta, test_meta
) = train_test_split(
test_emb, test_lbl, test_meta,
test_size=1 - val_size,
stratify=test_lbl,
random_state=random_state
)
# Create datasets
train_dataset = EmbeddingDataset(train_emb, train_lbl, train_meta, normalize=normalize)
val_dataset = EmbeddingDataset(val_emb, val_lbl, val_meta, normalize=False) # Don't re-normalize
test_dataset = EmbeddingDataset(test_emb, test_lbl, test_meta, normalize=False)
# Create samplers
train_sampler = None
if use_weighted_sampler and hasattr(train_dataset, 'class_weights'):
# Create weighted sampler for balanced training
sample_weights = [train_dataset.class_weights[label] for label in train_dataset.labels]
train_sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
sampler=train_sampler,
shuffle=(train_sampler is None),
num_workers=num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
# Dataset info
dataset_info = {
'dataset_id': dataset_id,
'total_samples': len(embeddings),
'embedding_dim': embeddings.shape[1],
'train_samples': len(train_dataset),
'val_samples': len(val_dataset),
'test_samples': len(test_dataset),
'train_ratio': len(train_dataset) / len(embeddings),
'val_ratio': len(val_dataset) / len(embeddings),
'test_ratio': len(test_dataset) / len(embeddings),
'class_distribution': {
'train': dict(zip(*np.unique(train_dataset.labels, return_counts=True))),
'val': dict(zip(*np.unique(val_dataset.labels, return_counts=True))),
'test': dict(zip(*np.unique(test_dataset.labels, return_counts=True)))
},
'normalize': normalize,
'use_weighted_sampler': use_weighted_sampler
}
logger.info(f"Created data loaders: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}")
return train_loader, val_loader, test_loader, dataset_info
def list_datasets(self) -> List[str]:
"""List all available datasets"""
parquet_files = list(self.datasets_dir.glob("*.parquet"))
return [f.stem for f in parquet_files]
def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]:
"""Get detailed information about a dataset"""
metadata_file = self.datasets_dir / f"{dataset_id}.metadata.json"
if metadata_file.exists():
with open(metadata_file, 'r') as f:
return json.load(f)
# Fallback: load dataset and return basic info
embeddings, labels, metadata = self.load_dataset(dataset_id)
return {
'dataset_id': dataset_id,
'total_samples': len(embeddings),
'embedding_dim': embeddings.shape[1],
'class_distribution': dict(zip(*np.unique(labels, return_counts=True))),
'file_format': 'parquet',
'created_at': 'unknown'
}
if __name__ == "__main__":
# Test dataset loading
loader = DatasetLoader()
# List available datasets
datasets = loader.list_datasets()
print(f"Available datasets: {datasets}")
if datasets:
# Test loading first dataset
dataset_id = datasets[0]
print(f"\nTesting dataset: {dataset_id}")
info = loader.get_dataset_info(dataset_id)
print("Dataset info:", info)
# Test creating data loaders
try:
train_loader, val_loader, test_loader, data_info = loader.create_data_loaders(
dataset_id,
batch_size=8,
normalize=True
)
print("\nData loader info:")
for key, value in data_info.items():
print(f" {key}: {value}")
# Test single batch
for batch_embeddings, batch_labels, batch_metadata in train_loader:
print(f"\nBatch test:")
print(f" Embeddings shape: {batch_embeddings.shape}")
print(f" Labels shape: {batch_labels.shape}")
print(f" Sample labels: {batch_labels[:5].tolist()}")
break
except Exception as e:
print(f"Error creating data loaders: {e}")

324
ml_new/training/models.py Normal file
View File

@ -0,0 +1,324 @@
"""
Neural network model for binary classification of video embeddings
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
from ml_new.config.logger_config import get_logger
logger = get_logger(__name__)
class EmbeddingClassifier(nn.Module):
"""
Neural network model for binary classification of video embeddings
Architecture:
- Embedding layer (configurable input dimension)
- Hidden layers with dropout and batch normalization
- Binary classification head
"""
def __init__(
self,
input_dim: int = 2048,
hidden_dims: Optional[Tuple[int, ...]] = None,
dropout_rate: float = 0.3,
batch_norm: bool = True,
activation: str = "relu"
):
"""
Initialize the embedding classifier model
Args:
input_dim: Dimension of input embeddings (default: 2048 for qwen3-embedding)
hidden_dims: Tuple of hidden layer dimensions (default: (512, 256, 128))
dropout_rate: Dropout probability for regularization
batch_norm: Whether to use batch normalization
activation: Activation function ('relu', 'gelu', 'tanh')
"""
super(EmbeddingClassifier, self).__init__()
# Default hidden dimensions if not provided
if hidden_dims is None:
hidden_dims = (512, 256, 128)
self.input_dim = input_dim
self.hidden_dims = hidden_dims
self.dropout_rate = dropout_rate
self.batch_norm = batch_norm
# Build layers
self.layers = nn.ModuleList()
# Input dimension to first hidden layer
prev_dim = input_dim
for i, hidden_dim in enumerate(hidden_dims):
# Linear layer
linear_layer = nn.Linear(prev_dim, hidden_dim)
self.layers.append(linear_layer)
# Batch normalization (optional)
if batch_norm:
bn_layer = nn.BatchNorm1d(hidden_dim)
self.layers.append(bn_layer)
# Activation function
activation_layer = self._get_activation(activation)
self.layers.append(activation_layer)
# Dropout
dropout_layer = nn.Dropout(dropout_rate)
self.layers.append(dropout_layer)
prev_dim = hidden_dim
# Binary classification head
self.classifier = nn.Linear(prev_dim, 1)
# Initialize weights
self._initialize_weights()
logger.info(f"Initialized EmbeddingClassifier: input_dim={input_dim}, hidden_dims={hidden_dims}")
def _get_activation(self, activation: str) -> nn.Module:
"""Get activation function module"""
activations = {
'relu': nn.ReLU(),
'gelu': nn.GELU(),
'tanh': nn.Tanh(),
'leaky_relu': nn.LeakyReLU(0.1),
'elu': nn.ELU()
}
if activation not in activations:
logger.warning(f"Unknown activation '{activation}', using ReLU")
return nn.ReLU()
return activations[activation]
def _initialize_weights(self):
"""Initialize model weights using Xavier/Glorot initialization"""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.BatchNorm1d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the network
Args:
x: Input embeddings of shape (batch_size, input_dim)
Returns:
logits of shape (batch_size, 1)
"""
# Ensure input is float tensor
if not x.dtype == torch.float32:
x = x.float()
# Flatten input if it's multi-dimensional (shouldn't be for embeddings)
if x.dim() > 2:
x = x.view(x.size(0), -1)
# Process through layers
for layer in self.layers:
x = layer(x)
# Final classification layer
logits = self.classifier(x)
return logits
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
"""
Predict class probabilities
Args:
x: Input embeddings
Returns:
Class probabilities of shape (batch_size, 2)
"""
logits = self.forward(x)
probabilities = torch.sigmoid(logits)
# Convert to [negative_prob, positive_prob] format
prob_0 = 1 - probabilities
prob_1 = probabilities
return torch.cat([prob_0, prob_1], dim=1)
def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
"""
Predict class labels
Args:
x: Input embeddings
threshold: Classification threshold
Returns:
Binary predictions of shape (batch_size,)
"""
probabilities = self.predict_proba(x)
predictions = (probabilities[:, 1] > threshold).long()
return predictions
def get_model_info(self) -> dict:
"""Get model architecture information"""
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
'model_class': self.__class__.__name__,
'input_dim': self.input_dim,
'hidden_dims': self.hidden_dims,
'dropout_rate': self.dropout_rate,
'batch_norm': self.batch_norm,
'total_parameters': total_params,
'trainable_parameters': trainable_params,
'model_size_mb': total_params * 4 / (1024 * 1024) # Assuming float32
}
class AttentionEmbeddingClassifier(EmbeddingClassifier):
"""
Enhanced classifier with self-attention mechanism
"""
def __init__(
self,
input_dim: int = 2048,
hidden_dims: Optional[Tuple[int, ...]] = None,
dropout_rate: float = 0.3,
batch_norm: bool = True,
activation: str = "relu",
attention_dim: int = 512
):
super().__init__(input_dim, hidden_dims, dropout_rate, batch_norm, activation)
# Self-attention mechanism
self.attention_dim = attention_dim
self.attention = nn.MultiheadAttention(
embed_dim=input_dim,
num_heads=8,
dropout=dropout_rate,
batch_first=True
)
# Attention projection layer
self.attention_projection = nn.Linear(input_dim, attention_dim)
# Re-initialize attention weights
self._initialize_attention_weights()
logger.info(f"Initialized AttentionEmbeddingClassifier with attention_dim={attention_dim}")
def _initialize_attention_weights(self):
"""Initialize attention mechanism weights"""
for module in self.attention.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with attention mechanism
Args:
x: Input embeddings of shape (batch_size, input_dim)
Returns:
logits of shape (batch_size, 1)
"""
# Ensure input is float tensor
if not x.dtype == torch.float32:
x = x.float()
# Add sequence dimension for attention (batch_size, 1, input_dim)
x_expanded = x.unsqueeze(1)
# Apply self-attention
attended, attention_weights = self.attention(x_expanded, x_expanded, x_expanded)
# Remove sequence dimension (batch_size, input_dim)
attended = attended.squeeze(1)
# Project to attention dimension
attended = self.attention_projection(attended)
# Process through original classification layers
for layer in self.layers:
attended = layer(attended)
# Final classification layer
logits = self.classifier(attended)
return logits
def create_model(
model_type: str = "standard",
input_dim: int = 2048,
hidden_dims: Optional[Tuple[int, ...]] = None,
**kwargs
) -> EmbeddingClassifier:
"""
Factory function to create embedding classifier models
Args:
model_type: Type of model ('standard', 'attention')
input_dim: Input embedding dimension
hidden_dims: Hidden layer dimensions
**kwargs: Additional model arguments
Returns:
Initialized model
"""
if model_type == "standard":
return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
elif model_type == "attention":
return AttentionEmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
if __name__ == "__main__":
# Test model creation and forward pass
model = create_model(
model_type="standard",
input_dim=2048,
hidden_dims=(512, 256, 128),
dropout_rate=0.3
)
# Print model info
info = model.get_model_info()
print("Model Information:")
for key, value in info.items():
print(f" {key}: {value}")
# Test forward pass
batch_size = 8
dummy_input = torch.randn(batch_size, 2048)
with torch.no_grad():
logits = model(dummy_input)
probabilities = model.predict_proba(dummy_input)
predictions = model.predict(dummy_input)
print(f"\nTest Results:")
print(f" Input shape: {dummy_input.shape}")
print(f" Logits shape: {logits.shape}")
print(f" Probabilities shape: {probabilities.shape}")
print(f" Predictions shape: {predictions.shape}")
print(f" Sample predictions: {predictions.tolist()}")

519
ml_new/training/train.py Normal file
View File

@ -0,0 +1,519 @@
#!/usr/bin/env python3
"""
Main training script for embedding classification models
"""
import argparse
import json
import sys
from pathlib import Path
from datetime import datetime
import numpy as np
# Add the parent directory to the path to import ml_new modules
sys.path.append(str(Path(__file__).parent.parent))
from ml_new.training.models import create_model
from ml_new.training.data_loader import DatasetLoader
from ml_new.training.trainer import create_trainer
from ml_new.config.logger_config import get_logger
def json_safe_convert(obj):
"""Convert objects to JSON-serializable format"""
if isinstance(obj, dict):
return {str(k): json_safe_convert(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [json_safe_convert(item) for item in obj]
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif hasattr(obj, 'item'): # numpy scalar
return obj.item()
else:
return obj
logger = get_logger(__name__)
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Train embedding classification model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# Data arguments
parser.add_argument(
"--dataset-id",
type=str,
required=True,
help="ID of the dataset to use for training"
)
parser.add_argument(
"--datasets-dir",
type=str,
default="training/datasets",
help="Directory containing dataset files"
)
# Model arguments
parser.add_argument(
"--model-type",
type=str,
choices=["standard", "attention"],
default="standard",
help="Type of model architecture"
)
parser.add_argument(
"--input-dim",
type=int,
default=2048,
help="Input embedding dimension"
)
parser.add_argument(
"--hidden-dims",
type=int,
nargs="+",
default=[512, 256, 128],
help="Hidden layer dimensions"
)
parser.add_argument(
"--dropout-rate",
type=float,
default=0.3,
help="Dropout rate for regularization"
)
parser.add_argument(
"--batch-norm",
action="store_true",
default=True,
help="Use batch normalization"
)
parser.add_argument(
"--activation",
type=str,
choices=["relu", "gelu", "tanh", "leaky_relu", "elu"],
default="relu",
help="Activation function"
)
# Training arguments
parser.add_argument(
"--num-epochs",
type=int,
default=100,
help="Number of training epochs"
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size for training"
)
parser.add_argument(
"--learning-rate",
type=float,
default=0.001,
help="Initial learning rate"
)
parser.add_argument(
"--weight-decay",
type=float,
default=1e-5,
help="L2 regularization weight decay"
)
parser.add_argument(
"--scheduler",
type=str,
choices=["plateau", "step", "none"],
default="plateau",
help="Learning rate scheduler"
)
parser.add_argument(
"--early-stopping-patience",
type=int,
default=15,
help="Patience for early stopping"
)
parser.add_argument(
"--min-delta",
type=float,
default=0.001,
help="Minimum improvement for early stopping"
)
parser.add_argument(
"--validation-metric",
type=str,
choices=["f1", "auc", "accuracy"],
default="f1",
help="Metric for model selection and early stopping"
)
# Data arguments
parser.add_argument(
"--train-ratio",
type=float,
default=0.8,
help="Proportion of data for training"
)
parser.add_argument(
"--val-ratio",
type=float,
default=0.1,
help="Proportion of data for validation"
)
parser.add_argument(
"--normalize",
action="store_true",
default=False,
help="Normalize embeddings during training"
)
parser.add_argument(
"--use-weighted-sampler",
action="store_true",
default=False,
help="Use weighted random sampling for balanced training"
)
# Output arguments
parser.add_argument(
"--save-dir",
type=str,
default="training/checkpoints",
help="Directory to save model checkpoints"
)
parser.add_argument(
"--experiment-name",
type=str,
default=None,
help="Name for the experiment (auto-generated if not provided)"
)
parser.add_argument(
"--save-every",
type=int,
default=10,
help="Save checkpoint every N epochs"
)
# Other arguments
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of worker processes for data loading"
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="Device to use (auto, cpu, cuda)"
)
parser.add_argument(
"--random-seed",
type=int,
default=42,
help="Random seed for reproducibility"
)
parser.add_argument(
"--list-datasets",
action="store_true",
help="List available datasets and exit"
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Test data loading without training"
)
return parser.parse_args()
def setup_device(device_arg: str):
"""Setup training device"""
import torch
if device_arg == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device_arg)
logger.info(f"Using device: {device}")
if device.type == "cuda":
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
return device
def list_available_datasets(datasets_dir: str):
"""List and display information about available datasets"""
loader = DatasetLoader(datasets_dir)
datasets = loader.list_datasets()
print(f"\nAvailable datasets in {datasets_dir}:")
print("-" * 50)
if not datasets:
print("No datasets found.")
return []
for i, dataset_id in enumerate(datasets, 1):
try:
info = loader.get_dataset_info(dataset_id)
print(f"{i}. {dataset_id}")
print(f" Samples: {info.get('total_samples', 'N/A')}")
print(f" Embedding dim: {info.get('embedding_dim', 'N/A')}")
print(f" Classes: {info.get('class_distribution', 'N/A')}")
print(f" Created: {info.get('created_at', 'N/A')}")
print()
except Exception as e:
print(f"{i}. {dataset_id} (Error loading info: {e})")
print()
return datasets
def validate_arguments(args):
"""Validate command line arguments"""
if args.train_ratio + args.val_ratio >= 1.0:
raise ValueError("train_ratio + val_ratio must be less than 1.0")
if args.batch_size <= 0:
raise ValueError("batch_size must be positive")
if args.learning_rate <= 0:
raise ValueError("learning_rate must be positive")
if args.dropout_rate < 0 or args.dropout_rate >= 1.0:
raise ValueError("dropout_rate must be between 0 and 1")
def main():
"""Main training function"""
args = parse_args()
# Set random seeds for reproducibility
import torch
import numpy as np
torch.manual_seed(args.random_seed)
np.random.seed(args.random_seed)
# Setup device
device = setup_device(args.device)
# List datasets if requested
if args.list_datasets:
list_available_datasets(args.datasets_dir)
return
# Validate arguments
try:
validate_arguments(args)
except ValueError as e:
logger.error(f"Invalid arguments: {e}")
sys.exit(1)
# Check if dataset exists
loader = DatasetLoader(args.datasets_dir)
datasets = loader.list_datasets()
if args.dataset_id not in datasets:
logger.error(f"Dataset '{args.dataset_id}' not found in {args.datasets_dir}")
logger.info(f"Available datasets: {datasets}")
sys.exit(1)
# Create experiment name if not provided
if args.experiment_name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.experiment_name = f"{args.model_type}_{args.dataset_id}_{timestamp}"
logger.info(f"Starting experiment: {args.experiment_name}")
logger.info(f"Dataset: {args.dataset_id}")
logger.info(f"Model: {args.model_type} with hidden dims {args.hidden_dims}")
# Load dataset and create data loaders
try:
logger.info("Loading dataset and creating data loaders...")
train_loader, val_loader, test_loader, data_info = loader.create_data_loaders(
dataset_id=args.dataset_id,
train_ratio=args.train_ratio,
val_ratio=args.val_ratio,
batch_size=args.batch_size,
num_workers=args.num_workers,
random_state=args.random_seed,
normalize=args.normalize,
use_weighted_sampler=args.use_weighted_sampler
)
logger.info("Data loaders created successfully")
logger.info(f" Training samples: {data_info['train_samples']}")
logger.info(f" Validation samples: {data_info['val_samples']}")
logger.info(f" Test samples: {data_info['test_samples']}")
logger.info(f" Training class distribution: {data_info['class_distribution']['train']}")
except Exception as e:
logger.error(f"Failed to create data loaders: {e}")
sys.exit(1)
# Create model
try:
logger.info("Creating model...")
model = create_model(
model_type=args.model_type,
input_dim=args.input_dim,
hidden_dims=tuple(args.hidden_dims),
dropout_rate=args.dropout_rate,
batch_norm=args.batch_norm,
activation=args.activation
)
model_info = model.get_model_info()
logger.info("Model created successfully")
logger.info(f" Parameters: {model_info['total_parameters']:,}")
logger.info(f" Model size: {model_info['model_size_mb']:.1f} MB")
except Exception as e:
logger.error(f"Failed to create model: {e}")
sys.exit(1)
# Test data loading (dry run)
if args.dry_run:
logger.info("Dry run: Testing data loading...")
# Test one batch from each loader
for name, loader_obj in [("train", train_loader), ("val", val_loader)]:
for batch_embeddings, batch_labels, batch_metadata in loader_obj:
logger.info(f" {name} batch - Embeddings: {batch_embeddings.shape}, Labels: {batch_labels.shape}")
logger.info(f" Sample labels: {batch_labels[:5].tolist()}")
break
logger.info("Dry run completed successfully")
return
# Create trainer
try:
logger.info("Creating trainer...")
# Determine scheduler type
scheduler_type = None if args.scheduler == "none" else args.scheduler
trainer = create_trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
scheduler_type=scheduler_type,
device=device,
save_dir=args.save_dir,
experiment_name=args.experiment_name
)
logger.info("Trainer created successfully")
except Exception as e:
logger.error(f"Failed to create trainer: {e}")
sys.exit(1)
# Save configuration
config_path = Path(args.save_dir) / args.experiment_name / "config.json"
config_path.parent.mkdir(parents=True, exist_ok=True)
config = {
"experiment_name": args.experiment_name,
"dataset_id": args.dataset_id,
"model_config": {
"model_type": args.model_type,
"input_dim": args.input_dim,
"hidden_dims": args.hidden_dims,
"dropout_rate": args.dropout_rate,
"batch_norm": args.batch_norm,
"activation": args.activation
},
"training_config": {
"num_epochs": args.num_epochs,
"batch_size": args.batch_size,
"learning_rate": args.learning_rate,
"weight_decay": args.weight_decay,
"scheduler": args.scheduler,
"early_stopping_patience": args.early_stopping_patience,
"validation_metric": args.validation_metric
},
"data_config": {
"train_ratio": args.train_ratio,
"val_ratio": args.val_ratio,
"normalize": args.normalize,
"use_weighted_sampler": args.use_weighted_sampler,
"random_seed": args.random_seed
},
"dataset_info": json_safe_convert(data_info)
}
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
logger.info(f"Configuration saved to {config_path}")
# Start training
try:
logger.info("Starting training...")
results = trainer.train(
num_epochs=args.num_epochs,
save_every=args.save_every,
early_stopping_patience=args.early_stopping_patience,
min_delta=args.min_delta,
validation_score=args.validation_metric
)
# Save results
results_path = Path(args.save_dir) / args.experiment_name / "results.json"
with open(results_path, 'w') as f:
json.dump(results, f, indent=2, default=str)
logger.info(f"Training completed successfully!")
logger.info(f"Results saved to {results_path}")
logger.info(f"Best validation {args.validation_metric}: {results['best_val_score']:.4f}")
if results.get('test_metrics'):
logger.info(f"Test accuracy: {results['test_metrics']['accuracy']:.4f}")
logger.info(f"Test F1: {results['test_metrics']['f1']:.4f}")
logger.info(f"Test AUC: {results['test_metrics']['auc']:.4f}")
except KeyboardInterrupt:
logger.info("Training interrupted by user")
except Exception as e:
logger.error(f"Training failed: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

537
ml_new/training/trainer.py Normal file
View File

@ -0,0 +1,537 @@
"""
Model trainer for embedding classification
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
import numpy as np
from typing import Dict, Any, Optional
import json
from datetime import datetime
from pathlib import Path
from ml_new.training.models import EmbeddingClassifier
from ml_new.config.logger_config import get_logger
logger = get_logger(__name__)
class ModelTrainer:
"""
Trainer for embedding classification models
"""
def __init__(
self,
model: EmbeddingClassifier,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: Optional[DataLoader] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
criterion: Optional[nn.Module] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[torch.device] = None,
save_dir: str = "training/checkpoints",
experiment_name: Optional[str] = None
):
"""
Initialize model trainer
Args:
model: PyTorch model to train
train_loader: Training data loader
val_loader: Validation data loader
test_loader: Test data loader (optional)
optimizer: Optimizer instance (default: Adam)
criterion: Loss function (default: BCEWithLogitsLoss)
scheduler: Learning rate scheduler (optional)
device: Device to use (default: auto-detect)
save_dir: Directory to save checkpoints
experiment_name: Name for the experiment
"""
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
# Device setup
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
self.model.to(self.device)
# Loss function and optimizer
self.criterion = criterion or nn.BCEWithLogitsLoss()
self.optimizer = optimizer or optim.Adam(self.model.parameters(), lr=0.001)
self.scheduler = scheduler
# Training configuration
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
# Experiment tracking
self.experiment_name = experiment_name or f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
self.checkpoint_dir = self.save_dir / self.experiment_name
self.checkpoint_dir.mkdir(exist_ok=True)
# TensorBoard logging
self.writer = SummaryWriter(log_dir=str(self.checkpoint_dir / "logs"))
# Training state
self.current_epoch = 0
self.best_val_score = 0.0
self.patience_counter = 0
self.training_history = {
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': [],
'val_auc': [],
'learning_rates': []
}
logger.info(f"Initialized trainer with device: {self.device}")
logger.info(f"Model info: {self.model.get_model_info()}")
def train_epoch(self) -> Dict[str, float]:
"""Train for one epoch"""
self.model.train()
total_loss = 0.0
all_predictions = []
all_labels = []
for batch_idx, (embeddings, labels, metadata) in enumerate(self.train_loader):
embeddings = embeddings.to(self.device)
labels = labels.to(self.device).float()
# Zero gradients
self.optimizer.zero_grad()
# Forward pass
outputs = self.model(embeddings)
loss = self.criterion(outputs.squeeze(), labels)
# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Collect statistics
total_loss += loss.item()
# Get predictions for metrics
with torch.no_grad():
predictions = torch.sigmoid(outputs).squeeze()
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Log progress
if batch_idx % 5 == 0:
logger.info(f"Batch {batch_idx}/{len(self.train_loader)}, Loss: {loss.item():.4f}")
# Calculate epoch metrics
epoch_loss = total_loss / len(self.train_loader)
epoch_acc = accuracy_score(all_labels, (np.array(all_predictions) > 0.5).astype(int))
return {
'loss': epoch_loss,
'accuracy': epoch_acc
}
def validate(self) -> Dict[str, float]:
"""Validate the model"""
self.model.eval()
total_loss = 0.0
all_predictions = []
all_labels = []
all_probabilities = []
with torch.no_grad():
for embeddings, labels, metadata in self.val_loader:
embeddings = embeddings.to(self.device)
labels = labels.to(self.device).float()
# Forward pass
outputs = self.model(embeddings)
loss = self.criterion(outputs.squeeze(), labels)
# Collect statistics
total_loss += loss.item()
# Get predictions and probabilities
probabilities = torch.sigmoid(outputs).squeeze()
predictions = (probabilities > 0.5).long()
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# Calculate metrics
val_loss = total_loss / len(self.val_loader)
val_accuracy = accuracy_score(all_labels, all_predictions)
# Calculate additional metrics
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_predictions, average='binary', zero_division=0
)
try:
val_auc = roc_auc_score(all_labels, all_probabilities)
except ValueError:
val_auc = 0.0 # AUC not defined for single class
# Confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
return {
'loss': val_loss,
'accuracy': val_accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'auc': val_auc,
'true_negatives': tn,
'false_positives': fp,
'false_negatives': fn,
'true_positives': tp
}
def train(
self,
num_epochs: int = 100,
save_every: int = 10,
early_stopping_patience: int = 15,
min_delta: float = 0.001,
validation_score: str = 'f1'
) -> Dict[str, Any]:
"""
Train the model
Args:
num_epochs: Number of training epochs
save_every: Save checkpoint every N epochs
early_stopping_patience: Patience for early stopping
min_delta: Minimum improvement for early stopping
validation_score: Metric to use for early stopping ('f1', 'auc', 'accuracy')
Returns:
Training results dictionary
"""
logger.info(f"Starting training for {num_epochs} epochs")
for epoch in range(num_epochs):
self.current_epoch = epoch
# Training
train_metrics = self.train_epoch()
# Validation
val_metrics = self.validate()
# Learning rate scheduling
if self.scheduler:
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(val_metrics['loss'])
else:
self.scheduler.step()
# Log metrics
current_lr = self.optimizer.param_groups[0]['lr']
# TensorBoard logging
self.writer.add_scalar('Loss/Train', train_metrics['loss'], epoch)
self.writer.add_scalar('Accuracy/Train', train_metrics['accuracy'], epoch)
self.writer.add_scalar('Loss/Validation', val_metrics['loss'], epoch)
self.writer.add_scalar('Accuracy/Validation', val_metrics['accuracy'], epoch)
self.writer.add_scalar('F1/Validation', val_metrics['f1'], epoch)
self.writer.add_scalar('AUC/Validation', val_metrics['auc'], epoch)
self.writer.add_scalar('Learning_Rate', current_lr, epoch)
# Update training history
self.training_history['train_loss'].append(train_metrics['loss'])
self.training_history['train_acc'].append(train_metrics['accuracy'])
self.training_history['val_loss'].append(val_metrics['loss'])
self.training_history['val_acc'].append(val_metrics['accuracy'])
self.training_history['val_auc'].append(val_metrics['auc'])
self.training_history['learning_rates'].append(current_lr)
# Print progress
logger.info(
f"Epoch {epoch+1}/{num_epochs} - "
f"Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.4f} - "
f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}, "
f"Val F1: {val_metrics['f1']:.4f}, Val AUC: {val_metrics['auc']:.4f}"
)
# Check for best model
val_score = val_metrics[validation_score]
if val_score > self.best_val_score + min_delta:
self.best_val_score = val_score
self.patience_counter = 0
# Save best model
self.save_checkpoint(is_best=True)
logger.info(f"New best {validation_score}: {val_score:.4f}")
else:
self.patience_counter += 1
# Save checkpoint
if (epoch + 1) % save_every == 0:
self.save_checkpoint(is_best=False)
# Early stopping
if self.patience_counter >= early_stopping_patience:
logger.info(f"Early stopping triggered after {epoch+1} epochs")
break
# Final evaluation on test set if available
test_metrics = None
if self.test_loader:
test_metrics = self.evaluate_test_set()
# Save final training state
self.save_training_state()
# Close TensorBoard writer
self.writer.close()
results = {
'experiment_name': self.experiment_name,
'best_val_score': self.best_val_score,
'total_epochs': self.current_epoch + 1,
'training_history': self.training_history,
'final_val_metrics': self.validate(),
'test_metrics': test_metrics,
'model_info': self.model.get_model_info()
}
logger.info(f"Training completed. Best validation {validation_score}: {self.best_val_score:.4f}")
return results
def evaluate_test_set(self) -> Dict[str, float]:
"""Evaluate model on test set"""
self.model.eval()
total_loss = 0.0
all_predictions = []
all_labels = []
all_probabilities = []
with torch.no_grad():
for embeddings, labels, metadata in self.test_loader:
embeddings = embeddings.to(self.device)
labels = labels.to(self.device).float()
# Forward pass
outputs = self.model(embeddings)
loss = self.criterion(outputs.squeeze(), labels)
# Collect statistics
total_loss += loss.item()
# Get predictions and probabilities
probabilities = torch.sigmoid(outputs).squeeze()
predictions = (probabilities > 0.5).long()
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# Calculate metrics
test_loss = total_loss / len(self.test_loader)
test_accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_predictions, average='binary', zero_division=0
)
try:
test_auc = roc_auc_score(all_labels, all_probabilities)
except ValueError:
test_auc = 0.0
logger.info(f"Test Set Results - Loss: {test_loss:.4f}, Acc: {test_accuracy:.4f}, F1: {f1:.4f}, AUC: {test_auc:.4f}")
return {
'loss': test_loss,
'accuracy': test_accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'auc': test_auc
}
def save_checkpoint(self, is_best: bool = False) -> None:
"""Save model checkpoint"""
checkpoint = {
'epoch': self.current_epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'best_val_score': self.best_val_score,
'training_history': self.training_history,
'model_config': {
'input_dim': self.model.input_dim,
'hidden_dims': self.model.hidden_dims,
'dropout_rate': self.model.dropout_rate,
'batch_norm': self.model.batch_norm
}
}
if self.scheduler:
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
# Save checkpoint
checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch_{self.current_epoch + 1}.pth"
torch.save(checkpoint, checkpoint_path)
# Save best model separately
if is_best:
best_model_path = self.checkpoint_dir / "best_model.pth"
torch.save(checkpoint, best_model_path)
logger.info(f"Saved best model to {best_model_path}")
def save_training_state(self) -> None:
"""Save final training state"""
state = {
'experiment_name': self.experiment_name,
'best_val_score': self.best_val_score,
'total_epochs': self.current_epoch + 1,
'training_history': self.training_history,
'model_info': self.model.get_model_info(),
'training_config': {
'optimizer': self.optimizer.__class__.__name__,
'criterion': self.criterion.__class__.__name__,
'device': str(self.device)
}
}
state_path = self.checkpoint_dir / "training_state.json"
with open(state_path, 'w') as f:
json.dump(state, f, indent=2)
logger.info(f"Saved training state to {state_path}")
def load_checkpoint(self, checkpoint_path: str, load_optimizer: bool = True) -> None:
"""Load model from checkpoint"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
if load_optimizer:
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if self.scheduler and 'scheduler_state_dict' in checkpoint:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.current_epoch = checkpoint['epoch']
self.best_val_score = checkpoint.get('best_val_score', 0.0)
self.training_history = checkpoint.get('training_history', {})
logger.info(f"Loaded checkpoint from {checkpoint_path}")
def create_trainer(
model: EmbeddingClassifier,
train_loader: DataLoader,
val_loader: DataLoader,
test_loader: Optional[DataLoader] = None,
learning_rate: float = 0.001,
weight_decay: float = 1e-5,
scheduler_type: Optional[str] = 'plateau',
**kwargs
) -> ModelTrainer:
"""
Factory function to create a configured trainer
Args:
model: The model to train
train_loader: Training data loader
val_loader: Validation data loader
test_loader: Test data loader (optional)
learning_rate: Initial learning rate
weight_decay: L2 regularization
scheduler_type: Learning rate scheduler type ('plateau', 'step', None)
**kwargs: Additional trainer arguments
Returns:
Configured ModelTrainer instance
"""
# Create optimizer
optimizer = optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=weight_decay
)
# Create scheduler
scheduler = None
if scheduler_type == 'plateau':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
factor=0.5,
patience=10
)
elif scheduler_type == 'step':
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=1,
gamma=0.1
)
# Create trainer
trainer = ModelTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
optimizer=optimizer,
scheduler=scheduler,
**kwargs
)
return trainer
if __name__ == "__main__":
# Test trainer creation
from ml_new.training.models import create_model
from ml_new.training.data_loader import DatasetLoader
# Create dummy model and data
model = create_model(
model_type="standard",
input_dim=2048,
hidden_dims=(512, 256, 128)
)
loader = DatasetLoader()
datasets = loader.list_datasets()
if datasets:
train_loader, val_loader, test_loader, data_info = loader.create_data_loaders(
datasets[0],
batch_size=8,
normalize=True
)
# Create trainer
trainer = create_trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
experiment_name="test_experiment"
)
# Test one epoch
print("Testing trainer...")
train_metrics = trainer.train_epoch()
val_metrics = trainer.validate()
print(f"Train metrics: {train_metrics}")
print(f"Val metrics: {val_metrics}")