diff --git a/.gitignore b/.gitignore index 4ea3796..3bb3b08 100644 --- a/.gitignore +++ b/.gitignore @@ -47,4 +47,4 @@ temp/ meili -.turbo \ No newline at end of file +.turbo/ \ No newline at end of file diff --git a/.kilocode/rules/common.md b/.kilocode/rules/common.md new file mode 100644 index 0000000..0464778 --- /dev/null +++ b/.kilocode/rules/common.md @@ -0,0 +1,5 @@ +# common.md + +1. Always use bun as package manager. + +2. Always write comments in English. diff --git a/ml_new/.gitignore b/ml_new/.gitignore index c0190e1..f3c07f0 100644 --- a/ml_new/.gitignore +++ b/ml_new/.gitignore @@ -1 +1 @@ -datasets \ No newline at end of file +datasets/ \ No newline at end of file diff --git a/ml_new/training/README.md b/ml_new/training/README.md new file mode 100644 index 0000000..189315b --- /dev/null +++ b/ml_new/training/README.md @@ -0,0 +1,140 @@ +# ML Training Service + +A FastAPI-based ML training service for dataset building, embedding generation, and experiment management. + +## Architecture + +The service is organized into modular components: + +``` +ml_new/training/ +├── main.py # FastAPI application entry point +├── models.py # Pydantic data models +├── config_loader.py # Configuration loading from TOML +├── database.py # Database connection and operations +├── embedding_service.py # Embedding generation service +├── dataset_service.py # Dataset building logic +├── api_routes.py # API endpoint definitions +├── embedding_models.toml # Embedding model configurations +└── requirements.txt # Python dependencies +``` + +## Key Components + +### 1. Main Application (`main.py`) +- FastAPI app initialization +- CORS middleware configuration +- Service dependency injection +- Startup/shutdown event handlers + +### 2. Data Models (`models.py`) +- `DatasetBuildRequest`: Request model for dataset building +- `DatasetBuildResponse`: Response model for dataset building +- `DatasetRecord`: Individual dataset record structure +- `EmbeddingModelInfo`: Embedding model configuration + +### 3. Configuration (`config_loader.py`) +- Loads embedding model configurations from TOML +- Manages model parameters (dimensions, API endpoints, etc.) + +### 4. Database Layer (`database.py`) +- PostgreSQL connection management +- CRUD operations for video metadata, user labels, and embeddings +- Optimized batch queries to avoid N+1 problems + +### 5. Embedding Service (`embedding_service.py`) +- Integration with OpenAI-compatible embedding APIs +- Text preprocessing and checksum generation +- Batch embedding generation with rate limiting + +### 6. Dataset Building (`dataset_service.py`) +- Complete dataset construction workflow: + 1. Pull raw text from database + 2. Text preprocessing (placeholder) + 3. Batch embedding generation with deduplication + 4. Embedding storage and caching + 5. Final dataset compilation with labels + +### 7. API Routes (`api_routes.py`) +- `/api/v1/health`: Health check +- `/api/v1/models/embedding`: List available embedding models +- `/api/v1/dataset/build`: Build new dataset +- `/api/v1/dataset/{id}`: Retrieve built dataset +- `/api/v1/datasets`: List all datasets +- `/api/v1/dataset/{id}`: Delete dataset + +## Dataset Building Flow + +1. **Model Selection**: Choose embedding model from TOML configuration +2. **Data Retrieval**: Pull video metadata and user labels from PostgreSQL +3. **Text Processing**: Combine title, description, and tags +4. **Deduplication**: Generate checksums to avoid duplicate embeddings +5. **Batch Processing**: Generate embeddings for new texts only +6. **Storage**: Store embeddings in database with caching +7. **Final Assembly**: Combine embeddings with labels using consensus mechanism + +## Configuration + +### Embedding Models (`embedding_models.toml`) +```toml +[text-embedding-3-large] +name = "text-embedding-3-large" +dimensions = 3072 +type = "openai" +api_endpoint = "https://api.openai.com/v1/embeddings" +max_tokens = 8192 +max_batch_size = 100 +``` + +### Environment Variables +- `DATABASE_URL`: PostgreSQL connection string +- `OPENAI_API_KEY`: OpenAI API key for embedding generation + +## Usage + +### Start the Service +```bash +cd ml_new/training +python main.py +``` + +### Build a Dataset +```bash +curl -X POST "http://localhost:8322/v1/dataset/build" \ + -H "Content-Type: application/json" \ + -d '{ + "aid_list": [170001, 170002, 170003], + "embedding_model": "text-embedding-3-large", + "force_regenerate": false + }' +``` + +### Check Health +```bash +curl "http://localhost:8322/v1/health" +``` + +### List Embedding Models +```bash +curl "http://localhost:8322/v1/models/embedding" +``` + +## Features + +- **High Performance**: Optimized database queries with batch operations +- **Deduplication**: Text-level deduplication using MD5 checksums +- **Consensus Labels**: Majority vote mechanism for user annotations +- **Batch Processing**: Efficient embedding generation and storage +- **Error Handling**: Comprehensive error handling and logging +- **Async Support**: Fully asynchronous operations for scalability +- **CORS Enabled**: Ready for frontend integration + +## Production Considerations + +- Replace in-memory dataset storage with database +- Add authentication and authorization +- Implement rate limiting for API endpoints +- Add monitoring and metrics collection +- Configure proper logging levels +- Set up database connection pooling +- Add API documentation with OpenAPI/Swagger \ No newline at end of file diff --git a/ml_new/training/api_routes.py b/ml_new/training/api_routes.py new file mode 100644 index 0000000..ea68779 --- /dev/null +++ b/ml_new/training/api_routes.py @@ -0,0 +1,196 @@ +""" +API routes for the ML training service +""" + +import logging +import uuid + +from fastapi import APIRouter, HTTPException, BackgroundTasks +from fastapi.responses import JSONResponse + +from config_loader import config_loader +from models import DatasetBuildRequest, DatasetBuildResponse +from dataset_service import DatasetBuilder + + +logger = logging.getLogger(__name__) + +# Create router +router = APIRouter(prefix="/v1") + +# Global dataset builder instance (will be set by main.py) +dataset_builder: DatasetBuilder = None + + +def set_dataset_builder(builder: DatasetBuilder): + """Set the global dataset builder instance""" + global dataset_builder + dataset_builder = builder + + +@router.get("/health") +async def health_check(): + """Health check endpoint""" + if not dataset_builder: + return JSONResponse( + status_code=503, + content={"status": "unavailable", "message": "Dataset builder not initialized"} + ) + + try: + # Check embedding service health + embedding_health = await dataset_builder.embedding_service.health_check() + except Exception as e: + embedding_health = {"status": "unhealthy", "error": str(e)} + + # Check database connection (pool should already be initialized) + db_status = "disconnected" + if dataset_builder.db_manager.is_connected: + try: + response = await dataset_builder.db_manager.pool.fetch("SELECT 1 FROM information_schema.tables") + db_status = "connected" if response else "disconnected" + except Exception as e: + db_status = f"error: {str(e)}" + + return { + "status": "healthy", + "service": "ml-training-api", + "embedding_service": embedding_health, + "database": db_status, + "available_models": list(config_loader.get_embedding_models().keys()) + } + + +@router.get("/models/embedding") +async def get_embedding_models(): + """Get available embedding models""" + return { + "models": { + name: { + "name": config.name, + "dimensions": config.dimensions, + "type": config.type, + "api_endpoint": config.api_endpoint, + "max_tokens": config.max_tokens, + "max_batch_size": config.max_batch_size + } + for name, config in config_loader.get_embedding_models().items() + } + } + + +@router.post("/dataset/build", response_model=DatasetBuildResponse) +async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks: BackgroundTasks): + """Build dataset endpoint""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + # Validate embedding model + if request.embedding_model not in config_loader.get_embedding_models(): + raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}") + + dataset_id = str(uuid.uuid4()) + # Start background task for dataset building + background_tasks.add_task( + dataset_builder.build_dataset, + dataset_id, + request.aid_list, + request.embedding_model, + request.force_regenerate + ) + + return DatasetBuildResponse( + dataset_id=dataset_id, + total_records=len(request.aid_list), + status="started", + message="Dataset building started" + ) + + +@router.get("/dataset/{dataset_id}") +async def get_dataset_endpoint(dataset_id: str): + """Get built dataset by ID""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + if not dataset_builder.dataset_exists(dataset_id): + raise HTTPException(status_code=404, detail="Dataset not found") + + dataset_info = dataset_builder.get_dataset(dataset_id) + + if "error" in dataset_info: + raise HTTPException(status_code=500, detail=dataset_info["error"]) + + return { + "dataset_id": dataset_id, + "dataset": dataset_info["dataset"], + "stats": dataset_info["stats"], + "created_at": dataset_info["created_at"] + } + + +@router.get("/datasets") +async def list_datasets(): + """List all built datasets""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + datasets = [] + for dataset_id, dataset_info in dataset_builder.dataset_storage.items(): + if "error" not in dataset_info: + datasets.append({ + "dataset_id": dataset_id, + "stats": dataset_info["stats"], + "created_at": dataset_info["created_at"] + }) + + return {"datasets": datasets} + + +@router.delete("/dataset/{dataset_id}") +async def delete_dataset_endpoint(dataset_id: str): + """Delete a built dataset""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + if dataset_builder.delete_dataset(dataset_id): + return {"message": f"Dataset {dataset_id} deleted successfully"} + else: + raise HTTPException(status_code=404, detail="Dataset not found") + + +@router.get("/datasets") +async def list_datasets_endpoint(): + """List all built datasets""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + datasets = dataset_builder.list_datasets() + return {"datasets": datasets} + + +@router.get("/datasets/stats") +async def get_dataset_stats_endpoint(): + """Get overall statistics about stored datasets""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + stats = dataset_builder.get_dataset_stats() + return stats + + +@router.post("/datasets/cleanup") +async def cleanup_datasets_endpoint(max_age_days: int = 30): + """Remove datasets older than specified days""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + await dataset_builder.cleanup_old_datasets(max_age_days) + return {"message": f"Cleanup completed for datasets older than {max_age_days} days"} \ No newline at end of file diff --git a/ml_new/training/config_loader.py b/ml_new/training/config_loader.py new file mode 100644 index 0000000..f7a93bb --- /dev/null +++ b/ml_new/training/config_loader.py @@ -0,0 +1,85 @@ +""" +Configuration loader for embedding models and other settings +""" + +import toml +import os +from typing import Dict +from pydantic import BaseModel +import logging + +logger = logging.getLogger(__name__) + + +class EmbeddingModelConfig(BaseModel): + name: str + dimensions: int + type: str + api_endpoint: str = "https://api.openai.com/v1" + max_tokens: int = 8191 + max_batch_size: int = 8 + api_key_env: str = "OPENAI_API_KEY" + + +class ConfigLoader: + def __init__(self, config_path: str = None): + if config_path is None: + # Default to the embedding_models.toml file we created + config_path = os.path.join( + os.path.dirname(__file__), "embedding_models.toml" + ) + + self.config_path = config_path + self.embedding_models: Dict[str, EmbeddingModelConfig] = {} + self._load_config() + + def _load_config(self): + """Load configuration from TOML file""" + try: + if not os.path.exists(self.config_path): + logger.warning(f"Config file not found: {self.config_path}") + return + + with open(self.config_path, "r", encoding="utf-8") as f: + config_data = toml.load(f) + + # Load embedding models + if "models" not in config_data: + return + + for model_key, model_data in config_data["models"].items(): + self.embedding_models[model_key] = EmbeddingModelConfig( + **model_data + ) + + logger.info( + f"Loaded {len(self.embedding_models)} embedding models from {self.config_path}" + ) + + except Exception as e: + logger.error(f"Failed to load config from {self.config_path}: {e}") + + def get_embedding_models(self) -> Dict[str, EmbeddingModelConfig]: + """Get all available embedding models""" + return self.embedding_models.copy() + + def get_embedding_model(self, model_name: str) -> EmbeddingModelConfig: + """Get specific embedding model config""" + if model_name not in self.embedding_models: + raise ValueError( + f"Embedding model '{model_name}' not found in configuration" + ) + return self.embedding_models[model_name] + + def list_model_names(self) -> list: + """Get list of available model names""" + return list(self.embedding_models.keys()) + + def reload_config(self): + """Reload configuration from file""" + self.embedding_models = {} + self._load_config() + + +# Global config loader instance +config_loader = ConfigLoader() diff --git a/ml_new/training/database.py b/ml_new/training/database.py new file mode 100644 index 0000000..50f2d53 --- /dev/null +++ b/ml_new/training/database.py @@ -0,0 +1,284 @@ +""" +Database connection and operations for ML training service +""" + +import os +import hashlib +from typing import List, Dict, Optional, Any +from datetime import datetime +import asyncpg +import logging +from config_loader import config_loader +from dotenv import load_dotenv + +load_dotenv() + +logger = logging.getLogger(__name__) + +# Database configuration +DATABASE_URL = os.getenv("DATABASE_URL") + +class DatabaseManager: + def __init__(self): + self.pool: Optional[asyncpg.Pool] = None + + async def connect(self): + """Initialize database connection pool""" + try: + self.pool = await asyncpg.create_pool(DATABASE_URL, min_size=5, max_size=20) + + logger.info("Database connection pool initialized") + except Exception as e: + logger.error(f"Failed to connect to database: {e}") + raise + + async def close(self): + """Close database connection pool""" + if self.pool: + await self.pool.close() + logger.info("Database connection pool closed") + + @property + def is_connected(self) -> bool: + """Check if database connection pool is initialized""" + return self.pool is not None + + async def get_embedding_models(self): + """Get available embedding models from config""" + return config_loader.get_embedding_models() + + async def get_video_metadata( + self, aid_list: List[int] + ) -> Dict[int, Dict[str, Any]]: + """Get video metadata for given AIDs""" + if not aid_list: + return {} + + async with self.pool.acquire() as conn: + query = """ + SELECT aid, title, description, tags + FROM bilibili_metadata + WHERE aid = ANY($1::bigint[]) + """ + rows = await conn.fetch(query, aid_list) + + result = {} + for row in rows: + result[int(row["aid"])] = { + "aid": int(row["aid"]), + "title": row["title"] or "", + "description": row["description"] or "", + "tags": row["tags"] or "", + } + + return result + + async def get_user_labels( + self, aid_list: List[int] + ) -> Dict[int, List[Dict[str, Any]]]: + """Get user labels for given AIDs, only the latest label per user""" + if not aid_list: + return {} + + async with self.pool.acquire() as conn: + query = """ + WITH latest_labels AS ( + SELECT DISTINCT ON (aid, "user") + aid, "user", label, created_at + FROM internal.video_type_label + WHERE aid = ANY($1::bigint[]) + ORDER BY aid, "user", created_at DESC + ) + SELECT aid, "user", label, created_at + FROM latest_labels + ORDER BY aid, "user" + """ + rows = await conn.fetch(query, aid_list) + + result = {} + for row in rows: + aid = int(row["aid"]) + if aid not in result: + result[aid] = [] + + result[aid].append( + { + "user": row["user"], + "label": bool(row["label"]), + "created_at": row["created_at"].isoformat(), + } + ) + + return result + + async def get_existing_embeddings( + self, checksums: List[str], model_name: str + ) -> Dict[str, Dict[str, Any]]: + """Get existing embeddings for given checksums and model""" + if not checksums: + return {} + + async with self.pool.acquire() as conn: + query = """ + SELECT data_checksum, vec_2048, vec_1536, vec_1024, created_at + FROM internal.embeddings + WHERE model_name = $1 AND data_checksum = ANY($2::text[]) + """ + rows = await conn.fetch(query, model_name, checksums) + + result = {} + for row in rows: + checksum = row["data_checksum"] + + # Convert vector strings to lists if they exist + vec_2048 = self._parse_vector_string(row["vec_2048"]) if row["vec_2048"] else None + vec_1536 = self._parse_vector_string(row["vec_1536"]) if row["vec_1536"] else None + vec_1024 = self._parse_vector_string(row["vec_1024"]) if row["vec_1024"] else None + + result[checksum] = { + "checksum": checksum, + "vec_2048": vec_2048, + "vec_1536": vec_1536, + "vec_1024": vec_1024, + "created_at": row["created_at"].isoformat(), + } + + return result + + def _parse_vector_string(self, vector_str: str) -> List[float]: + """Parse vector string format '[1.0,2.0,3.0]' back to list""" + if not vector_str: + return [] + + try: + # Remove brackets and split by comma + vector_str = vector_str.strip() + if vector_str.startswith('[') and vector_str.endswith(']'): + vector_str = vector_str[1:-1] + + return [float(x.strip()) for x in vector_str.split(',') if x.strip()] + except Exception as e: + logger.warning(f"Failed to parse vector string '{vector_str}': {e}") + return [] + + async def insert_embeddings(self, embeddings_data: List[Dict[str, Any]]) -> None: + """Batch insert embeddings into database""" + if not embeddings_data: + return + + async with self.pool.acquire() as conn: + async with conn.transaction(): + for data in embeddings_data: + # Determine which vector column to use based on dimensions + vec_column = f"vec_{data['dimensions']}" + + # Convert vector list to string format for PostgreSQL + vector_str = "[" + ",".join(map(str, data["vector"])) + "]" + + query = f""" + INSERT INTO internal.embeddings + (model_name, data_checksum, {vec_column}, created_at) + VALUES ($1, $2, $3, $4) + ON CONFLICT (data_checksum) DO NOTHING + """ + + await conn.execute( + query, + data["model_name"], + data["checksum"], + vector_str, + datetime.now(), + ) + + async def get_final_dataset( + self, aid_list: List[int], model_name: str + ) -> List[Dict[str, Any]]: + """Get final dataset with embeddings and labels""" + if not aid_list: + return [] + + # Get video metadata + metadata = await self.get_video_metadata(aid_list) + + # Get user labels (latest per user) + labels = await self.get_user_labels(aid_list) + + # Prepare text data for embedding + text_data = [] + aid_to_text = {} + + for aid in aid_list: + if aid in metadata: + # Combine title, description, and tags for embedding + text_parts = [ + metadata[aid]["title"], + metadata[aid]["description"], + metadata[aid]["tags"], + ] + combined_text = " ".join(filter(None, text_parts)) + + # Create checksum for deduplication + checksum = hashlib.md5(combined_text.encode("utf-8")).hexdigest() + + text_data.append( + {"aid": aid, "text": combined_text, "checksum": checksum} + ) + aid_to_text[checksum] = aid + + # Get checksums = [ existing embeddings + checks = [item["checksum"] for item in text_data] + existing_embeddings = await self.get_existing_embeddings(checks) + + # ums, model_name Prepare final dataset + dataset = [] + + for item in text_data: + aid = item["aid"] + checksum = item["checksum"] + + # Get embedding vector + embedding_vector = None + if checksum in existing_embeddings: + # Use existing embedding + emb_data = existing_embeddings[checksum] + if emb_data["vec_1536"]: + embedding_vector = emb_data["vec_1536"] + elif emb_data["vec_2048"]: + embedding_vector = emb_data["vec_2048"] + elif emb_data["vec_1024"]: + embedding_vector = emb_data["vec_1024"] + + # Get labels for this aid + aid_labels = labels.get(aid, []) + + # Determine final label using consensus (majority vote) + if aid_labels: + positive_votes = sum(1 for lbl in aid_labels if lbl["label"]) + final_label = positive_votes > len(aid_labels) / 2 + else: + final_label = None # No labels available + + # Check for inconsistent labels + inconsistent = len(aid_labels) > 1 and ( + sum(1 for lbl in aid_labels if lbl["label"]) != 0 + and sum(1 for lbl in aid_labels if lbl["label"]) != len(aid_labels) + ) + + if embedding_vector and final_label is not None: + dataset.append( + { + "aid": aid, + "embedding": embedding_vector, + "label": final_label, + "metadata": metadata.get(aid, {}), + "user_labels": aid_labels, + "inconsistent": inconsistent, + "text_checksum": checksum, + } + ) + + return dataset + + +# Global database manager instance +db_manager = DatabaseManager() diff --git a/ml_new/training/dataset_service.py b/ml_new/training/dataset_service.py new file mode 100644 index 0000000..7977780 --- /dev/null +++ b/ml_new/training/dataset_service.py @@ -0,0 +1,373 @@ +""" +Dataset building service - handles the complete dataset construction flow +""" + +import os +import json +import logging +import asyncio +from pathlib import Path +from typing import List, Dict, Any, Optional +from datetime import datetime + +from database import DatabaseManager +from embedding_service import EmbeddingService +from config_loader import config_loader + + +logger = logging.getLogger(__name__) + + +class DatasetBuilder: + """Service for building datasets with the specified flow""" + + def __init__(self, db_manager: DatabaseManager, embedding_service: EmbeddingService, storage_dir: str = "datasets"): + self.db_manager = db_manager + self.embedding_service = embedding_service + self.storage_dir = Path(storage_dir) + self.storage_dir.mkdir(exist_ok=True) + + # Load existing datasets from file system + self.dataset_storage: Dict[str, Dict] = self._load_all_datasets() + + + def _get_dataset_file_path(self, dataset_id: str) -> Path: + """Get file path for dataset""" + return self.storage_dir / f"{dataset_id}.json" + + + def _load_dataset_from_file(self, dataset_id: str) -> Optional[Dict[str, Any]]: + """Load dataset from file""" + file_path = self._get_dataset_file_path(dataset_id) + if not file_path.exists(): + return None + + try: + with open(file_path, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"Failed to load dataset {dataset_id} from file: {e}") + return None + + + def _save_dataset_to_file(self, dataset_id: str, dataset_data: Dict[str, Any]) -> bool: + """Save dataset to file""" + file_path = self._get_dataset_file_path(dataset_id) + + try: + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(dataset_data, f, ensure_ascii=False, indent=2) + return True + except Exception as e: + logger.error(f"Failed to save dataset {dataset_id} to file: {e}") + return False + + + def _load_all_datasets(self) -> Dict[str, Dict]: + """Load all datasets from file system""" + datasets = {} + + try: + for file_path in self.storage_dir.glob("*.json"): + dataset_id = file_path.stem + dataset_data = self._load_dataset_from_file(dataset_id) + if dataset_data: + datasets[dataset_id] = dataset_data + logger.info(f"Loaded {len(datasets)} datasets from file system") + except Exception as e: + logger.error(f"Failed to load datasets from file system: {e}") + + return datasets + + + async def cleanup_old_datasets(self, max_age_days: int = 30): + """Remove datasets older than specified days""" + try: + cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60) + removed_count = 0 + + for dataset_id in list(self.dataset_storage.keys()): + dataset_info = self.dataset_storage[dataset_id] + if "created_at" in dataset_info: + try: + created_time = datetime.fromisoformat(dataset_info["created_at"]).timestamp() + if created_time < cutoff_time: + # Remove from memory + del self.dataset_storage[dataset_id] + + # Remove file + file_path = self._get_dataset_file_path(dataset_id) + if file_path.exists(): + file_path.unlink() + + removed_count += 1 + except Exception as e: + logger.warning(f"Failed to process dataset {dataset_id} for cleanup: {e}") + + if removed_count > 0: + logger.info(f"Cleaned up {removed_count} old datasets") + + except Exception as e: + logger.error(f"Failed to cleanup old datasets: {e}") + + async def build_dataset(self, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool = False) -> str: + """ + Build dataset with the specified flow: + 1. Select embedding model (from TOML config) + 2. Pull raw text from database + 3. Preprocess (placeholder for now) + 4. Batch get embeddings (deduplicate by hash, skip if already in embeddings table) + 5. Write to embeddings table + 6. Pull all needed embeddings to create dataset with format: embeddings, label + """ + + try: + logger.info(f"Starting dataset building task {dataset_id}") + + EMBEDDING_MODELS = config_loader.get_embedding_models() + + # Get model configuration + if embedding_model not in EMBEDDING_MODELS: + raise ValueError(f"Invalid embedding model: {embedding_model}") + + model_config = EMBEDDING_MODELS[embedding_model] + + # Step 1: Get video metadata from database + metadata = await self.db_manager.get_video_metadata(aid_list) + + # Step 2: Get user labels + labels = await self.db_manager.get_user_labels(aid_list) + + # Step 3: Prepare text data and checksums + text_data = [] + + for aid in aid_list: + if aid in metadata: + # Combine title, description, tags + combined_text = self.embedding_service.combine_video_text( + metadata[aid]['title'], + metadata[aid]['description'], + metadata[aid]['tags'] + ) + + # Create checksum for deduplication + checksum = self.embedding_service.create_text_checksum(combined_text) + + text_data.append({ + 'aid': aid, + 'text': combined_text, + 'checksum': checksum + }) + + # Step 4: Check existing embeddings + checksums = [item['checksum'] for item in text_data] + existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model) + + # Step 5: Generate new embeddings for texts that don't have them + new_embeddings_needed = [] + for item in text_data: + if item['checksum'] not in existing_embeddings or force_regenerate: + new_embeddings_needed.append(item['text']) + + new_embeddings_count = 0 + if new_embeddings_needed: + logger.info(f"Generating {len(new_embeddings_needed)} new embeddings") + generated_embeddings = await self.embedding_service.generate_embeddings_batch( + new_embeddings_needed, + embedding_model + ) + + # Step 6: Store new embeddings in database + embeddings_to_store = [] + for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)): + checksum = self.embedding_service.create_text_checksum(text) + embeddings_to_store.append({ + 'model_name': embedding_model, + 'checksum': checksum, + 'dimensions': model_config.dimensions, + 'vector': embedding + }) + + await self.db_manager.insert_embeddings(embeddings_to_store) + new_embeddings_count = len(embeddings_to_store) + + # Update existing embeddings cache + for emb_data in embeddings_to_store: + existing_embeddings[emb_data['checksum']] = { + 'checksum': emb_data['checksum'], + f'vec_{model_config.dimensions}': emb_data['vector'] + } + + # Step 7: Build final dataset + dataset = [] + inconsistent_count = 0 + + for item in text_data: + aid = item['aid'] + checksum = item['checksum'] + + # Get embedding vector + embedding_vector = None + if checksum in existing_embeddings: + vec_key = f'vec_{model_config.dimensions}' + if vec_key in existing_embeddings[checksum]: + embedding_vector = existing_embeddings[checksum][vec_key] + + # Get labels for this aid + aid_labels = labels.get(aid, []) + + # Determine final label using consensus (majority vote) + final_label = None + if aid_labels: + positive_votes = sum(1 for lbl in aid_labels if lbl['label']) + final_label = positive_votes > len(aid_labels) / 2 + + # Check for inconsistent labels + inconsistent = len(aid_labels) > 1 and ( + sum(1 for lbl in aid_labels if lbl['label']) != 0 and + sum(1 for lbl in aid_labels if lbl['label']) != len(aid_labels) + ) + + if inconsistent: + inconsistent_count += 1 + + if embedding_vector and final_label is not None: + dataset.append({ + 'aid': aid, + 'embedding': embedding_vector, + 'label': final_label, + 'metadata': metadata.get(aid, {}), + 'user_labels': aid_labels, + 'inconsistent': inconsistent, + 'text_checksum': checksum + }) + + reused_count = len(dataset) - new_embeddings_count + + logger.info(f"Dataset building completed: {len(dataset)} records, {new_embeddings_count} new, {reused_count} reused, {inconsistent_count} inconsistent") + + # Prepare dataset data + dataset_data = { + 'dataset': dataset, + 'stats': { + 'total_records': len(dataset), + 'new_embeddings': new_embeddings_count, + 'reused_embeddings': reused_count, + 'inconsistent_labels': inconsistent_count, + 'embedding_model': embedding_model + }, + 'created_at': datetime.now().isoformat() + } + + # Save to file and memory cache + if self._save_dataset_to_file(dataset_id, dataset_data): + self.dataset_storage[dataset_id] = dataset_data + logger.info(f"Dataset {dataset_id} saved to file system") + else: + logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only") + self.dataset_storage[dataset_id] = dataset_data + + return dataset_id + + except Exception as e: + logger.error(f"Dataset building failed for {dataset_id}: {str(e)}") + + # Store error information + error_data = { + 'error': str(e), + 'created_at': datetime.now().isoformat() + } + + # Try to save error to file as well + self._save_dataset_to_file(dataset_id, error_data) + self.dataset_storage[dataset_id] = error_data + raise + + def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]: + """Get built dataset by ID""" + # First check memory cache + if dataset_id in self.dataset_storage: + return self.dataset_storage[dataset_id] + + # If not in memory, try to load from file + dataset_data = self._load_dataset_from_file(dataset_id) + if dataset_data: + # Add to memory cache + self.dataset_storage[dataset_id] = dataset_data + return dataset_data + + return None + + def dataset_exists(self, dataset_id: str) -> bool: + """Check if dataset exists""" + # Check memory cache first + if dataset_id in self.dataset_storage: + return True + + # Check file system + return self._get_dataset_file_path(dataset_id).exists() + + def delete_dataset(self, dataset_id: str) -> bool: + """Delete dataset from both memory and file system""" + try: + # Remove from memory + if dataset_id in self.dataset_storage: + del self.dataset_storage[dataset_id] + + # Remove file + file_path = self._get_dataset_file_path(dataset_id) + if file_path.exists(): + file_path.unlink() + logger.info(f"Dataset {dataset_id} deleted from file system") + return True + else: + logger.warning(f"Dataset file {dataset_id} not found for deletion") + return False + + except Exception as e: + logger.error(f"Failed to delete dataset {dataset_id}: {e}") + return False + + def list_datasets(self) -> List[Dict[str, Any]]: + """List all datasets with their basic information""" + datasets = [] + + for dataset_id, dataset_info in self.dataset_storage.items(): + if "error" not in dataset_info: + datasets.append({ + "dataset_id": dataset_id, + "stats": dataset_info["stats"], + "created_at": dataset_info["created_at"] + }) + + # Sort by creation time (newest first) + datasets.sort(key=lambda x: x["created_at"], reverse=True) + + return datasets + + def get_dataset_stats(self) -> Dict[str, Any]: + """Get overall statistics about stored datasets""" + total_datasets = len(self.dataset_storage) + error_datasets = sum(1 for data in self.dataset_storage.values() if "error" in data) + valid_datasets = total_datasets - error_datasets + + total_records = 0 + total_new_embeddings = 0 + total_reused_embeddings = 0 + + for dataset_info in self.dataset_storage.values(): + if "stats" in dataset_info: + stats = dataset_info["stats"] + total_records += stats.get("total_records", 0) + total_new_embeddings += stats.get("new_embeddings", 0) + total_reused_embeddings += stats.get("reused_embeddings", 0) + + return { + "total_datasets": total_datasets, + "valid_datasets": valid_datasets, + "error_datasets": error_datasets, + "total_records": total_records, + "total_new_embeddings": total_new_embeddings, + "total_reused_embeddings": total_reused_embeddings, + "storage_directory": str(self.storage_dir) + } \ No newline at end of file diff --git a/ml_new/training/embedding_models.toml b/ml_new/training/embedding_models.toml new file mode 100644 index 0000000..f650fd0 --- /dev/null +++ b/ml_new/training/embedding_models.toml @@ -0,0 +1,12 @@ +# Embedding Models Configuration + +model = "qwen3-embedding" + +[models.qwen3-embedding] +name = "text-embedding-v4" +dimensions = 2048 +type = "openai-compatible" +api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1" +max_tokens = 8192 +max_batch_size = 10 +api_key_env = "ALIYUN_KEY" diff --git a/ml_new/training/embedding_service.py b/ml_new/training/embedding_service.py new file mode 100644 index 0000000..8bc16b8 --- /dev/null +++ b/ml_new/training/embedding_service.py @@ -0,0 +1,143 @@ +""" +Embedding service for generating embeddings using OpenAI-compatible API +""" +import asyncio +import hashlib +from typing import List, Dict, Any, Optional +import logging +from openai import AsyncOpenAI +import os +from config_loader import config_loader +from dotenv import load_dotenv + +load_dotenv() + +logger = logging.getLogger(__name__) + +class EmbeddingService: + def __init__(self): + # Get configuration from config loader + self.embedding_models = config_loader.get_embedding_models() + + # Initialize OpenAI client (will be configured per model) + self.clients: Dict[str, AsyncOpenAI] = {} + self._initialize_clients() + + # Rate limiting + self.max_requests_per_minute = int(os.getenv("MAX_REQUESTS_PER_MINUTE", "100")) + self.request_interval = 60.0 / self.max_requests_per_minute + + def _initialize_clients(self): + """Initialize OpenAI clients for different models/endpoints""" + for model_name, model_config in self.embedding_models.items(): + if model_config.type == "openai-compatible": + # Get API key from environment variable specified in config + api_key = os.getenv(model_config.api_key_env) + + self.clients[model_name] = AsyncOpenAI( + api_key=api_key, + base_url=model_config.api_endpoint + ) + logger.info(f"Initialized client for model {model_name}") + + async def generate_embeddings_batch( + self, + texts: List[str], + model: str, + batch_size: Optional[int] = None + ) -> List[List[float]]: + """Generate embeddings for a batch of texts""" + + # Get model configuration + if model not in self.embedding_models: + raise ValueError(f"Model '{model}' not found in configuration") + + model_config = self.embedding_models[model] + + # Use model's max_batch_size if not specified + if batch_size is None: + batch_size = model_config.max_batch_size + + # Validate model and get expected dimensions + expected_dims = model_config.dimensions + + if model not in self.clients: + raise ValueError(f"No client configured for model '{model}'") + + client = self.clients[model] + all_embeddings = [] + + # Process in batches to avoid API limits + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + + try: + # Rate limiting + if i > 0: + await asyncio.sleep(self.request_interval) + + # Generate embeddings + response = await client.embeddings.create( + model=model_config.name, + input=batch, + dimensions=expected_dims + ) + + batch_embeddings = [data.embedding for data in response.data] + all_embeddings.extend(batch_embeddings) + + logger.info(f"Generated embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}") + + except Exception as e: + logger.error(f"Error generating embeddings for batch {i//batch_size + 1}: {e}") + # For now, fill with zeros as fallback (could implement retry logic) + zero_embedding = [0.0] * expected_dims + all_embeddings.extend([zero_embedding] * len(batch)) + + return all_embeddings + + def create_text_checksum(self, text: str) -> str: + """Create MD5 checksum for text deduplication""" + return hashlib.md5(text.encode('utf-8')).hexdigest() + + def combine_video_text(self, title: str, description: str, tags: str) -> str: + """Combine video metadata into a single text for embedding""" + parts = [ + title.strip() if "标题:"+title else "", + description.strip() if "简介:"+description else "", + tags.strip() if "标签:"+tags else "" + ] + + # Filter out empty parts and join + combined = '\n'.join(filter(None, parts)) + return combined + + async def health_check(self) -> Dict[str, Any]: + """Check if embedding service is healthy""" + try: + # Test with a simple embedding using the first available model + model_name = list(self.embedding_models.keys())[0] + + test_embedding = await self.generate_embeddings_batch( + ["health check"], + model_name, + batch_size=1 + ) + + return { + "status": "healthy", + "service": "embedding_service", + "model": model_name, + "dimensions": len(test_embedding[0]) if test_embedding else 0, + "available_models": list(self.embedding_models.keys()) + } + + except Exception as e: + return { + "status": "unhealthy", + "service": "embedding_service", + "error": str(e) + } + +# Global embedding service instance +embedding_service = EmbeddingService() \ No newline at end of file diff --git a/ml_new/training/main.py b/ml_new/training/main.py index 93211e2..b510e3d 100644 --- a/ml_new/training/main.py +++ b/ml_new/training/main.py @@ -1,246 +1,113 @@ -from fastapi import FastAPI, HTTPException, BackgroundTasks -from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel -from typing import List, Dict, Optional, Any -import asyncio -import uuid -import logging -from datetime import datetime -import json +""" +Main FastAPI application for ML training service +""" -# Setup logging +import logging +import uvicorn +from contextlib import asynccontextmanager +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from database import DatabaseManager +from embedding_service import EmbeddingService +from dataset_service import DatasetBuilder +from api_routes import router, set_dataset_builder + +# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Initialize FastAPI app -app = FastAPI( - title="CVSA ML Training API", - version="1.0.0", - description="ML training service for video classification" -) -# Enable CORS for web UI -app.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost:3000", "http://localhost:5173"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# Global service instances +db_manager = None +embedding_service = None +dataset_builder = None -# Pydantic models -class Hyperparameter(BaseModel): - name: str - type: str # 'number', 'boolean', 'select' - value: Any - range: Optional[tuple] = None - options: Optional[List[str]] = None - description: Optional[str] = None -class TrainingConfig(BaseModel): - learning_rate: float = 1e-4 - batch_size: int = 32 - epochs: int = 10 - early_stop: bool = True - patience: int = 3 - embedding_model: str = "text-embedding-3-small" - -class TrainingRequest(BaseModel): - experiment_name: str - config: TrainingConfig - dataset: Dict[str, Any] - -class TrainingStatus(BaseModel): - experiment_id: str - status: str # 'pending', 'running', 'completed', 'failed' - progress: Optional[float] = None - current_epoch: Optional[int] = None - total_epochs: Optional[int] = None - metrics: Optional[Dict[str, float]] = None - error: Optional[str] = None - -class ExperimentResult(BaseModel): - experiment_id: str - experiment_name: str - config: TrainingConfig - metrics: Dict[str, float] - created_at: str - status: str - -class EmbeddingRequest(BaseModel): - texts: List[str] - model: str - -# In-memory storage for experiments (in production, use database) -training_sessions: Dict[str, Dict] = {} -experiments: Dict[str, ExperimentResult] = {} - -# Default hyperparameters that will be dynamically discovered -DEFAULT_HYPERPARAMETERS = [ - Hyperparameter( - name="learning_rate", - type="number", - value=1e-4, - range=(1e-6, 1e-2), - description="Learning rate for optimizer" - ), - Hyperparameter( - name="batch_size", - type="number", - value=32, - range=(8, 256), - description="Training batch size" - ), - Hyperparameter( - name="epochs", - type="number", - value=10, - range=(1, 100), - description="Number of training epochs" - ), - Hyperparameter( - name="early_stop", - type="boolean", - value=True, - description="Enable early stopping" - ), - Hyperparameter( - name="patience", - type="number", - value=3, - range=(1, 20), - description="Early stopping patience" - ), - Hyperparameter( - name="embedding_model", - type="select", - value="text-embedding-3-small", - options=["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"], - description="Embedding model to use" - ) -] - -@app.get("/") -async def root(): - return {"message": "CVSA ML Training API", "version": "1.0.0"} - -@app.get("/health") -async def health_check(): - return {"status": "healthy", "service": "ml-training-api"} - -@app.get("/hyperparameters", response_model=List[Hyperparameter]) -async def get_hyperparameters(): - """Get all available hyperparameters for the current model""" - return DEFAULT_HYPERPARAMETERS - -@app.post("/train") -async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks): - """Start a new training experiment""" - experiment_id = str(uuid.uuid4()) +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan manager for startup and shutdown events""" + global db_manager, embedding_service, dataset_builder - # Store training session - training_sessions[experiment_id] = { - "experiment_id": experiment_id, - "experiment_name": request.experiment_name, - "config": request.config.dict(), - "dataset": request.dataset, - "status": "pending", - "created_at": datetime.now().isoformat() - } + # Startup + logger.info("Initializing services...") - # Start background training task - background_tasks.add_task(run_training, experiment_id, request) - - return {"experiment_id": experiment_id} - -@app.get("/train/{experiment_id}/status", response_model=TrainingStatus) -async def get_training_status(experiment_id: str): - """Get training status for an experiment""" - if experiment_id not in training_sessions: - raise HTTPException(status_code=404, detail="Experiment not found") - - session = training_sessions[experiment_id] - - return TrainingStatus( - experiment_id=experiment_id, - status=session.get("status", "unknown"), - progress=session.get("progress"), - current_epoch=session.get("current_epoch"), - total_epochs=session.get("total_epochs"), - metrics=session.get("metrics"), - error=session.get("error") - ) - -@app.get("/experiments", response_model=List[ExperimentResult]) -async def list_experiments(): - """List all experiments""" - return list(experiments.values()) - -@app.get("/experiments/{experiment_id}", response_model=ExperimentResult) -async def get_experiment(experiment_id: str): - """Get experiment details""" - if experiment_id not in experiments: - raise HTTPException(status_code=404, detail="Experiment not found") - - return experiments[experiment_id] - -@app.post("/embeddings") -async def generate_embeddings(request: EmbeddingRequest): - """Generate embeddings using OpenAI-compatible API""" - # This is a placeholder implementation - # In production, this would call actual embedding API - embeddings = [] - for text in request.texts: - # Mock embedding generation - embedding = [0.1] * 1536 # Mock 1536-dimensional embedding - embeddings.append(embedding) - - return embeddings - -async def run_training(experiment_id: str, request: TrainingRequest): - """Background task to run training""" try: - session = training_sessions[experiment_id] - session["status"] = "running" - session["total_epochs"] = request.config.epochs + # Database manager + db_manager = DatabaseManager() + await db_manager.connect() # Initialize database connection pool + logger.info("Database manager initialized and connected") - # Simulate training process - for epoch in range(request.config.epochs): - session["current_epoch"] = epoch + 1 - session["progress"] = (epoch + 1) / request.config.epochs - - # Simulate training metrics - session["metrics"] = { - "loss": max(0.0, 1.0 - (epoch + 1) * 0.1), - "accuracy": min(0.95, 0.5 + (epoch + 1) * 0.05), - "val_loss": max(0.0, 0.8 - (epoch + 1) * 0.08), - "val_accuracy": min(0.92, 0.45 + (epoch + 1) * 0.04) - } - - logger.info(f"Training epoch {epoch + 1}/{request.config.epochs}") - await asyncio.sleep(1) # Simulate training time + # Embedding service + embedding_service = EmbeddingService() + logger.info("Embedding service initialized") - # Training completed - session["status"] = "completed" - final_metrics = session["metrics"] + # Dataset builder + dataset_builder = DatasetBuilder(db_manager, embedding_service) + logger.info("Dataset builder initialized") - # Store final experiment result - experiments[experiment_id] = ExperimentResult( - experiment_id=experiment_id, - experiment_name=request.experiment_name, - config=request.config, - metrics=final_metrics, - created_at=session["created_at"], - status="completed" - ) + # Set global dataset builder instance + set_dataset_builder(dataset_builder) - logger.info(f"Training completed for experiment {experiment_id}") + logger.info("All services initialized successfully") except Exception as e: - session["status"] = "failed" - session["error"] = str(e) - logger.error(f"Training failed for experiment {experiment_id}: {str(e)}") + logger.error(f"Failed to initialize services: {e}") + raise + + # Yield control to the application + yield + + # Shutdown + logger.info("Shutting down services...") + + try: + if db_manager: + await db_manager.close() + logger.info("Database connection pool closed") + except Exception as e: + logger.error(f"Error during shutdown: {e}") + + +def create_app() -> FastAPI: + """Create and configure FastAPI application""" + + # Create FastAPI app with lifespan manager + app = FastAPI( + title="ML Training Service", + description="ML training, dataset building, and experiment management service", + version="1.0.0", + lifespan=lifespan + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Configure appropriately for production + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Include API routes + app.include_router(router) + + return app + + +def main(): + """Main entry point""" + app = create_app() + + # Run the application + uvicorn.run( + app, + host="0.0.0.0", + port=8322, + log_level="info", + access_log=True + ) + if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + main() \ No newline at end of file diff --git a/ml_new/training/models.py b/ml_new/training/models.py new file mode 100644 index 0000000..66f084b --- /dev/null +++ b/ml_new/training/models.py @@ -0,0 +1,62 @@ +""" +Data models for dataset building functionality +""" + +from typing import List, Optional, Dict, Any +from pydantic import BaseModel, Field +from datetime import datetime + + +class DatasetBuildRequest(BaseModel): + """Request model for dataset building""" + aid_list: List[int] = Field(..., description="List of video AIDs") + embedding_model: str = Field(..., description="Embedding model name") + force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings") + + +class DatasetBuildResponse(BaseModel): + """Response model for dataset building""" + dataset_id: str + total_records: int + status: str + message: str + created_at: Optional[datetime] = None + + +class DatasetRecord(BaseModel): + """Model for a single dataset record""" + aid: int + embedding: List[float] + label: bool + metadata: Dict[str, Any] + user_labels: List[Dict[str, Any]] + inconsistent: bool + text_checksum: str + + +class DatasetInfo(BaseModel): + """Model for dataset information""" + dataset_id: str + dataset: List[DatasetRecord] + stats: Dict[str, Any] + created_at: datetime + + +class DatasetBuildStats(BaseModel): + """Statistics for dataset building process""" + total_records: int + new_embeddings: int + reused_embeddings: int + inconsistent_labels: int + embedding_model: str + processing_time: Optional[float] = None + + +class EmbeddingModelInfo(BaseModel): + """Information about embedding models""" + name: str + dimensions: int + type: str + api_endpoint: Optional[str] = None + max_tokens: Optional[int] = None + max_batch_size: Optional[int] = None \ No newline at end of file diff --git a/ml_new/training/requirements.txt b/ml_new/training/requirements.txt index 4cd74eb..d14ecd3 100644 --- a/ml_new/training/requirements.txt +++ b/ml_new/training/requirements.txt @@ -9,4 +9,8 @@ scikit-learn==1.3.2 pandas==2.1.3 openai==1.3.7 psycopg2-binary==2.9.9 -sqlalchemy==2.0.23 \ No newline at end of file +sqlalchemy==2.0.23 +asyncpg==0.29.0 +toml==0.10.2 +aiohttp==3.9.0 +python-dotenv==1.1.0 \ No newline at end of file diff --git a/mutagen.yml b/mutagen.yml new file mode 100644 index 0000000..d343689 --- /dev/null +++ b/mutagen.yml @@ -0,0 +1,26 @@ +# mutagen.yml +sync: + development: + alpha: "." + beta: "root@cvsa-hk-02:/web/cvsa" + maxStagingFileSize: "5MB" + mode: "two-way-resolved" + ignore: + paths: + - "**/node_modules/" + - "*.log" + - "/model" + - "**/.DS_Store" + - ".env" + - ".env.*" + - "**/logs/" + - ".git/" + - ".jj/" + - "dist" + - "**/build/" + - "**/.react-router" + - "redis/" + - "temp/" + - "mutagen.yml" + - "/ml_new/" + - "/ml/" \ No newline at end of file diff --git a/mutagen.yml.lock b/mutagen.yml.lock new file mode 100644 index 0000000..fd38c71 --- /dev/null +++ b/mutagen.yml.lock @@ -0,0 +1 @@ +proj_EPp5s45rolBZ729uBHaoBMzV51GCoGkpnJuTkK6owzH \ No newline at end of file diff --git a/packages/ml_panel/README.md b/packages/ml_panel/README.md new file mode 100644 index 0000000..e1ff783 --- /dev/null +++ b/packages/ml_panel/README.md @@ -0,0 +1,267 @@ +# CVSA ML 基础设施重构项目 + +## 项目概述 + +本项目旨在重构现有的 ML 服务基础设施,从原始的 `ml/filter` 系统迁移到新的前后端分离架构。主要目标是为 ML 训练、实验管理和数据处理提供一个现代化的 Web UI 界面。 + +### 核心功能 + +- **数据管线管理**: 从 PostgreSQL 数据库获取和预处理训练数据 +- **实验管理**: 训练参数配置、实验追踪和结果可视化 +- **超参数调优**: 动态超参数配置和调整 +- **数据标注界面**: 简单易用的数据标注和管理工具 +- **模型训练**: 2分类视频分类模型训练 +- **嵌入向量管理**: 支持多种嵌入模型和向量维度 + +## 架构设计 + +### 技术栈 + +- **前端**: React + TypeScript + Vite + Tailwind CSS + shadcn/ui +- **后端**: FastAPI + Python +- **数据库**: PostgreSQL + Drizzle ORM +- **向量数据库**: PostgreSQL pgvector +- **包管理**: Bun (TypeScript) + pip (Python) + +### 分层架构 + +``` +┌─────────────────┐ ┌──────────────┐ ┌──────────────────┐ +│ Web UI │ │ FastAPI │ │ Database │ +│ (React TS) │◄──►│ (Python) │◄──►│ (PostgreSQL) │ +└─────────────────┘ └──────────────┘ └──────────────────┘ +``` + +## 目录结构 + +### 前端项目 (`packages/ml_panel/`) + +``` +packages/ml_panel/ +├── src/ # 前端应用 +│ ├── App.tsx # 主应用组件 +│ ├── main.tsx # 应用入口 +│ ├── index.css # 全局样式 +│ └── lib/ +│ └── utils.ts # 前端工具函数 +├── lib/ # 核心库文件 +│ ├── types.ts # 共享类型定义 +│ ├── ml-client.ts # ML API 客户端 +│ ├── data-pipeline/ # 数据管线类型 +│ │ └── types.ts +│ └── index.ts # 导出文件 +├── package.json +├── vite.config.ts +└── tailwind.config.js +``` + +### 后端服务 (`ml_new/training/`) + +``` +ml_new/training/ +├── main.py # FastAPI 主服务 +├── requirements.txt # Python 依赖 +└── ... # 其他服务文件 +``` + +### 数据库 Schema + +使用现有的 `packages/core/drizzle/main/schema.ts` 中的定义: + +- `videoTypeLabelInInternal`: 用户标注数据 +- `embeddingsInInternal`: 嵌入向量存储 +- `bilibiliMetadata`: 视频元数据 + +## 已完成的工作 + +### 1. 核心类型定义 + +**文件**: `packages/ml_panel/lib/types.ts` + +- 定义了核心数据结构 +- `DatasetRecord`: 数据集记录 +- `UserLabel`: 用户标注 +- `EmbeddingModel`: 嵌入模型配置 +- `TrainingConfig`: 训练配置 +- `ExperimentResult`: 实验结果 +- `InconsistentLabel`: 标注不一致数据 + +### 2. 数据管线类型 + +**文件**: `packages/ml_panel/lib/data-pipeline/types.ts` + +- `VideoMetadata`: 视频元数据 +- `VideoTypeLabel`: 标注数据 +- `EmbeddingRecord`: 嵌入记录 +- `DataPipelineConfig`: 管线配置 +- `ProcessedDataset`: 处理后的数据集 + +### 3. ML 客户端 + +**文件**: `packages/ml_panel/lib/ml-client.ts` + +- `MLClient` 类用于与 FastAPI 通信 +- 超参数获取和更新 +- 训练任务启动和状态监控 +- 实验管理 +- 嵌入生成接口 + +### 4. FastAPI 服务框架 + +**文件**: `ml_new/training/main.py` + +- 基础的 FastAPI 应用配置 +- CORS 中间件配置 +- 内存存储的训练会话管理 +- 基础的 API 端点定义 + +### 5. 项目配置 + +- **前端**: `packages/ml_panel/package.json` - React + Vite + TypeScript 配置 +- **后端**: `ml_new/training/requirements.txt` - Python 依赖 +- **主项目**: `packages/ml/package.json` - Monorepo 工作空间配置 + +## 核心功能实现状态 + +### 已完成 + +- [x] 基础项目结构搭建 +- [x] 核心类型定义 +- [x] ML API 客户端 +- [x] FastAPI 服务框架 +- [x] 前端项目配置 + +### 待实现 + +- [ ] 数据管线核心逻辑实现 +- [ ] React UI 组件开发 +- [ ] FastAPI 服务功能完善 +- [ ] 数据库连接和数据获取逻辑 +- [ ] 用户标注数据处理 +- [ ] 嵌入向量管理 +- [ ] 标注一致性检查 +- [ ] 训练任务队列 +- [ ] 实验追踪和可视化 +- [ ] 超参数动态配置 +- [ ] 完整的前后端集成测试 + +## 数据流程设计 + +### 1. 数据集创建流程 (高 RTT 优化) + +```mermaid +graph TD + A[前端点击创建数据集] --> B[选定嵌入模型] + B --> C[从 TOML 配置载入模型参数] + C --> D[从数据库批量拉取原始文本] + D --> E[文本预处理] + E --> F[计算文本 hash 并去重] + F --> G[批量查询已有嵌入] + G --> H[区分需要生成的新文本] + H --> I[批量调用嵌入 API] + I --> J[批量写入 embeddings 表] + J --> K[拉取完整 embeddings 数据] + K --> L[合并标签数据] + L --> M[构建最终数据集
格式: embeddings, label] +``` + +### 2. 数据获取流程 (数据库优化) + +```mermaid +graph TD + A[PostgreSQL 远程数据库
RTT: 100ms] --> B[videoTypeLabelInInternal] + A --> C[embeddingsInInternal] + A --> D[bilibiliMetadata] + B --> E[批量获取用户最后一次标注
避免循环查询] + C --> F[批量获取嵌入向量
一次性查询所有维度] + D --> G[批量获取视频元数据
IN 查询避免 N+1] + E --> H[标注一致性检查] + F --> I[向量数据处理] + G --> J[数据合并] + H --> K[数据集构建] + I --> K + J --> K +``` + +### 3. 训练流程 + +```mermaid +graph TD + A[前端配置] --> B[超参数设置] + B --> C[FastAPI 接收] + C --> D[数据管线处理] + D --> E[模型训练] + E --> F[实时状态更新] + F --> G[结果存储] + G --> H[前端展示] +``` + +## 技术要点 + +### 1. 高性能数据库设计 (RTT 优化) + +- **批量操作**: 避免循环查询,使用 `IN` 语句和批量 `INSERT/UPDATE` + +### 2. 嵌入向量管理 + +- **多模型支持**: `embeddingsInInternal` 支持不同维度的向量 (2048/1536/1024) +- **去重机制**: 使用文本 hash 去重,避免重复生成嵌入 +- **批量处理**: 批量生成和存储嵌入向量 +- **缓存策略**: 优先使用已存在的嵌入向量 + +### 3. 2分类模型架构 + +- 从原有的 3分类系统迁移到 2分类 +- 输入: 预计算的嵌入向量 (而非原始文本) +- 支持多种嵌入模型切换 + +### 4. 数据一致性处理 + +- **用户标注**: `videoTypeLabelInInternal` 存储多用户标注 +- **最后标注**: 获取每个用户的最后一次标注作为有效数据 +- **一致性检查**: 识别不同用户标注不一致的视频,标记为需要人工复核 + +## 后续开发计划 + +### Phase 1: 核心功能 (优先) + +1. **数据管线实现** + - 标注数据获取和一致性检查 + - 嵌入向量生成和存储 + - 数据集构建逻辑 + +2. **FastAPI 服务完善** + - 构建新的模型架构(输入嵌入向量,直接二分类头) + - 迁移现有 ml/filter 训练逻辑 + - 实现超参数动态暴露 + - 集成 OpenAI 兼容嵌入 API + - 训练任务队列管理 + +### Phase 2: 用户界面 + +1. **数据集创建界面** + - 嵌入模型选择 + - 数据预览和筛选 + - 处理进度显示 + +2. **训练参数配置界面** + - 超参数动态渲染 + - 参数验证和约束 + +3. **实验管理和追踪** + - 实验历史和比较 + - 训练状态实时监控 + - 结果可视化 + +### Phase 3: 高级功能 + +1. **超参数自动调优** +2. **模型版本管理** +3. **批量训练支持** +4. **性能优化** + +## 注意事项 + +1. **数据库性能**: 远程数据库 RTT 高,避免 N+1 查询,使用批量操作 +2. **标注一致性**: 实现自动的标注不一致检测 +3. **嵌入模型支持**: 为未来扩展多种嵌入模型预留接口 \ No newline at end of file diff --git a/packages/ml_panel/lib/data-pipeline/types.ts b/packages/ml_panel/lib/data-pipeline/types.ts deleted file mode 100644 index cbb1099..0000000 --- a/packages/ml_panel/lib/data-pipeline/types.ts +++ /dev/null @@ -1,47 +0,0 @@ -// Data pipeline specific types -import type { DatasetRecord, UserLabel, EmbeddingModel, InconsistentLabel } from "../types"; - -// Database types from packages/core -export interface VideoMetadata { - aid: number; - title: string; - description: string; - tags: string; - createdAt?: string; -} - -export interface VideoTypeLabel { - id: number; - aid: number; - label: boolean; - user: string; - createdAt: string; -} - -export interface EmbeddingRecord { - id: number; - modelName: string; - dataChecksum: string; - vec2048?: number[]; - vec1536?: number[]; - vec1024?: number[]; - createdAt?: string; -} - -export interface DataPipelineConfig { - embeddingModels: EmbeddingModel[]; - batchSize: number; - requireConsensus: boolean; - maxInconsistentRatio: number; -} - -export interface ProcessedDataset { - records: DatasetRecord[]; - inconsistentLabels: InconsistentLabel[]; - statistics: { - totalRecords: number; - labeledRecords: number; - inconsistentRecords: number; - embeddingCoverage: Record; - }; -} diff --git a/packages/ml_panel/lib/index.ts b/packages/ml_panel/lib/index.ts deleted file mode 100644 index e69de29..0000000 diff --git a/packages/ml_panel/lib/ml-client.ts b/packages/ml_panel/lib/ml-client.ts deleted file mode 100644 index ff0294a..0000000 --- a/packages/ml_panel/lib/ml-client.ts +++ /dev/null @@ -1,107 +0,0 @@ -// ML Client for communicating with FastAPI service -import type { TrainingConfig, ExperimentResult } from './types'; - -export interface Hyperparameter { - name: string; - type: 'number' | 'boolean' | 'select'; - value: any; - range?: [number, number]; - options?: string[]; - description?: string; -} - -export interface TrainingRequest { - experimentName: string; - config: TrainingConfig; - dataset: { - aid: number[]; - embeddings: Record; - labels: Record; - }; -} - -export interface TrainingStatus { - experimentId: string; - status: 'pending' | 'running' | 'completed' | 'failed'; - progress?: number; - currentEpoch?: number; - totalEpochs?: number; - metrics?: Record; - error?: string; -} - -export class MLClient { - private baseUrl: string; - - constructor(baseUrl: string = 'http://localhost:8000') { - this.baseUrl = baseUrl; - } - - // Get available hyperparameters from the model - async getHyperparameters(): Promise { - const response = await fetch(`${this.baseUrl}/hyperparameters`); - if (!response.ok) { - throw new Error(`Failed to get hyperparameters: ${response.statusText}`); - } - return (await response.json()) as Hyperparameter[]; - } - - // Start a training experiment - async startTraining(request: TrainingRequest): Promise<{ experimentId: string }> { - const response = await fetch(`${this.baseUrl}/train`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(request), - }); - - if (!response.ok) { - throw new Error(`Failed to start training: ${response.statusText}`); - } - return (await response.json()) as { experimentId: string }; - } - - // Get training status - async getTrainingStatus(experimentId: string): Promise { - const response = await fetch(`${this.baseUrl}/train/${experimentId}/status`); - if (!response.ok) { - throw new Error(`Failed to get training status: ${response.statusText}`); - } - return (await response.json()) as TrainingStatus; - } - - // Get experiment results - async getExperimentResult(experimentId: string): Promise { - const response = await fetch(`${this.baseUrl}/experiments/${experimentId}`); - if (!response.ok) { - throw new Error(`Failed to get experiment result: ${response.statusText}`); - } - return (await response.json()) as ExperimentResult; - } - - // List all experiments - async listExperiments(): Promise { - const response = await fetch(`${this.baseUrl}/experiments`); - if (!response.ok) { - throw new Error(`Failed to list experiments: ${response.statusText}`); - } - return (await response.json()) as ExperimentResult[]; - } - - // Generate embeddings using OpenAI-compatible API - async generateEmbeddings(texts: string[], model: string): Promise { - const response = await fetch(`${this.baseUrl}/embeddings`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ texts, model }), - }); - - if (!response.ok) { - throw new Error(`Failed to generate embeddings: ${response.statusText}`); - } - return (await response.json()) as number[][]; - } -} \ No newline at end of file diff --git a/packages/ml_panel/lib/types.ts b/packages/ml_panel/lib/types.ts deleted file mode 100644 index 6b826b2..0000000 --- a/packages/ml_panel/lib/types.ts +++ /dev/null @@ -1,54 +0,0 @@ -// Shared ML types and interfaces -export interface DatasetRecord { - aid: number; - title: string; - description: string; - tags: string; - embedding?: number[]; - label?: boolean; - userLabels?: UserLabel[]; -} - -export interface UserLabel { - user: string; - label: boolean; - createdAt: string; -} - -export interface EmbeddingModel { - name: string; - dimensions: number; - type: "openai-compatible" | "local"; - apiEndpoint?: string; -} - -export interface TrainingConfig { - learningRate: number; - batchSize: number; - epochs: number; - earlyStop: boolean; - patience?: number; - embeddingModel: string; -} - -export interface ExperimentResult { - experimentId: string; - config: TrainingConfig; - metrics: { - accuracy: number; - precision: number; - recall: number; - f1: number; - }; - createdAt: string; - status: "running" | "completed" | "failed"; -} - -export interface InconsistentLabel { - aid: number; - title: string; - description: string; - tags: string; - labels: UserLabel[]; - consensus?: boolean; -}