diff --git a/.kilocode/rules/common.md b/.kilocode/rules/common.md index 0464778..0028cae 100644 --- a/.kilocode/rules/common.md +++ b/.kilocode/rules/common.md @@ -2,4 +2,4 @@ 1. Always use bun as package manager. -2. Always write comments in English. +2. IMPORTANT: **Always write comments inside your code (through tool calls) in English, and respond (including description of changes, requests and questions to user ) in the same language as the user's query.** diff --git a/ml_new/training/README.md b/ml_new/training/README.md deleted file mode 100644 index 189315b..0000000 --- a/ml_new/training/README.md +++ /dev/null @@ -1,140 +0,0 @@ -# ML Training Service - -A FastAPI-based ML training service for dataset building, embedding generation, and experiment management. - -## Architecture - -The service is organized into modular components: - -``` -ml_new/training/ -├── main.py # FastAPI application entry point -├── models.py # Pydantic data models -├── config_loader.py # Configuration loading from TOML -├── database.py # Database connection and operations -├── embedding_service.py # Embedding generation service -├── dataset_service.py # Dataset building logic -├── api_routes.py # API endpoint definitions -├── embedding_models.toml # Embedding model configurations -└── requirements.txt # Python dependencies -``` - -## Key Components - -### 1. Main Application (`main.py`) -- FastAPI app initialization -- CORS middleware configuration -- Service dependency injection -- Startup/shutdown event handlers - -### 2. Data Models (`models.py`) -- `DatasetBuildRequest`: Request model for dataset building -- `DatasetBuildResponse`: Response model for dataset building -- `DatasetRecord`: Individual dataset record structure -- `EmbeddingModelInfo`: Embedding model configuration - -### 3. Configuration (`config_loader.py`) -- Loads embedding model configurations from TOML -- Manages model parameters (dimensions, API endpoints, etc.) - -### 4. Database Layer (`database.py`) -- PostgreSQL connection management -- CRUD operations for video metadata, user labels, and embeddings -- Optimized batch queries to avoid N+1 problems - -### 5. Embedding Service (`embedding_service.py`) -- Integration with OpenAI-compatible embedding APIs -- Text preprocessing and checksum generation -- Batch embedding generation with rate limiting - -### 6. Dataset Building (`dataset_service.py`) -- Complete dataset construction workflow: - 1. Pull raw text from database - 2. Text preprocessing (placeholder) - 3. Batch embedding generation with deduplication - 4. Embedding storage and caching - 5. Final dataset compilation with labels - -### 7. API Routes (`api_routes.py`) -- `/api/v1/health`: Health check -- `/api/v1/models/embedding`: List available embedding models -- `/api/v1/dataset/build`: Build new dataset -- `/api/v1/dataset/{id}`: Retrieve built dataset -- `/api/v1/datasets`: List all datasets -- `/api/v1/dataset/{id}`: Delete dataset - -## Dataset Building Flow - -1. **Model Selection**: Choose embedding model from TOML configuration -2. **Data Retrieval**: Pull video metadata and user labels from PostgreSQL -3. **Text Processing**: Combine title, description, and tags -4. **Deduplication**: Generate checksums to avoid duplicate embeddings -5. **Batch Processing**: Generate embeddings for new texts only -6. **Storage**: Store embeddings in database with caching -7. **Final Assembly**: Combine embeddings with labels using consensus mechanism - -## Configuration - -### Embedding Models (`embedding_models.toml`) -```toml -[text-embedding-3-large] -name = "text-embedding-3-large" -dimensions = 3072 -type = "openai" -api_endpoint = "https://api.openai.com/v1/embeddings" -max_tokens = 8192 -max_batch_size = 100 -``` - -### Environment Variables -- `DATABASE_URL`: PostgreSQL connection string -- `OPENAI_API_KEY`: OpenAI API key for embedding generation - -## Usage - -### Start the Service -```bash -cd ml_new/training -python main.py -``` - -### Build a Dataset -```bash -curl -X POST "http://localhost:8322/v1/dataset/build" \ - -H "Content-Type: application/json" \ - -d '{ - "aid_list": [170001, 170002, 170003], - "embedding_model": "text-embedding-3-large", - "force_regenerate": false - }' -``` - -### Check Health -```bash -curl "http://localhost:8322/v1/health" -``` - -### List Embedding Models -```bash -curl "http://localhost:8322/v1/models/embedding" -``` - -## Features - -- **High Performance**: Optimized database queries with batch operations -- **Deduplication**: Text-level deduplication using MD5 checksums -- **Consensus Labels**: Majority vote mechanism for user annotations -- **Batch Processing**: Efficient embedding generation and storage -- **Error Handling**: Comprehensive error handling and logging -- **Async Support**: Fully asynchronous operations for scalability -- **CORS Enabled**: Ready for frontend integration - -## Production Considerations - -- Replace in-memory dataset storage with database -- Add authentication and authorization -- Implement rate limiting for API endpoints -- Add monitoring and metrics collection -- Configure proper logging levels -- Set up database connection pooling -- Add API documentation with OpenAPI/Swagger \ No newline at end of file diff --git a/ml_new/training/api_routes.py b/ml_new/training/api_routes.py index ea68779..5c772bc 100644 --- a/ml_new/training/api_routes.py +++ b/ml_new/training/api_routes.py @@ -2,18 +2,19 @@ API routes for the ML training service """ -import logging import uuid +from typing import Optional -from fastapi import APIRouter, HTTPException, BackgroundTasks +from fastapi import APIRouter, HTTPException from fastapi.responses import JSONResponse from config_loader import config_loader -from models import DatasetBuildRequest, DatasetBuildResponse +from models import DatasetBuildRequest, DatasetBuildResponse, TaskStatus, TaskStatusResponse, TaskListResponse from dataset_service import DatasetBuilder +from logger_config import get_logger -logger = logging.getLogger(__name__) +logger = get_logger(__name__) # Create router router = APIRouter(prefix="/v1") @@ -80,8 +81,8 @@ async def get_embedding_models(): @router.post("/dataset/build", response_model=DatasetBuildResponse) -async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks: BackgroundTasks): - """Build dataset endpoint""" +async def build_dataset_endpoint(request: DatasetBuildRequest): + """Build dataset endpoint with task tracking""" if not dataset_builder: raise HTTPException(status_code=503, detail="Dataset builder not available") @@ -91,20 +92,22 @@ async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks: raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}") dataset_id = str(uuid.uuid4()) - # Start background task for dataset building - background_tasks.add_task( - dataset_builder.build_dataset, + + # Start task-based dataset building + task_id = await dataset_builder.start_dataset_build_task( dataset_id, request.aid_list, request.embedding_model, - request.force_regenerate + request.force_regenerate, + request.description ) return DatasetBuildResponse( dataset_id=dataset_id, total_records=len(request.aid_list), status="started", - message="Dataset building started" + message=f"Dataset building started with task ID: {task_id}", + description=request.description ) @@ -126,6 +129,7 @@ async def get_dataset_endpoint(dataset_id: str): return { "dataset_id": dataset_id, "dataset": dataset_info["dataset"], + "description": dataset_info.get("description"), "stats": dataset_info["stats"], "created_at": dataset_info["created_at"] } @@ -143,6 +147,7 @@ async def list_datasets(): if "error" not in dataset_info: datasets.append({ "dataset_id": dataset_id, + "description": dataset_info.get("description"), "stats": dataset_info["stats"], "created_at": dataset_info["created_at"] }) @@ -171,7 +176,16 @@ async def list_datasets_endpoint(): raise HTTPException(status_code=503, detail="Dataset builder not available") datasets = dataset_builder.list_datasets() - return {"datasets": datasets} + # Add description to each dataset + datasets_with_description = [] + for dataset in datasets: + dataset_info = dataset_builder.get_dataset(dataset["dataset_id"]) + if dataset_info and "description" in dataset_info: + dataset["description"] = dataset_info["description"] + else: + dataset["description"] = None + datasets_with_description.append(dataset) + return {"datasets": datasets_with_description} @router.get("/datasets/stats") @@ -193,4 +207,111 @@ async def cleanup_datasets_endpoint(max_age_days: int = 30): raise HTTPException(status_code=503, detail="Dataset builder not available") await dataset_builder.cleanup_old_datasets(max_age_days) - return {"message": f"Cleanup completed for datasets older than {max_age_days} days"} \ No newline at end of file + return {"message": f"Cleanup completed for datasets older than {max_age_days} days"} + + +# Task Status Endpoints + +@router.get("/tasks/{task_id}", response_model=TaskStatusResponse) +async def get_task_status_endpoint(task_id: str): + """Get status of a specific task""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + task_status = dataset_builder.get_task_status(task_id) + if not task_status: + raise HTTPException(status_code=404, detail="Task not found") + + # Convert to response model + progress_dict = None + if task_status.progress: + progress_dict = { + "current_step": task_status.progress.current_step, + "total_steps": task_status.progress.total_steps, + "completed_steps": task_status.progress.completed_steps, + "percentage": task_status.progress.percentage, + "message": task_status.progress.message + } + + return TaskStatusResponse( + task_id=task_status.task_id, + status=task_status.status, + progress=progress_dict, + result=task_status.result, + error=task_status.error_message, + created_at=task_status.created_at, + started_at=task_status.started_at, + completed_at=task_status.completed_at + ) + + +@router.get("/tasks", response_model=TaskListResponse) +async def list_tasks_endpoint(status: Optional[TaskStatus] = None, limit: int = 50): + """List all tasks, optionally filtered by status""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + tasks = dataset_builder.list_tasks(status_filter=status) + + # Limit results + if limit > 0: + tasks = tasks[:limit] + + # Convert to response models + task_responses = [] + for task_status in tasks: + progress_dict = None + if task_status.progress: + progress_dict = { + "current_step": task_status.progress.current_step, + "total_steps": task_status.progress.total_steps, + "completed_steps": task_status.progress.completed_steps, + "percentage": task_status.progress.percentage, + "message": task_status.progress.message + } + + task_responses.append(TaskStatusResponse( + task_id=task_status.task_id, + status=task_status.status, + progress=progress_dict, + result=task_status.result, + error=task_status.error_message, + created_at=task_status.created_at, + started_at=task_status.started_at, + completed_at=task_status.completed_at + )) + + # Get statistics + stats = dataset_builder.get_task_statistics() + + return TaskListResponse( + tasks=task_responses, + total_count=stats["total_tasks"], + pending_count=stats["status_counts"][TaskStatus.PENDING], + running_count=stats["status_counts"][TaskStatus.RUNNING], + completed_count=stats["status_counts"][TaskStatus.COMPLETED], + failed_count=stats["status_counts"][TaskStatus.FAILED] + ) + + +@router.get("/tasks/stats") +async def get_task_statistics_endpoint(): + """Get statistics about all tasks""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + return dataset_builder.get_task_statistics() + + +@router.post("/tasks/cleanup") +async def cleanup_tasks_endpoint(max_age_hours: int = 24): + """Clean up completed/failed tasks older than specified hours""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + cleaned_count = await dataset_builder.cleanup_completed_tasks(max_age_hours) + return {"message": f"Cleaned up {cleaned_count} tasks older than {max_age_hours} hours"} \ No newline at end of file diff --git a/ml_new/training/config_loader.py b/ml_new/training/config_loader.py index f7a93bb..5977c2b 100644 --- a/ml_new/training/config_loader.py +++ b/ml_new/training/config_loader.py @@ -6,9 +6,9 @@ import toml import os from typing import Dict from pydantic import BaseModel -import logging +from logger_config import get_logger -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class EmbeddingModelConfig(BaseModel): @@ -19,6 +19,8 @@ class EmbeddingModelConfig(BaseModel): max_tokens: int = 8191 max_batch_size: int = 8 api_key_env: str = "OPENAI_API_KEY" + model_path: str = "" + tokenizer_name: str = "" class ConfigLoader: @@ -31,6 +33,7 @@ class ConfigLoader: self.config_path = config_path self.embedding_models: Dict[str, EmbeddingModelConfig] = {} + self.selected_model: str = None self._load_config() def _load_config(self): @@ -51,6 +54,8 @@ class ConfigLoader: self.embedding_models[model_key] = EmbeddingModelConfig( **model_data ) + + self.selected_model = config_data.get("model", list(self.embedding_models.keys())[0]) logger.info( f"Loaded {len(self.embedding_models)} embedding models from {self.config_path}" @@ -58,6 +63,10 @@ class ConfigLoader: except Exception as e: logger.error(f"Failed to load config from {self.config_path}: {e}") + + def get_selected_model(self) -> str: + """Get selected model for health check""" + return self.selected_model def get_embedding_models(self) -> Dict[str, EmbeddingModelConfig]: """Get all available embedding models""" diff --git a/ml_new/training/database.py b/ml_new/training/database.py index 50f2d53..c0a73de 100644 --- a/ml_new/training/database.py +++ b/ml_new/training/database.py @@ -7,13 +7,13 @@ import hashlib from typing import List, Dict, Optional, Any from datetime import datetime import asyncpg -import logging from config_loader import config_loader from dotenv import load_dotenv +from logger_config import get_logger -load_dotenv() +load_dotenv() -logger = logging.getLogger(__name__) +logger = get_logger(__name__) # Database configuration DATABASE_URL = os.getenv("DATABASE_URL") @@ -120,7 +120,7 @@ class DatabaseManager: async with self.pool.acquire() as conn: query = """ - SELECT data_checksum, vec_2048, vec_1536, vec_1024, created_at + SELECT data_checksum, dimensions, vec_2048, vec_1536, vec_1024, created_at FROM internal.embeddings WHERE model_name = $1 AND data_checksum = ANY($2::text[]) """ @@ -177,14 +177,15 @@ class DatabaseManager: query = f""" INSERT INTO internal.embeddings - (model_name, data_checksum, {vec_column}, created_at) - VALUES ($1, $2, $3, $4) - ON CONFLICT (data_checksum) DO NOTHING + (model_name, dimensions, data_checksum, {vec_column}, created_at) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (model_name, dimensions, data_checksum) DO NOTHING """ await conn.execute( query, data["model_name"], + data["dimensions"], data["checksum"], vector_str, datetime.now(), diff --git a/ml_new/training/dataset_service.py b/ml_new/training/dataset_service.py index 7977780..fc55e17 100644 --- a/ml_new/training/dataset_service.py +++ b/ml_new/training/dataset_service.py @@ -2,20 +2,21 @@ Dataset building service - handles the complete dataset construction flow """ -import os import json -import logging import asyncio from pathlib import Path from typing import List, Dict, Any, Optional from datetime import datetime +import threading from database import DatabaseManager from embedding_service import EmbeddingService from config_loader import config_loader +from logger_config import get_logger +from models import TaskStatus, DatasetBuildTaskStatus, TaskProgress -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class DatasetBuilder: @@ -29,6 +30,11 @@ class DatasetBuilder: # Load existing datasets from file system self.dataset_storage: Dict[str, Dict] = self._load_all_datasets() + + # Task status tracking + self._task_status_lock = threading.Lock() + self.task_statuses: Dict[str, DatasetBuildTaskStatus] = {} + self.running_tasks: Dict[str, asyncio.Task] = {} def _get_dataset_file_path(self, dataset_id: str) -> Path: @@ -80,6 +86,134 @@ class DatasetBuilder: return datasets + def _create_task_status(self, task_id: str, dataset_id: str, aid_list: List[int], + embedding_model: str, force_regenerate: bool) -> DatasetBuildTaskStatus: + """Create initial task status""" + with self._task_status_lock: + task_status = DatasetBuildTaskStatus( + task_id=task_id, + status=TaskStatus.PENDING, + dataset_id=dataset_id, + aid_list=aid_list, + embedding_model=embedding_model, + force_regenerate=force_regenerate, + created_at=datetime.now(), + progress=TaskProgress( + current_step="initialized", + total_steps=7, + completed_steps=0, + percentage=0.0, + message="Task initialized" + ) + ) + self.task_statuses[task_id] = task_status + return task_status + + + def _update_task_status(self, task_id: str, **kwargs): + """Update task status with new values""" + with self._task_status_lock: + if task_id in self.task_statuses: + task_status = self.task_statuses[task_id] + for key, value in kwargs.items(): + if hasattr(task_status, key): + setattr(task_status, key, value) + self.task_statuses[task_id] = task_status + + + def _update_task_progress(self, task_id: str, current_step: str, completed_steps: int, + message: str = None, percentage: float = None): + """Update task progress""" + with self._task_status_lock: + if task_id in self.task_statuses: + task_status = self.task_statuses[task_id] + if percentage is not None: + progress_percentage = percentage + else: + progress_percentage = (completed_steps / task_status.progress.total_steps) * 100 if task_status.progress else 0.0 + + task_status.progress = TaskProgress( + current_step=current_step, + total_steps=task_status.progress.total_steps if task_status.progress else 7, + completed_steps=completed_steps, + percentage=progress_percentage, + message=message + ) + self.task_statuses[task_id] = task_status + + + def get_task_status(self, task_id: str) -> Optional[DatasetBuildTaskStatus]: + """Get task status by task ID""" + with self._task_status_lock: + return self.task_statuses.get(task_id) + + + def list_tasks(self, status_filter: Optional[TaskStatus] = None) -> List[DatasetBuildTaskStatus]: + """List all tasks, optionally filtered by status""" + with self._task_status_lock: + tasks = list(self.task_statuses.values()) + if status_filter: + tasks = [task for task in tasks if task.status == status_filter] + # Sort by creation time (newest first) + tasks.sort(key=lambda x: x.created_at, reverse=True) + return tasks + + + def get_task_statistics(self) -> Dict[str, Any]: + """Get statistics about all tasks""" + with self._task_status_lock: + total_tasks = len(self.task_statuses) + status_counts = { + TaskStatus.PENDING: 0, + TaskStatus.RUNNING: 0, + TaskStatus.COMPLETED: 0, + TaskStatus.FAILED: 0, + TaskStatus.CANCELLED: 0 + } + + for task_status in self.task_statuses.values(): + status_counts[task_status.status] += 1 + + return { + "total_tasks": total_tasks, + "status_counts": status_counts, + "running_tasks": status_counts[TaskStatus.RUNNING] + } + + + async def cleanup_completed_tasks(self, max_age_hours: int = 24): + """Clean up completed/failed tasks older than specified hours""" + cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600) + cleaned_count = 0 + + with self._task_status_lock: + tasks_to_remove = [] + + for task_id, task_status in self.task_statuses.items(): + if task_status.completed_at: + try: + completed_time = task_status.completed_at.timestamp() + if completed_time < cutoff_time: + tasks_to_remove.append(task_id) + except Exception as e: + logger.warning(f"Failed to check completion time for task {task_id}: {e}") + + for task_id in tasks_to_remove: + # Remove from task statuses + del self.task_statuses[task_id] + + # Remove from running tasks if still there + if task_id in self.running_tasks: + del self.running_tasks[task_id] + + cleaned_count += 1 + + if cleaned_count > 0: + logger.info(f"Cleaned up {cleaned_count} old tasks") + + return cleaned_count + + async def cleanup_old_datasets(self, max_age_days: int = 30): """Remove datasets older than specified days""" try: @@ -110,43 +244,54 @@ class DatasetBuilder: except Exception as e: logger.error(f"Failed to cleanup old datasets: {e}") - async def build_dataset(self, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool = False) -> str: + async def build_dataset_with_task_tracking(self, task_id: str, dataset_id: str, aid_list: List[int], + embedding_model: str, force_regenerate: bool = False, + description: Optional[str] = None) -> str: """ - Build dataset with the specified flow: - 1. Select embedding model (from TOML config) - 2. Pull raw text from database - 3. Preprocess (placeholder for now) - 4. Batch get embeddings (deduplicate by hash, skip if already in embeddings table) - 5. Write to embeddings table - 6. Pull all needed embeddings to create dataset with format: embeddings, label + Build dataset with task status tracking + + Steps: + 1. Initialize task status + 2. Select embedding model (from TOML config) + 3. Pull raw text from database + 4. Preprocess (placeholder for now) + 5. Batch get embeddings (deduplicate by hash, skip if already in embeddings table) + 6. Write to embeddings table + 7. Pull all needed embeddings to create dataset with format: embeddings, label """ + # Update task status to running + self._update_task_status(task_id, status=TaskStatus.RUNNING, started_at=datetime.now()) + try: - logger.info(f"Starting dataset building task {dataset_id}") + logger.info(f"Starting dataset building task {dataset_id} (task_id: {task_id})") + # Step 1: Get model configuration EMBEDDING_MODELS = config_loader.get_embedding_models() - - # Get model configuration if embedding_model not in EMBEDDING_MODELS: raise ValueError(f"Invalid embedding model: {embedding_model}") model_config = EMBEDDING_MODELS[embedding_model] + self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database") - # Step 1: Get video metadata from database + # Step 2: Get video metadata from database metadata = await self.db_manager.get_video_metadata(aid_list) + self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database") - # Step 2: Get user labels + # Step 3: Get user labels labels = await self.db_manager.get_user_labels(aid_list) + self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums") - # Step 3: Prepare text data and checksums + # Step 4: Prepare text data and checksums text_data = [] + total_aids = len(aid_list) - for aid in aid_list: + for i, aid in enumerate(aid_list): if aid in metadata: # Combine title, description, tags combined_text = self.embedding_service.combine_video_text( metadata[aid]['title'], - metadata[aid]['description'], + metadata[aid]['description'], metadata[aid]['tags'] ) @@ -158,12 +303,24 @@ class DatasetBuilder: 'text': combined_text, 'checksum': checksum }) + + # Update progress for text preparation + if i % 10 == 0 or i == total_aids - 1: # Update every 10 items or at the end + progress_pct = 3 + (i + 1) / total_aids + self._update_task_progress( + task_id, + "preparing_text", + min(3, int(progress_pct)), + f"Prepared {i + 1}/{total_aids} text entries" + ) - # Step 4: Check existing embeddings + self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings") + + # Step 5: Check existing embeddings checksums = [item['checksum'] for item in text_data] existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model) - # Step 5: Generate new embeddings for texts that don't have them + # Step 6: Generate new embeddings for texts that don't have them new_embeddings_needed = [] for item in text_data: if item['checksum'] not in existing_embeddings or force_regenerate: @@ -171,13 +328,20 @@ class DatasetBuilder: new_embeddings_count = 0 if new_embeddings_needed: + self._update_task_progress( + task_id, + "generating_embeddings", + 5, + f"Generating {len(new_embeddings_needed)} new embeddings" + ) + logger.info(f"Generating {len(new_embeddings_needed)} new embeddings") generated_embeddings = await self.embedding_service.generate_embeddings_batch( - new_embeddings_needed, + new_embeddings_needed, embedding_model ) - # Step 6: Store new embeddings in database + # Step 7: Store new embeddings in database embeddings_to_store = [] for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)): checksum = self.embedding_service.create_text_checksum(text) @@ -198,11 +362,13 @@ class DatasetBuilder: f'vec_{model_config.dimensions}': emb_data['vector'] } - # Step 7: Build final dataset + self._update_task_progress(task_id, "building_dataset", 6, "Building final dataset") + + # Step 8: Build final dataset dataset = [] inconsistent_count = 0 - for item in text_data: + for i, item in enumerate(text_data): aid = item['aid'] checksum = item['checksum'] @@ -241,6 +407,16 @@ class DatasetBuilder: 'inconsistent': inconsistent, 'text_checksum': checksum }) + + # Update progress for dataset building + if i % 10 == 0 or i == len(text_data) - 1: # Update every 10 items or at the end + progress_pct = 6 + (i + 1) / len(text_data) + self._update_task_progress( + task_id, + "building_dataset", + min(6, int(progress_pct)), + f"Built {i + 1}/{len(text_data)} dataset records" + ) reused_count = len(dataset) - new_embeddings_count @@ -249,6 +425,7 @@ class DatasetBuilder: # Prepare dataset data dataset_data = { 'dataset': dataset, + 'description': description, 'stats': { 'total_records': len(dataset), 'new_embeddings': new_embeddings_count, @@ -259,6 +436,8 @@ class DatasetBuilder: 'created_at': datetime.now().isoformat() } + self._update_task_progress(task_id, "saving_dataset", 7, "Saving dataset to storage") + # Save to file and memory cache if self._save_dataset_to_file(dataset_id, dataset_data): self.dataset_storage[dataset_id] = dataset_data @@ -267,11 +446,46 @@ class DatasetBuilder: logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only") self.dataset_storage[dataset_id] = dataset_data + # Update task status to completed + result = { + 'dataset_id': dataset_id, + 'stats': dataset_data['stats'] + } + + self._update_task_status( + task_id, + status=TaskStatus.COMPLETED, + completed_at=datetime.now(), + result=result, + progress=TaskProgress( + current_step="completed", + total_steps=7, + completed_steps=7, + percentage=100.0, + message="Dataset building completed successfully" + ) + ) + return dataset_id except Exception as e: logger.error(f"Dataset building failed for {dataset_id}: {str(e)}") + # Update task status to failed + self._update_task_status( + task_id, + status=TaskStatus.FAILED, + completed_at=datetime.now(), + error_message=str(e), + progress=TaskProgress( + current_step="failed", + total_steps=7, + completed_steps=0, + percentage=0.0, + message=f"Task failed: {str(e)}" + ) + ) + # Store error information error_data = { 'error': str(e), @@ -283,6 +497,30 @@ class DatasetBuilder: self.dataset_storage[dataset_id] = error_data raise + + async def start_dataset_build_task(self, dataset_id: str, aid_list: List[int], + embedding_model: str, force_regenerate: bool = False, + description: Optional[str] = None) -> str: + """ + Start a dataset building task and return task ID for status tracking + """ + import uuid + task_id = str(uuid.uuid4()) + + # Create task status + task_status = self._create_task_status(task_id, dataset_id, aid_list, embedding_model, force_regenerate) + + # Start the actual task + task = asyncio.create_task( + self.build_dataset_with_task_tracking(task_id, dataset_id, aid_list, embedding_model, force_regenerate, description) + ) + + # Store the running task + with self._task_status_lock: + self.running_tasks[task_id] = task + + return task_id + def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]: """Get built dataset by ID""" # First check memory cache diff --git a/ml_new/training/embedding_models.toml b/ml_new/training/embedding_models.toml index f650fd0..f9ba291 100644 --- a/ml_new/training/embedding_models.toml +++ b/ml_new/training/embedding_models.toml @@ -1,6 +1,6 @@ # Embedding Models Configuration -model = "qwen3-embedding" +model = "jina-embedding-v3-m2v" [models.qwen3-embedding] name = "text-embedding-v4" @@ -10,3 +10,11 @@ api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1" max_tokens = 8192 max_batch_size = 10 api_key_env = "ALIYUN_KEY" + +[models.jina-embedding-v3-m2v] +name = "jina-embedding-v3-m2v-1024" +dimensions = 1024 +type = "legacy" +model_path = "../../model/embedding/model.onnx" +tokenizer_name = "jinaai/jina-embeddings-v3" +max_batch_size = 128 diff --git a/ml_new/training/embedding_service.py b/ml_new/training/embedding_service.py index 8bc16b8..72c0cdc 100644 --- a/ml_new/training/embedding_service.py +++ b/ml_new/training/embedding_service.py @@ -1,18 +1,22 @@ """ -Embedding service for generating embeddings using OpenAI-compatible API +Embedding service for generating embeddings using OpenAI-compatible API and legacy methods """ import asyncio import hashlib from typing import List, Dict, Any, Optional -import logging from openai import AsyncOpenAI import os from config_loader import config_loader from dotenv import load_dotenv +import torch +import numpy as np +from transformers import AutoTokenizer +import onnxruntime as ort +from logger_config import get_logger -load_dotenv() +load_dotenv() -logger = logging.getLogger(__name__) +logger = get_logger(__name__) class EmbeddingService: def __init__(self): @@ -21,7 +25,10 @@ class EmbeddingService: # Initialize OpenAI client (will be configured per model) self.clients: Dict[str, AsyncOpenAI] = {} + self.legacy_models: Dict[str, Dict[str, Any]] = {} + self._initialize_clients() + self._initialize_legacy_models() # Rate limiting self.max_requests_per_minute = int(os.getenv("MAX_REQUESTS_PER_MINUTE", "100")) @@ -30,19 +37,93 @@ class EmbeddingService: def _initialize_clients(self): """Initialize OpenAI clients for different models/endpoints""" for model_name, model_config in self.embedding_models.items(): - if model_config.type == "openai-compatible": - # Get API key from environment variable specified in config - api_key = os.getenv(model_config.api_key_env) + if model_config.type != "openai-compatible": + continue + + # Get API key from environment variable specified in config + api_key = os.getenv(model_config.api_key_env) + + self.clients[model_name] = AsyncOpenAI( + api_key=api_key, + base_url=model_config.api_endpoint + ) + logger.info(f"Initialized client for model {model_name}") + + def _initialize_legacy_models(self): + """Initialize legacy ONNX models for embedding generation""" + for model_name, model_config in self.embedding_models.items(): + if model_config.type != "legacy": + continue + try: + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_config.tokenizer_name) - self.clients[model_name] = AsyncOpenAI( - api_key=api_key, - base_url=model_config.api_endpoint - ) - logger.info(f"Initialized client for model {model_name}") + + + # Load ONNX model + session = ort.InferenceSession(model_config.model_path) + + self.legacy_models[model_name] = { + "tokenizer": tokenizer, + "session": session, + "config": model_config + } + logger.info(f"Initialized legacy model {model_name}") + + except Exception as e: + logger.error(f"Failed to initialize legacy model {model_name}: {e}") + + def get_jina_embeddings_1024(self, texts: List[str], model_name: str) -> np.ndarray: + """Generate embeddings using legacy Jina method (same as ml/api/main.py)""" + if model_name not in self.legacy_models: + raise ValueError(f"Legacy model '{model_name}' not initialized") + + legacy_model = self.legacy_models[model_name] + tokenizer = legacy_model["tokenizer"] + session = legacy_model["session"] + + # Encode inputs using tokenizer + encoded_inputs = tokenizer( + texts, + add_special_tokens=False, # Don't add special tokens (consistent with JS) + return_attention_mask=False, + return_tensors=None # Return native Python lists for easier processing + ) + input_ids = encoded_inputs["input_ids"] # Shape: [batch_size, seq_len_i] (variable length per sample) + + # Calculate offsets (consistent with JS cumsum logic) + # Get token length for each sample first + lengths = [len(ids) for ids in input_ids] + # Calculate cumulative sum (exclude last sample) + cumsum = [] + current_sum = 0 + for l in lengths[:-1]: # Only accumulate first n-1 samples + current_sum += l + cumsum.append(current_sum) + # Build offsets: start with 0, followed by cumulative sums + offsets = [0] + cumsum # Shape: [batch_size] + + # Flatten input_ids to 1D array + flattened_input_ids = [] + for ids in input_ids: + flattened_input_ids.extend(ids) # Directly concatenate all token ids + flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64) + + # Prepare ONNX inputs (consistent tensor shapes with JS) + inputs = { + "input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids), + "offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64)) + } + + # Run model inference + outputs = session.run(None, inputs) + embeddings = outputs[0] # Assume first output is embeddings, shape: [batch_size, embedding_dim] + + return torch.tensor(embeddings, dtype=torch.float32).numpy() async def generate_embeddings_batch( - self, - texts: List[str], + self, + texts: List[str], model: str, batch_size: Optional[int] = None ) -> List[List[float]]: @@ -54,6 +135,64 @@ class EmbeddingService: model_config = self.embedding_models[model] + # Handle different model types + if model_config.type == "legacy": + return self._generate_legacy_embeddings_batch(texts, model, batch_size) + elif model_config.type == "openai-compatible": + return await self._generate_openai_embeddings_batch(texts, model, batch_size) + else: + raise ValueError(f"Unsupported model type: {model_config.type}") + + def _generate_legacy_embeddings_batch( + self, + texts: List[str], + model: str, + batch_size: Optional[int] = None + ) -> List[List[float]]: + """Generate embeddings using legacy ONNX model""" + if model not in self.legacy_models: + raise ValueError(f"Legacy model '{model}' not initialized") + + model_config = self.embedding_models[model] + + # Use model's max_batch_size if not specified + if batch_size is None: + batch_size = model_config.max_batch_size + + expected_dims = model_config.dimensions + all_embeddings = [] + + # Process in batches + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + + try: + # Generate embeddings using legacy method + embeddings = self.get_jina_embeddings_1024(batch, model) + + # Convert to list of lists (expected format) + batch_embeddings = embeddings.tolist() + all_embeddings.extend(batch_embeddings) + + logger.info(f"Generated legacy embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}") + + except Exception as e: + logger.error(f"Error generating legacy embeddings for batch {i//batch_size + 1}: {e}") + # Fill with zeros as fallback + zero_embedding = [0.0] * expected_dims + all_embeddings.extend([zero_embedding] * len(batch)) + + return all_embeddings + + async def _generate_openai_embeddings_batch( + self, + texts: List[str], + model: str, + batch_size: Optional[int] = None + ) -> List[List[float]]: + """Generate embeddings using OpenAI-compatible API""" + model_config = self.embedding_models[model] + # Use model's max_batch_size if not specified if batch_size is None: batch_size = model_config.max_batch_size @@ -115,11 +254,19 @@ class EmbeddingService: async def health_check(self) -> Dict[str, Any]: """Check if embedding service is healthy""" try: + if not self.embedding_models: + return { + "status": "unhealthy", + "service": "embedding_service", + "error": "No embedding models configured" + } + # Test with a simple embedding using the first available model - model_name = list(self.embedding_models.keys())[0] + model_name = config_loader.get_selected_model() + model_config = self.embedding_models[model_name] test_embedding = await self.generate_embeddings_batch( - ["health check"], + ["health check"], model_name, batch_size=1 ) @@ -128,13 +275,16 @@ class EmbeddingService: "status": "healthy", "service": "embedding_service", "model": model_name, + "model_type": model_config.type, "dimensions": len(test_embedding[0]) if test_embedding else 0, - "available_models": list(self.embedding_models.keys()) + "available_models": list(self.embedding_models.keys()), + "legacy_models": list(self.legacy_models.keys()), + "openai_clients": list(self.clients.keys()) } except Exception as e: return { - "status": "unhealthy", + "status": "unhealthy", "service": "embedding_service", "error": str(e) } diff --git a/ml_new/training/logger_config.py b/ml_new/training/logger_config.py new file mode 100644 index 0000000..4970d61 --- /dev/null +++ b/ml_new/training/logger_config.py @@ -0,0 +1,168 @@ +""" +Unified logging configuration for ml_new training project +Provides colorful level formatting with [level]: [msg] format +""" +import logging +import sys + +class ColorfulFormatter(logging.Formatter): + """Custom formatter with colorful level names and [level]: [msg] format""" + + # ANSI color codes for different log levels + COLORS = { + 'DEBUG': '\033[36m', # Cyan + 'INFO': '\033[32m', # Green + 'WARNING': '\033[33m', # Yellow + 'ERROR': '\033[31m', # Red + 'CRITICAL': '\033[35m', # Magenta + } + RESET = '\033[0m' # Reset color + + def format(self, record): + # Extract level name + level_name = record.levelname + + # Apply color to level name + color = self.COLORS.get(level_name, self.RESET) + colored_level = f"{color}{level_name}{self.RESET}" + + # Create new format: [colored_level]: [msg] + # Get the message part only (without level name, module, etc.) + if hasattr(record, 'message'): + message = record.message + else: + message = record.getMessage() + + # Return in the requested format + return f"{colored_level}: {message}" + + +class ColoredConsoleHandler(logging.StreamHandler): + """Console handler that outputs to stderr with colors""" + + def __init__(self, stream=None): + if stream is None: + stream = sys.stderr + super().__init__(stream) + self.setFormatter(ColorfulFormatter()) + + +def setup_logger( + name: str = None, + level: str = "INFO", + log_file: str = None, + console_output: bool = True, + file_output: bool = False +) -> logging.Logger: + """ + Setup a unified logger with colorful formatting + + Args: + name: Logger name (defaults to project name) + level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + log_file: Optional log file path + console_output: Whether to output to console + file_output: Whether to output to file + + Returns: + Configured logger instance + """ + + # Create logger + logger = logging.getLogger(name) + logger.setLevel(getattr(logging, level.upper())) + + # Clear existing handlers to avoid duplicates + logger.handlers.clear() + + # Console handler with colorful formatting + if console_output: + console_handler = ColoredConsoleHandler() + logger.addHandler(console_handler) + + # File handler with standard formatting (no colors in files) + if file_output and log_file: + file_handler = logging.FileHandler(log_file) + file_formatter = logging.Formatter( + '[%(levelname)s]: %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + file_handler.setFormatter(file_formatter) + logger.addHandler(file_handler) + + # Prevent propagation to root logger to avoid duplicate messages + logger.propagate = False + + return logger + + +def get_logger(name: str = None) -> logging.Logger: + """ + Get a logger instance with unified configuration + + Args: + name: Logger name (defaults to 'ml_new.training') + + Returns: + Logger instance + """ + if name is None: + name = 'ml_new.training' + + logger = logging.getLogger(name) + + # Only configure if not already configured + if not logger.handlers: + return setup_logger(name=name) + + return logger + + +# Project-wide logger configuration +def configure_project_logger(): + """Configure the main project logger with all modules""" + + # Configure root logger for the project + root_logger = logging.getLogger('ml_new') + root_logger.setLevel(logging.INFO) + + # Clear existing handlers + root_logger.handlers.clear() + + # Add console handler + console_handler = ColoredConsoleHandler() + root_logger.addHandler(console_handler) + + # Prevent propagation to avoid duplicates + root_logger.propagate = False + + return root_logger + + +# Convenience function for quick setup +def quick_setup(level: str = "INFO") -> logging.Logger: + """ + Quick setup for individual modules + + Args: + level: Logging level + + Returns: + Logger instance for the calling module + """ + # Get the calling module name + import inspect + frame = inspect.currentframe().f_back + module_name = frame.f_globals.get('__name__', 'ml_new.training') + + return setup_logger(name=module_name, level=level) + + +if __name__ == "__main__": + # Test the logger configuration + logger = get_logger('test_logger') + logger.debug("This is a debug message") + logger.info("This is an info message") + logger.warning("This is a warning message") + logger.error("This is an error message") + logger.critical("This is a critical message") \ No newline at end of file diff --git a/ml_new/training/main.py b/ml_new/training/main.py index b510e3d..fa3d930 100644 --- a/ml_new/training/main.py +++ b/ml_new/training/main.py @@ -2,7 +2,6 @@ Main FastAPI application for ML training service """ -import logging import uvicorn from contextlib import asynccontextmanager from fastapi import FastAPI @@ -12,10 +11,9 @@ from database import DatabaseManager from embedding_service import EmbeddingService from dataset_service import DatasetBuilder from api_routes import router, set_dataset_builder +from logger_config import get_logger -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +logger = get_logger(__name__) # Global service instances @@ -105,7 +103,7 @@ def main(): host="0.0.0.0", port=8322, log_level="info", - access_log=True + access_log=True, ) diff --git a/ml_new/training/models.py b/ml_new/training/models.py index 66f084b..ee027d5 100644 --- a/ml_new/training/models.py +++ b/ml_new/training/models.py @@ -12,6 +12,7 @@ class DatasetBuildRequest(BaseModel): aid_list: List[int] = Field(..., description="List of video AIDs") embedding_model: str = Field(..., description="Embedding model name") force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings") + description: Optional[str] = Field(None, description="Optional description for the dataset") class DatasetBuildResponse(BaseModel): @@ -20,6 +21,7 @@ class DatasetBuildResponse(BaseModel): total_records: int status: str message: str + description: Optional[str] = None created_at: Optional[datetime] = None @@ -59,4 +61,67 @@ class EmbeddingModelInfo(BaseModel): type: str api_endpoint: Optional[str] = None max_tokens: Optional[int] = None - max_batch_size: Optional[int] = None \ No newline at end of file + max_batch_size: Optional[int] = None + + +from typing import List, Optional, Dict, Any, Literal +from pydantic import BaseModel, Field +from datetime import datetime +from enum import Enum + + +class TaskStatus(str, Enum): + """Task status enumeration""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class TaskProgress(BaseModel): + """Progress information for a task""" + current_step: str + total_steps: int + completed_steps: int + percentage: float + message: Optional[str] = None + estimated_time_remaining: Optional[float] = None + + +class DatasetBuildTaskStatus(BaseModel): + """Status model for dataset building task""" + task_id: str + status: TaskStatus + dataset_id: Optional[str] = None + aid_list: List[int] + embedding_model: str + force_regenerate: bool + progress: Optional[TaskProgress] = None + error_message: Optional[str] = None + created_at: datetime + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + result: Optional[Dict[str, Any]] = None + + +class TaskStatusResponse(BaseModel): + """Response model for task status endpoint""" + task_id: str + status: TaskStatus + progress: Optional[Dict[str, Any]] = None + result: Optional[Dict[str, Any]] = None + error: Optional[str] = None + created_at: datetime + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + +class TaskListResponse(BaseModel): + """Response model for listing tasks""" + tasks: List[TaskStatusResponse] + total_count: int + pending_count: int + running_count: int + completed_count: int + failed_count: int \ No newline at end of file