diff --git a/ml_new/training/api_routes.py b/ml_new/training/api_routes.py index 937ead9..b49a522 100644 --- a/ml_new/training/api_routes.py +++ b/ml_new/training/api_routes.py @@ -188,18 +188,6 @@ async def get_dataset_stats_endpoint(): stats = dataset_builder.get_dataset_stats() return stats - -@router.post("/datasets/cleanup") -async def cleanup_datasets_endpoint(max_age_days: int = 30): - """Remove datasets older than specified days""" - - if not dataset_builder: - raise HTTPException(status_code=503, detail="Dataset builder not available") - - await dataset_builder.cleanup_old_datasets(max_age_days) - return {"message": f"Cleanup completed for datasets older than {max_age_days} days"} - - # Task Status Endpoints @router.get("/tasks/{task_id}", response_model=TaskStatusResponse) diff --git a/ml_new/training/dataset_service.py b/ml_new/training/dataset_service.py index bf7ff81..054b6c4 100644 --- a/ml_new/training/dataset_service.py +++ b/ml_new/training/dataset_service.py @@ -2,7 +2,6 @@ Dataset building service - handles the complete dataset construction flow """ -import json import asyncio from pathlib import Path from typing import List, Dict, Any, Optional @@ -14,6 +13,7 @@ from embedding_service import EmbeddingService from config_loader import config_loader from logger_config import get_logger from models import TaskStatus, DatasetBuildTaskStatus, TaskProgress +from dataset_storage_parquet import ParquetDatasetStorage logger = get_logger(__name__) @@ -28,63 +28,13 @@ class DatasetBuilder: self.storage_dir = Path(storage_dir) self.storage_dir.mkdir(exist_ok=True) - # Load existing datasets from file system - self.dataset_storage: Dict[str, Dict] = self._load_all_datasets() + self.storage = ParquetDatasetStorage(str(self.storage_dir)) # Task status tracking self._task_status_lock = threading.Lock() self.task_statuses: Dict[str, DatasetBuildTaskStatus] = {} self.running_tasks: Dict[str, asyncio.Task] = {} - - def _get_dataset_file_path(self, dataset_id: str) -> Path: - """Get file path for dataset""" - return self.storage_dir / f"{dataset_id}.json" - - - def _load_dataset_from_file(self, dataset_id: str) -> Optional[Dict[str, Any]]: - """Load dataset from file""" - file_path = self._get_dataset_file_path(dataset_id) - if not file_path.exists(): - return None - - try: - with open(file_path, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logger.error(f"Failed to load dataset {dataset_id} from file: {e}") - return None - - - def _save_dataset_to_file(self, dataset_id: str, dataset_data: Dict[str, Any]) -> bool: - """Save dataset to file""" - file_path = self._get_dataset_file_path(dataset_id) - - try: - with open(file_path, 'w', encoding='utf-8') as f: - json.dump(dataset_data, f, ensure_ascii=False, indent=2) - return True - except Exception as e: - logger.error(f"Failed to save dataset {dataset_id} to file: {e}") - return False - - - def _load_all_datasets(self) -> Dict[str, Dict]: - """Load all datasets from file system""" - datasets = {} - - try: - for file_path in self.storage_dir.glob("*.json"): - dataset_id = file_path.stem - dataset_data = self._load_dataset_from_file(dataset_id) - if dataset_data: - datasets[dataset_id] = dataset_data - logger.info(f"Loaded {len(datasets)} datasets from file system") - except Exception as e: - logger.error(f"Failed to load datasets from file system: {e}") - - return datasets - def _create_task_status(self, task_id: str, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool) -> DatasetBuildTaskStatus: @@ -100,7 +50,7 @@ class DatasetBuilder: created_at=datetime.now(), progress=TaskProgress( current_step="initialized", - total_steps=7, + total_steps=8, completed_steps=0, percentage=0.0, message="Task initialized" @@ -214,36 +164,6 @@ class DatasetBuilder: return cleaned_count - async def cleanup_old_datasets(self, max_age_days: int = 30): - """Remove datasets older than specified days""" - try: - cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60) - removed_count = 0 - - for dataset_id in list(self.dataset_storage.keys()): - dataset_info = self.dataset_storage[dataset_id] - if "created_at" in dataset_info: - try: - created_time = datetime.fromisoformat(dataset_info["created_at"]).timestamp() - if created_time < cutoff_time: - # Remove from memory - del self.dataset_storage[dataset_id] - - # Remove file - file_path = self._get_dataset_file_path(dataset_id) - if file_path.exists(): - file_path.unlink() - - removed_count += 1 - except Exception as e: - logger.warning(f"Failed to process dataset {dataset_id} for cleanup: {e}") - - if removed_count > 0: - logger.info(f"Cleaned up {removed_count} old datasets") - - except Exception as e: - logger.error(f"Failed to cleanup old datasets: {e}") - async def build_dataset_with_task_tracking(self, task_id: str, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool = False, description: Optional[str] = None) -> str: @@ -272,15 +192,15 @@ class DatasetBuilder: raise ValueError(f"Invalid embedding model: {embedding_model}") model_config = EMBEDDING_MODELS[embedding_model] - self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database") + self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database", 3) # Step 2: Get video metadata from database metadata = await self.db_manager.get_video_metadata(aid_list) - self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database") + self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database", 15) # Step 3: Get user labels labels = await self.db_manager.get_user_labels(aid_list) - self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums") + self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums", 30) # Step 4: Prepare text data and checksums text_data = [] @@ -314,7 +234,7 @@ class DatasetBuilder: f"Prepared {i + 1}/{total_aids} text entries" ) - self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings") + self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings", 33) # Step 5: Check existing embeddings checksums = [item['checksum'] for item in text_data] @@ -332,13 +252,26 @@ class DatasetBuilder: task_id, "generating_embeddings", 5, - f"Generating {len(new_embeddings_needed)} new embeddings" + f"Generating {len(new_embeddings_needed)} new embeddings", + 50 ) + # Create progress callback for embedding generation + def embedding_progress_callback(current: int, total: int): + percentage = (current / total) * 25 + 50 # Map to progress range 5-6 + self._update_task_progress( + task_id, + "generating_embeddings", + 5, + f"Generated {current}/{total} embeddings", + percentage + ) + logger.info(f"Generating {len(new_embeddings_needed)} new embeddings") generated_embeddings = await self.embedding_service.generate_embeddings_batch( new_embeddings_needed, - embedding_model + embedding_model, + progress_callback=embedding_progress_callback ) # Step 7: Store new embeddings in database @@ -352,6 +285,14 @@ class DatasetBuilder: 'vector': embedding }) + self._update_task_progress( + task_id, + "inserting_embeddings", + 6, + f"Inserting {len(embeddings_to_store)} embeddings into database", + 75 + ) + await self.db_manager.insert_embeddings(embeddings_to_store) new_embeddings_count = len(embeddings_to_store) @@ -362,7 +303,7 @@ class DatasetBuilder: f'vec_{model_config.dimensions}': emb_data['vector'] } - self._update_task_progress(task_id, "building_dataset", 6, "Building final dataset") + self._update_task_progress(task_id, "building_dataset", 7, "Building final dataset", 95) # Step 8: Build final dataset dataset = [] @@ -410,12 +351,13 @@ class DatasetBuilder: # Update progress for dataset building if i % 10 == 0 or i == len(text_data) - 1: # Update every 10 items or at the end - progress_pct = 6 + (i + 1) / len(text_data) + progress_pct = 7 + (i + 1) / len(text_data) self._update_task_progress( task_id, "building_dataset", - min(6, int(progress_pct)), - f"Built {i + 1}/{len(text_data)} dataset records" + min(7, int(progress_pct)), + f"Built {i + 1}/{len(text_data)} dataset records", + 100 ) reused_count = len(dataset) - new_embeddings_count @@ -436,15 +378,13 @@ class DatasetBuilder: 'created_at': datetime.now().isoformat() } - self._update_task_progress(task_id, "saving_dataset", 7, "Saving dataset to storage") + self._update_task_progress(task_id, "saving_dataset", 8, "Saving dataset to storage") - # Save to file and memory cache - if self._save_dataset_to_file(dataset_id, dataset_data): - self.dataset_storage[dataset_id] = dataset_data - logger.info(f"Dataset {dataset_id} saved to file system") + # Save to file using efficient storage (Parquet format) + if self.storage.save_dataset(dataset_id, dataset_data['dataset'], description, dataset_data['stats']): + logger.info(f"Dataset {dataset_id} saved to efficient storage (Parquet format)") else: - logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only") - self.dataset_storage[dataset_id] = dataset_data + logger.warning(f"Failed to save dataset {dataset_id} to storage") # Update task status to completed result = { @@ -459,8 +399,8 @@ class DatasetBuilder: result=result, progress=TaskProgress( current_step="completed", - total_steps=7, - completed_steps=7, + total_steps=8, + completed_steps=8, percentage=100.0, message="Dataset building completed successfully" ) @@ -479,7 +419,7 @@ class DatasetBuilder: error_message=str(e), progress=TaskProgress( current_step="failed", - total_steps=7, + total_steps=8, completed_steps=0, percentage=0.0, message=f"Task failed: {str(e)}" @@ -492,9 +432,8 @@ class DatasetBuilder: 'created_at': datetime.now().isoformat() } - # Try to save error to file as well - self._save_dataset_to_file(dataset_id, error_data) - self.dataset_storage[dataset_id] = error_data + # Try to save error using new storage + self.storage.save_dataset(dataset_id, [], description=str(e), stats={'error': True}) raise @@ -523,91 +462,28 @@ class DatasetBuilder: def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]: """Get built dataset by ID""" - # First check memory cache - if dataset_id in self.dataset_storage: - return self.dataset_storage[dataset_id] - - # If not in memory, try to load from file - dataset_data = self._load_dataset_from_file(dataset_id) - if dataset_data: - # Add to memory cache - self.dataset_storage[dataset_id] = dataset_data - return dataset_data - - return None + # Use pure Parquet storage interface + return self.storage.load_dataset_full(dataset_id) def dataset_exists(self, dataset_id: str) -> bool: """Check if dataset exists""" - # Check memory cache first - if dataset_id in self.dataset_storage: - return True - - # Check file system - return self._get_dataset_file_path(dataset_id).exists() + return self.storage.dataset_exists(dataset_id) def delete_dataset(self, dataset_id: str) -> bool: """Delete dataset from both memory and file system""" - try: - # Remove from memory - if dataset_id in self.dataset_storage: - del self.dataset_storage[dataset_id] - - # Remove file - file_path = self._get_dataset_file_path(dataset_id) - if file_path.exists(): - file_path.unlink() - logger.info(f"Dataset {dataset_id} deleted from file system") - return True - else: - logger.warning(f"Dataset file {dataset_id} not found for deletion") - return False - - except Exception as e: - logger.error(f"Failed to delete dataset {dataset_id}: {e}") - return False + return self.storage.delete_dataset(dataset_id) def list_datasets(self) -> List[Dict[str, Any]]: """List all datasets with their basic information""" - datasets = [] - - self._load_all_datasets() - - for dataset_id, dataset_info in self.dataset_storage.items(): - if "error" not in dataset_info: - datasets.append({ - "dataset_id": dataset_id, - "stats": dataset_info["stats"], - "created_at": dataset_info["created_at"] - }) - - # Sort by creation time (newest first) - datasets.sort(key=lambda x: x["created_at"], reverse=True) - - return datasets + return self.storage.list_datasets() def get_dataset_stats(self) -> Dict[str, Any]: """Get overall statistics about stored datasets""" - total_datasets = len(self.dataset_storage) - error_datasets = sum(1 for data in self.dataset_storage.values() if "error" in data) - valid_datasets = total_datasets - error_datasets - - total_records = 0 - total_new_embeddings = 0 - total_reused_embeddings = 0 - - for dataset_info in self.dataset_storage.values(): - if "stats" in dataset_info: - stats = dataset_info["stats"] - total_records += stats.get("total_records", 0) - total_new_embeddings += stats.get("new_embeddings", 0) - total_reused_embeddings += stats.get("reused_embeddings", 0) - - return { - "total_datasets": total_datasets, - "valid_datasets": valid_datasets, - "error_datasets": error_datasets, - "total_records": total_records, - "total_new_embeddings": total_new_embeddings, - "total_reused_embeddings": total_reused_embeddings, - "storage_directory": str(self.storage_dir) - } \ No newline at end of file + return self.storage.get_dataset_stats() + + def load_dataset_metadata_fast(self, dataset_id: str) -> Optional[Dict[str, Any]]: + return self.storage.load_dataset_metadata(dataset_id) + + def load_dataset_partial(self, dataset_id: str, columns: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None): + return self.storage.load_dataset_partial(dataset_id, columns, filters) \ No newline at end of file diff --git a/ml_new/training/dataset_storage_parquet.py b/ml_new/training/dataset_storage_parquet.py new file mode 100644 index 0000000..26518ae --- /dev/null +++ b/ml_new/training/dataset_storage_parquet.py @@ -0,0 +1,393 @@ +""" +Efficient dataset storage using Parquet format for better space utilization and loading performance +""" + +import pandas as pd +import numpy as np +import json +import os +from pathlib import Path +from typing import List, Dict, Any, Optional, Union +from datetime import datetime +import pyarrow as pa +import pyarrow.parquet as pq +from logger_config import get_logger + +logger = get_logger(__name__) + +class ParquetDatasetStorage: + def __init__(self, storage_dir: str = "datasets"): + self.storage_dir = Path(storage_dir) + self.storage_dir.mkdir(exist_ok=True) + + # Parquet file extension + self.parquet_ext = ".parquet" + self.metadata_ext = ".metadata.json" + + # In-memory cache: only cache metadata to avoid large file memory usage + self.metadata_cache: Dict[str, Dict[str, Any]] = {} + self._load_metadata_cache() + + def _get_dataset_files(self, dataset_id: str) -> tuple[Path, Path]: + """Get file paths for the dataset""" + base_path = self.storage_dir / dataset_id + data_file = base_path.with_suffix(self.parquet_ext) + metadata_file = base_path.with_suffix(self.metadata_ext) + return data_file, metadata_file + + def _load_metadata_cache(self): + """Load metadata cache""" + try: + for metadata_file in self.storage_dir.glob("*.metadata.json"): + try: + # Remove ".metadata" suffix + dataset_id = metadata_file.stem[:-9] + with open(metadata_file, 'r', encoding='utf-8') as f: + metadata = json.load(f) + self.metadata_cache[dataset_id] = metadata + except Exception as e: + logger.warning(f"Failed to load metadata for {metadata_file}: {e}") + + logger.info(f"Loaded metadata for {len(self.metadata_cache)} datasets") + + except Exception as e: + logger.error(f"Failed to load metadata cache: {e}") + + def save_dataset(self, dataset_id: str, dataset: List[Dict[str, Any]], + description: Optional[str] = None, stats: Optional[Dict[str, Any]] = None) -> bool: + """ + Save dataset using Parquet format + + Args: + dataset_id: Dataset ID + dataset: Dataset content + description: Dataset description + stats: Dataset statistics + + Returns: + bool: Whether the save was successful + """ + try: + data_file, metadata_file = self._get_dataset_files(dataset_id) + + # Ensure directory exists + data_file.parent.mkdir(parents=True, exist_ok=True) + + # Prepare data: convert embedding vectors to numpy arrays + if not dataset: + logger.warning(f"Empty dataset for {dataset_id}") + return False + + # Analyze data structure + first_item = dataset[0] + embedding_dim = len(first_item.get('embedding', [])) + + # Build DataFrame + records = [] + for item in dataset: + record = { + 'aid': item.get('aid'), + 'label': item.get('label'), + 'inconsistent': item.get('inconsistent', False), + 'text_checksum': item.get('text_checksum'), + # Store embedding as a separate column + 'embedding': item.get('embedding', []), + # Store metadata as JSON string + 'metadata_json': json.dumps(item.get('metadata', {}), ensure_ascii=False), + 'user_labels_json': json.dumps(item.get('user_labels', []), ensure_ascii=False) + } + records.append(record) + + # Create DataFrame + df = pd.DataFrame(records) + + # Convert embedding column to numpy arrays + df['embedding'] = df['embedding'].apply(lambda x: np.array(x, dtype=np.float32) if x else np.array([], dtype=np.float32)) + + # Use PyArrow Schema for type safety + schema = pa.schema([ + ('aid', pa.int64()), + ('label', pa.bool_()), + ('inconsistent', pa.bool_()), + ('text_checksum', pa.string()), + ('embedding', pa.list_(pa.float32())), + ('metadata_json', pa.string()), + ('user_labels_json', pa.string()) + ]) + + # Convert to PyArrow Table + table = pa.Table.from_pandas(df, schema=schema) + + # Write Parquet file with efficient compression settings + pq.write_table( + table, + data_file, + compression='zstd', # Better compression ratio + compression_level=6, # Balance compression ratio and speed + use_dictionary=True, # Enable dictionary encoding + write_page_index=True, # Support fast metadata access + write_statistics=True # Enable statistics + ) + + # Save metadata + metadata = { + 'dataset_id': dataset_id, + 'description': description, + 'stats': stats or {}, + 'created_at': datetime.now().isoformat(), + 'file_format': 'parquet_v1', + 'embedding_dimension': embedding_dim, + 'total_records': len(dataset), + 'columns': list(df.columns), + 'file_size_bytes': data_file.stat().st_size, + 'compression': 'zstd' + } + + with open(metadata_file, 'w', encoding='utf-8') as f: + json.dump(metadata, f, ensure_ascii=False, indent=2) + + # Update cache + self.metadata_cache[dataset_id] = metadata + + logger.info(f"Saved dataset {dataset_id} to Parquet: {len(dataset)} records, {data_file.stat().st_size} bytes") + return True + + except Exception as e: + logger.error(f"Failed to save dataset {dataset_id}: {e}") + return False + + def load_dataset_metadata(self, dataset_id: str) -> Optional[Dict[str, Any]]: + """ + Quickly load dataset metadata (without loading the entire file) + + Args: + dataset_id: Dataset ID + + Returns: + Dict: Metadata, or None if not found + """ + # Check cache + if dataset_id in self.metadata_cache: + return self.metadata_cache[dataset_id] + + # Load from file + _, metadata_file = self._get_dataset_files(dataset_id) + if not metadata_file.exists(): + return None + + try: + with open(metadata_file, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + # Update cache + self.metadata_cache[dataset_id] = metadata + return metadata + + except Exception as e: + logger.error(f"Failed to load metadata for {dataset_id}: {e}") + return None + + def load_dataset_partial(self, dataset_id: str, columns: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None) -> Optional[pd.DataFrame]: + """ + Partially load the dataset (only reading specified columns or rows meeting criteria) + + Args: + dataset_id: Dataset ID + columns: Columns to read, None to read all + filters: Filtering conditions, format {column: value} + + Returns: + pd.DataFrame: Loaded data, or None if failed + """ + data_file, _ = self._get_dataset_files(dataset_id) + if not data_file.exists(): + return None + + try: + # Read Parquet file, supporting column selection and filtering + if columns: + # Ensure necessary columns exist + all_columns = ['aid', 'label', 'inconsistent', 'text_checksum', 'embedding', 'metadata_json', 'user_labels_json'] + required_cols = ['aid', 'label', 'embedding'] # These are fundamentally needed + columns = list(set(columns + required_cols)) + + # Filter out non-existent columns + columns = [col for col in columns if col in all_columns] + + # Use pyarrow to read, supporting filters + if filters: + # Build filter expressions + expressions = [] + for col, value in filters.items(): + if col == 'label': + expressions.append(pa.compute.equal(pa.field(col), value)) + elif col == 'aid': + expressions.append(pa.compute.equal(pa.field(col), value)) + + if expressions: + filter_expr = expressions[0] + for expr in expressions[1:]: + filter_expr = pa.compute.and_(filter_expr, expr) + else: + filter_expr = None + else: + filter_expr = None + + # Read data + if columns and filter_expr: + table = pq.read_table(data_file, columns=columns, filter=filter_expr) + elif columns: + table = pq.read_table(data_file, columns=columns) + elif filter_expr: + table = pq.read_table(data_file, filter=filter_expr) + else: + table = pq.read_table(data_file) + + # Convert to DataFrame + df = table.to_pandas() + + # Handle embedding column + if 'embedding' in df.columns: + df['embedding'] = df['embedding'].apply(lambda x: x.tolist() if hasattr(x, 'tolist') else list(x)) + + logger.info(f"Loaded partial dataset {dataset_id}: {len(df)} rows, {len(df.columns)} columns") + return df + + except Exception as e: + logger.error(f"Failed to load partial dataset {dataset_id}: {e}") + return None + + def load_dataset_full(self, dataset_id: str) -> Optional[Dict[str, Any]]: + """ + Fully load the dataset (maintaining backward compatibility format) + + Args: + dataset_id: Dataset ID + + Returns: + Dict: Full dataset data, or None if failed + """ + data_file, _ = self._get_dataset_files(dataset_id) + if not data_file.exists(): + return None + + try: + # Load metadata + metadata = self.load_dataset_metadata(dataset_id) + if not metadata: + return None + + # Load data + df = self.load_dataset_partial(dataset_id) + if df is None: + return None + + # Convert to original format + dataset = [] + for _, row in df.iterrows(): + record = { + 'aid': int(row['aid']), + 'embedding': row['embedding'], + 'label': bool(row['label']), + 'metadata': json.loads(row['metadata_json']) if row['metadata_json'] else {}, + 'user_labels': json.loads(row['user_labels_json']) if row['user_labels_json'] else [], + 'inconsistent': bool(row['inconsistent']), + 'text_checksum': row['text_checksum'] + } + dataset.append(record) + + return { + 'dataset': dataset, + 'description': metadata.get('description'), + 'stats': metadata.get('stats', {}), + 'created_at': metadata.get('created_at') + } + + except Exception as e: + logger.error(f"Failed to load full dataset {dataset_id}: {e}") + return None + + def dataset_exists(self, dataset_id: str) -> bool: + """Check if the dataset exists""" + data_file, _ = self._get_dataset_files(dataset_id) + return data_file.exists() + + def delete_dataset(self, dataset_id: str) -> bool: + """Delete a dataset""" + try: + data_file, metadata_file = self._get_dataset_files(dataset_id) + + # Delete files + if data_file.exists(): + data_file.unlink() + if metadata_file.exists(): + metadata_file.unlink() + + # Remove from cache + if dataset_id in self.metadata_cache: + del self.metadata_cache[dataset_id] + + logger.info(f"Deleted dataset {dataset_id}") + return True + + except Exception as e: + logger.error(f"Failed to delete dataset {dataset_id}: {e}") + return False + + def list_datasets(self) -> List[Dict[str, Any]]: + """List metadata for all datasets""" + datasets = [] + + for dataset_id, metadata in self.metadata_cache.items(): + datasets.append({ + "dataset_id": dataset_id, + "description": metadata.get("description"), + "stats": metadata.get("stats", {}), + "created_at": metadata.get("created_at"), + "total_records": metadata.get("total_records", 0), + "file_size_mb": round(metadata.get("file_size_bytes", 0) / (1024 * 1024), 2), + "embedding_dimension": metadata.get("embedding_dimension"), + "file_format": metadata.get("file_format") + }) + + # Sort by creation time descending + datasets.sort(key=lambda x: x["created_at"], reverse=True) + return datasets + + def get_dataset_stats(self) -> Dict[str, Any]: + """Get overall statistics""" + total_datasets = len(self.metadata_cache) + total_records = sum(m.get("total_records", 0) for m in self.metadata_cache.values()) + total_size_bytes = sum(m.get("file_size_bytes", 0) for m in self.metadata_cache.values()) + + return { + "total_datasets": total_datasets, + "total_records": total_records, + "total_size_mb": round(total_size_bytes / (1024 * 1024), 2), + "average_size_mb": round(total_size_bytes / total_datasets / (1024 * 1024), 2) if total_datasets > 0 else 0, + "storage_directory": str(self.storage_dir), + "storage_format": "parquet_v1" + } + + def migrate_from_json(self, dataset_id: str, json_data: Dict[str, Any]) -> bool: + """ + Migrate a dataset from JSON format to Parquet format + + Args: + dataset_id: Dataset ID + json_data: Data in JSON format + + Returns: + bool: Migration success status + """ + try: + dataset = json_data.get('dataset', []) + description = json_data.get('description') + stats = json_data.get('stats') + + return self.save_dataset(dataset_id, dataset, description, stats) + + except Exception as e: + logger.error(f"Failed to migrate dataset {dataset_id} from JSON: {e}") + return False \ No newline at end of file diff --git a/ml_new/training/embedding_service.py b/ml_new/training/embedding_service.py index 72c0cdc..eadd254 100644 --- a/ml_new/training/embedding_service.py +++ b/ml_new/training/embedding_service.py @@ -125,7 +125,8 @@ class EmbeddingService: self, texts: List[str], model: str, - batch_size: Optional[int] = None + batch_size: Optional[int] = None, + progress_callback: Optional[callable] = None ) -> List[List[float]]: """Generate embeddings for a batch of texts""" @@ -137,9 +138,9 @@ class EmbeddingService: # Handle different model types if model_config.type == "legacy": - return self._generate_legacy_embeddings_batch(texts, model, batch_size) + return self._generate_legacy_embeddings_batch(texts, model, batch_size, progress_callback) elif model_config.type == "openai-compatible": - return await self._generate_openai_embeddings_batch(texts, model, batch_size) + return await self._generate_openai_embeddings_batch(texts, model, batch_size, progress_callback=progress_callback) else: raise ValueError(f"Unsupported model type: {model_config.type}") @@ -147,7 +148,8 @@ class EmbeddingService: self, texts: List[str], model: str, - batch_size: Optional[int] = None + batch_size: Optional[int] = None, + progress_callback: Optional[callable] = None ) -> List[List[float]]: """Generate embeddings using legacy ONNX model""" if model not in self.legacy_models: @@ -162,9 +164,13 @@ class EmbeddingService: expected_dims = model_config.dimensions all_embeddings = [] + # Calculate total batches for progress tracking + total_batches = (len(texts) + batch_size - 1) // batch_size + # Process in batches for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] + batch_idx = i // batch_size + 1 try: # Generate embeddings using legacy method @@ -174,23 +180,33 @@ class EmbeddingService: batch_embeddings = embeddings.tolist() all_embeddings.extend(batch_embeddings) - logger.info(f"Generated legacy embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}") + logger.info(f"Generated legacy embeddings for batch {batch_idx}/{total_batches}") + + # Update progress if callback provided + if progress_callback: + progress_callback(batch_idx, total_batches) except Exception as e: - logger.error(f"Error generating legacy embeddings for batch {i//batch_size + 1}: {e}") + logger.error(f"Error generating legacy embeddings for batch {batch_idx}: {e}") # Fill with zeros as fallback zero_embedding = [0.0] * expected_dims all_embeddings.extend([zero_embedding] * len(batch)) + # Final progress update + if progress_callback: + progress_callback(total_batches, total_batches) + return all_embeddings async def _generate_openai_embeddings_batch( self, texts: List[str], model: str, - batch_size: Optional[int] = None + batch_size: Optional[int] = None, + max_concurrent: int = 6, + progress_callback: Optional[callable] = None ) -> List[List[float]]: - """Generate embeddings using OpenAI-compatible API""" + """Generate embeddings using OpenAI-compatible API with parallel requests""" model_config = self.embedding_models[model] # Use model's max_batch_size if not specified @@ -204,34 +220,71 @@ class EmbeddingService: raise ValueError(f"No client configured for model '{model}'") client = self.clients[model] - all_embeddings = [] - # Process in batches to avoid API limits - for i in range(0, len(texts), batch_size): - batch = texts[i:i + batch_size] - - try: - # Rate limiting - if i > 0: - await asyncio.sleep(self.request_interval) - - # Generate embeddings - response = await client.embeddings.create( - model=model_config.name, - input=batch, - dimensions=expected_dims - ) - - batch_embeddings = [data.embedding for data in response.data] - all_embeddings.extend(batch_embeddings) - - logger.info(f"Generated embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}") - - except Exception as e: - logger.error(f"Error generating embeddings for batch {i//batch_size + 1}: {e}") - # For now, fill with zeros as fallback (could implement retry logic) - zero_embedding = [0.0] * expected_dims - all_embeddings.extend([zero_embedding] * len(batch)) + # Split texts into batches + batches = [(i, texts[i:i + batch_size]) for i in range(0, len(texts), batch_size)] + total_batches = len(batches) + + # Track completed batches for progress reporting + completed_batches = 0 + completed_batches_lock = asyncio.Lock() + + # Semaphore for concurrency control + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_batch(batch_idx: int, batch: List[str]) -> tuple[int, List[List[float]]]: + """Process a single batch with semaphore-controlled concurrency""" + nonlocal completed_batches + async with semaphore: + try: + # Rate limiting + if batch_idx > 0: + await asyncio.sleep(self.request_interval) + + # Generate embeddings + response = await client.embeddings.create( + model=model_config.name, + input=batch, + dimensions=expected_dims + ) + + batch_embeddings = [data.embedding for data in response.data] + logger.info(f"Generated embeddings for batch {batch_idx//batch_size + 1}/{total_batches}") + + # Update completed batch count and progress + async with completed_batches_lock: + completed_batches += 1 + if progress_callback: + progress_callback(completed_batches, total_batches) + + return (batch_idx, batch_embeddings) + + except Exception as e: + logger.error(f"Error generating embeddings for batch {batch_idx//batch_size + 1}: {e}") + # Fill with zeros as fallback + zero_embedding = [0.0] * expected_dims + + # Update completed batch count and progress even for failed batches + async with completed_batches_lock: + completed_batches += 1 + if progress_callback: + progress_callback(completed_batches, total_batches) + + return (batch_idx, [zero_embedding] * len(batch)) + + # Process all batches concurrently with semaphore limit + tasks = [process_batch(i, batch) for i, batch in batches] + results = await asyncio.gather(*tasks) + + # Final progress update + if progress_callback: + progress_callback(total_batches, total_batches) + + # Sort results by batch index and flatten + results.sort(key=lambda x: x[0]) + all_embeddings = [] + for _, embeddings in results: + all_embeddings.extend(embeddings) return all_embeddings @@ -268,7 +321,8 @@ class EmbeddingService: test_embedding = await self.generate_embeddings_batch( ["health check"], model_name, - batch_size=1 + batch_size=1, + progress_callback=None ) return { diff --git a/ml_new/training/requirements.txt b/ml_new/training/requirements.txt index d14ecd3..c5b83d4 100644 --- a/ml_new/training/requirements.txt +++ b/ml_new/training/requirements.txt @@ -13,4 +13,5 @@ sqlalchemy==2.0.23 asyncpg==0.29.0 toml==0.10.2 aiohttp==3.9.0 -python-dotenv==1.1.0 \ No newline at end of file +python-dotenv==1.1.0 +pyarrow==15.0.0 \ No newline at end of file diff --git a/packages/ml_panel/src/components/TaskMonitor.tsx b/packages/ml_panel/src/components/TaskMonitor.tsx index 3e8cb56..344c17a 100644 --- a/packages/ml_panel/src/components/TaskMonitor.tsx +++ b/packages/ml_panel/src/components/TaskMonitor.tsx @@ -30,7 +30,7 @@ export function TaskMonitor() { const params = statusFilter === "all" ? {} : { status: statusFilter }; return apiClient.getTasks(params.status, 50); }, - refetchInterval: 5000 // Refresh every 5 seconds + refetchInterval: 500 }); const getStatusIcon = (status: string) => {