""" Data loader for embedding datasets """ import pandas as pd import numpy as np import torch from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler from typing import List, Dict, Any, Optional, Tuple from pathlib import Path import json from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from ml_new.config.logger_config import get_logger logger = get_logger(__name__) class EmbeddingDataset(Dataset): """ PyTorch Dataset for embedding-based classification """ def __init__( self, embeddings: np.ndarray, labels: np.ndarray, metadata: Optional[List[Dict[str, Any]]] = None, transform: Optional[callable] = None, normalize: bool = True ): """ Initialize embedding dataset Args: embeddings: Array of embedding vectors (n_samples, embedding_dim) labels: Array of binary labels (n_samples,) metadata: Optional list of metadata dictionaries transform: Optional transformation function normalize: Whether to normalize embeddings """ assert len(embeddings) == len(labels), "Embeddings and labels must have same length" self.embeddings = embeddings.astype(np.float32) self.labels = labels.astype(np.int64) self.metadata = metadata or [] self.transform = transform # Normalize embeddings if requested if normalize and len(embeddings) > 0: self.scaler = StandardScaler() self.embeddings = self.scaler.fit_transform(self.embeddings) else: self.scaler = None # Calculate class weights for balanced sampling self._calculate_class_weights() def _calculate_class_weights(self): """Calculate weights for each class for balanced sampling""" unique, counts = np.unique(self.labels, return_counts=True) total_samples = len(self.labels) self.class_weights = {} for class_label, count in zip(unique, counts): # Inverse frequency weighting weight = total_samples / (2 * count) self.class_weights[class_label] = weight logger.info(f"Class distribution: {dict(zip(unique, counts))}") logger.info(f"Class weights: {self.class_weights}") def __len__(self) -> int: return len(self.embeddings) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Get a single sample from the dataset Returns: tuple: (embedding, label, metadata) """ embedding = torch.from_numpy(self.embeddings[idx]) label = torch.tensor(self.labels[idx], dtype=torch.long) metadata = {} if self.metadata and idx < len(self.metadata): metadata = self.metadata[idx] if self.transform: embedding = self.transform(embedding) return embedding, label, metadata class DatasetLoader: """ Loader for embedding datasets stored in Parquet format """ def __init__(self, datasets_dir: str = "training/datasets"): """ Initialize dataset loader Args: datasets_dir: Directory containing dataset files """ self.datasets_dir = Path(datasets_dir) self.datasets_dir.mkdir(parents=True, exist_ok=True) def load_dataset(self, dataset_id: str) -> Tuple[np.ndarray, np.ndarray, List[Dict[str, Any]]]: """ Load a dataset by ID from Parquet files Args: dataset_id: Unique identifier for the dataset Returns: tuple: (embeddings, labels, metadata_list) """ dataset_file = self.datasets_dir / f"{dataset_id}.parquet" metadata_file = self.datasets_dir / f"{dataset_id}.metadata.json" if not dataset_file.exists(): raise FileNotFoundError(f"Dataset file not found: {dataset_file}") # Load metadata metadata = {} if metadata_file.exists(): with open(metadata_file, 'r') as f: metadata = json.load(f) # Load data from Parquet logger.info(f"Loading dataset {dataset_id} from {dataset_file}") df = pd.read_parquet(dataset_file) # Extract embeddings (they might be stored as list or numpy array) embeddings = self._extract_embeddings(df) # Extract labels labels = df['label'].values.astype(np.int64) # Extract metadata metadata_list = [] if 'metadata_json' in df.columns: for _, row in df.iterrows(): meta = {} if pd.notna(row.get('metadata_json')): try: meta = json.loads(row['metadata_json']) except (json.JSONDecodeError, TypeError): meta = {} # Add other fields meta.update({ 'aid': row.get('aid'), 'inconsistent': row.get('inconsistent', False), 'text_checksum': row.get('text_checksum') }) metadata_list.append(meta) else: # Create basic metadata metadata_list = [{ 'aid': aid, 'inconsistent': inconsistent, 'text_checksum': checksum } for aid, inconsistent, checksum in zip( df.get('aid', []), df.get('inconsistent', [False] * len(df)), df.get('text_checksum', [''] * len(df)) )] logger.info(f"Loaded dataset with {len(embeddings)} samples, {embeddings.shape[1]} embedding dimensions") return embeddings, labels, metadata_list def _extract_embeddings(self, df: pd.DataFrame) -> np.ndarray: """Extract embeddings from DataFrame, handling different storage formats""" embedding_col = None for col in ['embedding', 'embeddings', 'vec_2048', 'vec_1024']: if col in df.columns: embedding_col = col break if embedding_col is None: raise ValueError("No embedding column found in dataset") embeddings_data = df[embedding_col] # Handle different embedding storage formats if embeddings_data.dtype == 'object': # Likely stored as lists or numpy arrays embeddings = np.array([ np.array(emb) if isinstance(emb, (list, np.ndarray)) else np.zeros(2048) for emb in embeddings_data ]) else: # Already numpy array embeddings = embeddings_data.values # Ensure 2D array if embeddings.ndim == 1: # If embeddings are flattened, reshape embedding_dim = len(embeddings) // len(df) embeddings = embeddings.reshape(len(df), embedding_dim) return embeddings.astype(np.float32) def create_data_loaders( self, dataset_id: str, train_ratio: float = 0.8, val_ratio: float = 0.1, batch_size: int = 32, num_workers: int = 4, random_state: int = 42, normalize: bool = True, use_weighted_sampler: bool = True ) -> Tuple[DataLoader, DataLoader, DataLoader, Dict[str, Any]]: """ Create train, validation, and test data loaders Args: dataset_id: Dataset identifier train_ratio: Proportion of data for training val_ratio: Proportion of data for validation batch_size: Batch size for data loaders num_workers: Number of worker processes random_state: Random seed for reproducibility normalize: Whether to normalize embeddings use_weighted_sampler: Whether to use weighted random sampling Returns: tuple: (train_loader, val_loader, test_loader, dataset_info) """ # Load dataset embeddings, labels, metadata = self.load_dataset(dataset_id) # Split data ( train_emb, test_emb, train_lbl, test_lbl, train_meta, test_meta ) = train_test_split( embeddings, labels, metadata, test_size=1 - train_ratio, stratify=labels, random_state=random_state ) # Split test into val and test val_size = val_ratio / (val_ratio + (1 - train_ratio - val_ratio)) ( val_emb, test_emb, val_lbl, test_lbl, val_meta, test_meta ) = train_test_split( test_emb, test_lbl, test_meta, test_size=1 - val_size, stratify=test_lbl, random_state=random_state ) # Create datasets train_dataset = EmbeddingDataset(train_emb, train_lbl, train_meta, normalize=normalize) val_dataset = EmbeddingDataset(val_emb, val_lbl, val_meta, normalize=False) # Don't re-normalize test_dataset = EmbeddingDataset(test_emb, test_lbl, test_meta, normalize=False) # Create samplers train_sampler = None if use_weighted_sampler and hasattr(train_dataset, 'class_weights'): # Create weighted sampler for balanced training sample_weights = [train_dataset.class_weights[label] for label in train_dataset.labels] train_sampler = WeightedRandomSampler( weights=sample_weights, num_samples=len(sample_weights), replacement=True ) # Create data loaders train_loader = DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=(train_sampler is None), num_workers=num_workers ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) # Dataset info dataset_info = { 'dataset_id': dataset_id, 'total_samples': len(embeddings), 'embedding_dim': embeddings.shape[1], 'train_samples': len(train_dataset), 'val_samples': len(val_dataset), 'test_samples': len(test_dataset), 'train_ratio': len(train_dataset) / len(embeddings), 'val_ratio': len(val_dataset) / len(embeddings), 'test_ratio': len(test_dataset) / len(embeddings), 'class_distribution': { 'train': dict(zip(*np.unique(train_dataset.labels, return_counts=True))), 'val': dict(zip(*np.unique(val_dataset.labels, return_counts=True))), 'test': dict(zip(*np.unique(test_dataset.labels, return_counts=True))) }, 'normalize': normalize, 'use_weighted_sampler': use_weighted_sampler } logger.info(f"Created data loaders: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}") return train_loader, val_loader, test_loader, dataset_info def list_datasets(self) -> List[str]: """List all available datasets""" parquet_files = list(self.datasets_dir.glob("*.parquet")) return [f.stem for f in parquet_files] def get_dataset_info(self, dataset_id: str) -> Dict[str, Any]: """Get detailed information about a dataset""" metadata_file = self.datasets_dir / f"{dataset_id}.metadata.json" if metadata_file.exists(): with open(metadata_file, 'r') as f: return json.load(f) # Fallback: load dataset and return basic info embeddings, labels, metadata = self.load_dataset(dataset_id) return { 'dataset_id': dataset_id, 'total_samples': len(embeddings), 'embedding_dim': embeddings.shape[1], 'class_distribution': dict(zip(*np.unique(labels, return_counts=True))), 'file_format': 'parquet', 'created_at': 'unknown' } if __name__ == "__main__": # Test dataset loading loader = DatasetLoader() # List available datasets datasets = loader.list_datasets() print(f"Available datasets: {datasets}") if datasets: # Test loading first dataset dataset_id = datasets[0] print(f"\nTesting dataset: {dataset_id}") info = loader.get_dataset_info(dataset_id) print("Dataset info:", info) # Test creating data loaders try: train_loader, val_loader, test_loader, data_info = loader.create_data_loaders( dataset_id, batch_size=8, normalize=True ) print("\nData loader info:") for key, value in data_info.items(): print(f" {key}: {value}") # Test single batch for batch_embeddings, batch_labels, batch_metadata in train_loader: print(f"\nBatch test:") print(f" Embeddings shape: {batch_embeddings.shape}") print(f" Labels shape: {batch_labels.shape}") print(f" Sample labels: {batch_labels[:5].tolist()}") break except Exception as e: print(f"Error creating data loaders: {e}")