diff --git a/.gitignore b/.gitignore
index 4ea3796..3bb3b08 100644
--- a/.gitignore
+++ b/.gitignore
@@ -47,4 +47,4 @@ temp/
meili
-.turbo
\ No newline at end of file
+.turbo/
\ No newline at end of file
diff --git a/.kilocode/rules/common.md b/.kilocode/rules/common.md
new file mode 100644
index 0000000..0464778
--- /dev/null
+++ b/.kilocode/rules/common.md
@@ -0,0 +1,5 @@
+# common.md
+
+1. Always use bun as package manager.
+
+2. Always write comments in English.
diff --git a/ml_new/.gitignore b/ml_new/.gitignore
index c0190e1..f3c07f0 100644
--- a/ml_new/.gitignore
+++ b/ml_new/.gitignore
@@ -1 +1 @@
-datasets
\ No newline at end of file
+datasets/
\ No newline at end of file
diff --git a/ml_new/training/README.md b/ml_new/training/README.md
new file mode 100644
index 0000000..189315b
--- /dev/null
+++ b/ml_new/training/README.md
@@ -0,0 +1,140 @@
+# ML Training Service
+
+A FastAPI-based ML training service for dataset building, embedding generation, and experiment management.
+
+## Architecture
+
+The service is organized into modular components:
+
+```
+ml_new/training/
+├── main.py # FastAPI application entry point
+├── models.py # Pydantic data models
+├── config_loader.py # Configuration loading from TOML
+├── database.py # Database connection and operations
+├── embedding_service.py # Embedding generation service
+├── dataset_service.py # Dataset building logic
+├── api_routes.py # API endpoint definitions
+├── embedding_models.toml # Embedding model configurations
+└── requirements.txt # Python dependencies
+```
+
+## Key Components
+
+### 1. Main Application (`main.py`)
+- FastAPI app initialization
+- CORS middleware configuration
+- Service dependency injection
+- Startup/shutdown event handlers
+
+### 2. Data Models (`models.py`)
+- `DatasetBuildRequest`: Request model for dataset building
+- `DatasetBuildResponse`: Response model for dataset building
+- `DatasetRecord`: Individual dataset record structure
+- `EmbeddingModelInfo`: Embedding model configuration
+
+### 3. Configuration (`config_loader.py`)
+- Loads embedding model configurations from TOML
+- Manages model parameters (dimensions, API endpoints, etc.)
+
+### 4. Database Layer (`database.py`)
+- PostgreSQL connection management
+- CRUD operations for video metadata, user labels, and embeddings
+- Optimized batch queries to avoid N+1 problems
+
+### 5. Embedding Service (`embedding_service.py`)
+- Integration with OpenAI-compatible embedding APIs
+- Text preprocessing and checksum generation
+- Batch embedding generation with rate limiting
+
+### 6. Dataset Building (`dataset_service.py`)
+- Complete dataset construction workflow:
+ 1. Pull raw text from database
+ 2. Text preprocessing (placeholder)
+ 3. Batch embedding generation with deduplication
+ 4. Embedding storage and caching
+ 5. Final dataset compilation with labels
+
+### 7. API Routes (`api_routes.py`)
+- `/api/v1/health`: Health check
+- `/api/v1/models/embedding`: List available embedding models
+- `/api/v1/dataset/build`: Build new dataset
+- `/api/v1/dataset/{id}`: Retrieve built dataset
+- `/api/v1/datasets`: List all datasets
+- `/api/v1/dataset/{id}`: Delete dataset
+
+## Dataset Building Flow
+
+1. **Model Selection**: Choose embedding model from TOML configuration
+2. **Data Retrieval**: Pull video metadata and user labels from PostgreSQL
+3. **Text Processing**: Combine title, description, and tags
+4. **Deduplication**: Generate checksums to avoid duplicate embeddings
+5. **Batch Processing**: Generate embeddings for new texts only
+6. **Storage**: Store embeddings in database with caching
+7. **Final Assembly**: Combine embeddings with labels using consensus mechanism
+
+## Configuration
+
+### Embedding Models (`embedding_models.toml`)
+```toml
+[text-embedding-3-large]
+name = "text-embedding-3-large"
+dimensions = 3072
+type = "openai"
+api_endpoint = "https://api.openai.com/v1/embeddings"
+max_tokens = 8192
+max_batch_size = 100
+```
+
+### Environment Variables
+- `DATABASE_URL`: PostgreSQL connection string
+- `OPENAI_API_KEY`: OpenAI API key for embedding generation
+
+## Usage
+
+### Start the Service
+```bash
+cd ml_new/training
+python main.py
+```
+
+### Build a Dataset
+```bash
+curl -X POST "http://localhost:8322/v1/dataset/build" \
+ -H "Content-Type: application/json" \
+ -d '{
+ "aid_list": [170001, 170002, 170003],
+ "embedding_model": "text-embedding-3-large",
+ "force_regenerate": false
+ }'
+```
+
+### Check Health
+```bash
+curl "http://localhost:8322/v1/health"
+```
+
+### List Embedding Models
+```bash
+curl "http://localhost:8322/v1/models/embedding"
+```
+
+## Features
+
+- **High Performance**: Optimized database queries with batch operations
+- **Deduplication**: Text-level deduplication using MD5 checksums
+- **Consensus Labels**: Majority vote mechanism for user annotations
+- **Batch Processing**: Efficient embedding generation and storage
+- **Error Handling**: Comprehensive error handling and logging
+- **Async Support**: Fully asynchronous operations for scalability
+- **CORS Enabled**: Ready for frontend integration
+
+## Production Considerations
+
+- Replace in-memory dataset storage with database
+- Add authentication and authorization
+- Implement rate limiting for API endpoints
+- Add monitoring and metrics collection
+- Configure proper logging levels
+- Set up database connection pooling
+- Add API documentation with OpenAPI/Swagger
\ No newline at end of file
diff --git a/ml_new/training/api_routes.py b/ml_new/training/api_routes.py
new file mode 100644
index 0000000..ea68779
--- /dev/null
+++ b/ml_new/training/api_routes.py
@@ -0,0 +1,196 @@
+"""
+API routes for the ML training service
+"""
+
+import logging
+import uuid
+
+from fastapi import APIRouter, HTTPException, BackgroundTasks
+from fastapi.responses import JSONResponse
+
+from config_loader import config_loader
+from models import DatasetBuildRequest, DatasetBuildResponse
+from dataset_service import DatasetBuilder
+
+
+logger = logging.getLogger(__name__)
+
+# Create router
+router = APIRouter(prefix="/v1")
+
+# Global dataset builder instance (will be set by main.py)
+dataset_builder: DatasetBuilder = None
+
+
+def set_dataset_builder(builder: DatasetBuilder):
+ """Set the global dataset builder instance"""
+ global dataset_builder
+ dataset_builder = builder
+
+
+@router.get("/health")
+async def health_check():
+ """Health check endpoint"""
+ if not dataset_builder:
+ return JSONResponse(
+ status_code=503,
+ content={"status": "unavailable", "message": "Dataset builder not initialized"}
+ )
+
+ try:
+ # Check embedding service health
+ embedding_health = await dataset_builder.embedding_service.health_check()
+ except Exception as e:
+ embedding_health = {"status": "unhealthy", "error": str(e)}
+
+ # Check database connection (pool should already be initialized)
+ db_status = "disconnected"
+ if dataset_builder.db_manager.is_connected:
+ try:
+ response = await dataset_builder.db_manager.pool.fetch("SELECT 1 FROM information_schema.tables")
+ db_status = "connected" if response else "disconnected"
+ except Exception as e:
+ db_status = f"error: {str(e)}"
+
+ return {
+ "status": "healthy",
+ "service": "ml-training-api",
+ "embedding_service": embedding_health,
+ "database": db_status,
+ "available_models": list(config_loader.get_embedding_models().keys())
+ }
+
+
+@router.get("/models/embedding")
+async def get_embedding_models():
+ """Get available embedding models"""
+ return {
+ "models": {
+ name: {
+ "name": config.name,
+ "dimensions": config.dimensions,
+ "type": config.type,
+ "api_endpoint": config.api_endpoint,
+ "max_tokens": config.max_tokens,
+ "max_batch_size": config.max_batch_size
+ }
+ for name, config in config_loader.get_embedding_models().items()
+ }
+ }
+
+
+@router.post("/dataset/build", response_model=DatasetBuildResponse)
+async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks: BackgroundTasks):
+ """Build dataset endpoint"""
+
+ if not dataset_builder:
+ raise HTTPException(status_code=503, detail="Dataset builder not available")
+
+ # Validate embedding model
+ if request.embedding_model not in config_loader.get_embedding_models():
+ raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
+
+ dataset_id = str(uuid.uuid4())
+ # Start background task for dataset building
+ background_tasks.add_task(
+ dataset_builder.build_dataset,
+ dataset_id,
+ request.aid_list,
+ request.embedding_model,
+ request.force_regenerate
+ )
+
+ return DatasetBuildResponse(
+ dataset_id=dataset_id,
+ total_records=len(request.aid_list),
+ status="started",
+ message="Dataset building started"
+ )
+
+
+@router.get("/dataset/{dataset_id}")
+async def get_dataset_endpoint(dataset_id: str):
+ """Get built dataset by ID"""
+
+ if not dataset_builder:
+ raise HTTPException(status_code=503, detail="Dataset builder not available")
+
+ if not dataset_builder.dataset_exists(dataset_id):
+ raise HTTPException(status_code=404, detail="Dataset not found")
+
+ dataset_info = dataset_builder.get_dataset(dataset_id)
+
+ if "error" in dataset_info:
+ raise HTTPException(status_code=500, detail=dataset_info["error"])
+
+ return {
+ "dataset_id": dataset_id,
+ "dataset": dataset_info["dataset"],
+ "stats": dataset_info["stats"],
+ "created_at": dataset_info["created_at"]
+ }
+
+
+@router.get("/datasets")
+async def list_datasets():
+ """List all built datasets"""
+
+ if not dataset_builder:
+ raise HTTPException(status_code=503, detail="Dataset builder not available")
+
+ datasets = []
+ for dataset_id, dataset_info in dataset_builder.dataset_storage.items():
+ if "error" not in dataset_info:
+ datasets.append({
+ "dataset_id": dataset_id,
+ "stats": dataset_info["stats"],
+ "created_at": dataset_info["created_at"]
+ })
+
+ return {"datasets": datasets}
+
+
+@router.delete("/dataset/{dataset_id}")
+async def delete_dataset_endpoint(dataset_id: str):
+ """Delete a built dataset"""
+
+ if not dataset_builder:
+ raise HTTPException(status_code=503, detail="Dataset builder not available")
+
+ if dataset_builder.delete_dataset(dataset_id):
+ return {"message": f"Dataset {dataset_id} deleted successfully"}
+ else:
+ raise HTTPException(status_code=404, detail="Dataset not found")
+
+
+@router.get("/datasets")
+async def list_datasets_endpoint():
+ """List all built datasets"""
+
+ if not dataset_builder:
+ raise HTTPException(status_code=503, detail="Dataset builder not available")
+
+ datasets = dataset_builder.list_datasets()
+ return {"datasets": datasets}
+
+
+@router.get("/datasets/stats")
+async def get_dataset_stats_endpoint():
+ """Get overall statistics about stored datasets"""
+
+ if not dataset_builder:
+ raise HTTPException(status_code=503, detail="Dataset builder not available")
+
+ stats = dataset_builder.get_dataset_stats()
+ return stats
+
+
+@router.post("/datasets/cleanup")
+async def cleanup_datasets_endpoint(max_age_days: int = 30):
+ """Remove datasets older than specified days"""
+
+ if not dataset_builder:
+ raise HTTPException(status_code=503, detail="Dataset builder not available")
+
+ await dataset_builder.cleanup_old_datasets(max_age_days)
+ return {"message": f"Cleanup completed for datasets older than {max_age_days} days"}
\ No newline at end of file
diff --git a/ml_new/training/config_loader.py b/ml_new/training/config_loader.py
new file mode 100644
index 0000000..f7a93bb
--- /dev/null
+++ b/ml_new/training/config_loader.py
@@ -0,0 +1,85 @@
+"""
+Configuration loader for embedding models and other settings
+"""
+
+import toml
+import os
+from typing import Dict
+from pydantic import BaseModel
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class EmbeddingModelConfig(BaseModel):
+ name: str
+ dimensions: int
+ type: str
+ api_endpoint: str = "https://api.openai.com/v1"
+ max_tokens: int = 8191
+ max_batch_size: int = 8
+ api_key_env: str = "OPENAI_API_KEY"
+
+
+class ConfigLoader:
+ def __init__(self, config_path: str = None):
+ if config_path is None:
+ # Default to the embedding_models.toml file we created
+ config_path = os.path.join(
+ os.path.dirname(__file__), "embedding_models.toml"
+ )
+
+ self.config_path = config_path
+ self.embedding_models: Dict[str, EmbeddingModelConfig] = {}
+ self._load_config()
+
+ def _load_config(self):
+ """Load configuration from TOML file"""
+ try:
+ if not os.path.exists(self.config_path):
+ logger.warning(f"Config file not found: {self.config_path}")
+ return
+
+ with open(self.config_path, "r", encoding="utf-8") as f:
+ config_data = toml.load(f)
+
+ # Load embedding models
+ if "models" not in config_data:
+ return
+
+ for model_key, model_data in config_data["models"].items():
+ self.embedding_models[model_key] = EmbeddingModelConfig(
+ **model_data
+ )
+
+ logger.info(
+ f"Loaded {len(self.embedding_models)} embedding models from {self.config_path}"
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to load config from {self.config_path}: {e}")
+
+ def get_embedding_models(self) -> Dict[str, EmbeddingModelConfig]:
+ """Get all available embedding models"""
+ return self.embedding_models.copy()
+
+ def get_embedding_model(self, model_name: str) -> EmbeddingModelConfig:
+ """Get specific embedding model config"""
+ if model_name not in self.embedding_models:
+ raise ValueError(
+ f"Embedding model '{model_name}' not found in configuration"
+ )
+ return self.embedding_models[model_name]
+
+ def list_model_names(self) -> list:
+ """Get list of available model names"""
+ return list(self.embedding_models.keys())
+
+ def reload_config(self):
+ """Reload configuration from file"""
+ self.embedding_models = {}
+ self._load_config()
+
+
+# Global config loader instance
+config_loader = ConfigLoader()
diff --git a/ml_new/training/database.py b/ml_new/training/database.py
new file mode 100644
index 0000000..50f2d53
--- /dev/null
+++ b/ml_new/training/database.py
@@ -0,0 +1,284 @@
+"""
+Database connection and operations for ML training service
+"""
+
+import os
+import hashlib
+from typing import List, Dict, Optional, Any
+from datetime import datetime
+import asyncpg
+import logging
+from config_loader import config_loader
+from dotenv import load_dotenv
+
+load_dotenv()
+
+logger = logging.getLogger(__name__)
+
+# Database configuration
+DATABASE_URL = os.getenv("DATABASE_URL")
+
+class DatabaseManager:
+ def __init__(self):
+ self.pool: Optional[asyncpg.Pool] = None
+
+ async def connect(self):
+ """Initialize database connection pool"""
+ try:
+ self.pool = await asyncpg.create_pool(DATABASE_URL, min_size=5, max_size=20)
+
+ logger.info("Database connection pool initialized")
+ except Exception as e:
+ logger.error(f"Failed to connect to database: {e}")
+ raise
+
+ async def close(self):
+ """Close database connection pool"""
+ if self.pool:
+ await self.pool.close()
+ logger.info("Database connection pool closed")
+
+ @property
+ def is_connected(self) -> bool:
+ """Check if database connection pool is initialized"""
+ return self.pool is not None
+
+ async def get_embedding_models(self):
+ """Get available embedding models from config"""
+ return config_loader.get_embedding_models()
+
+ async def get_video_metadata(
+ self, aid_list: List[int]
+ ) -> Dict[int, Dict[str, Any]]:
+ """Get video metadata for given AIDs"""
+ if not aid_list:
+ return {}
+
+ async with self.pool.acquire() as conn:
+ query = """
+ SELECT aid, title, description, tags
+ FROM bilibili_metadata
+ WHERE aid = ANY($1::bigint[])
+ """
+ rows = await conn.fetch(query, aid_list)
+
+ result = {}
+ for row in rows:
+ result[int(row["aid"])] = {
+ "aid": int(row["aid"]),
+ "title": row["title"] or "",
+ "description": row["description"] or "",
+ "tags": row["tags"] or "",
+ }
+
+ return result
+
+ async def get_user_labels(
+ self, aid_list: List[int]
+ ) -> Dict[int, List[Dict[str, Any]]]:
+ """Get user labels for given AIDs, only the latest label per user"""
+ if not aid_list:
+ return {}
+
+ async with self.pool.acquire() as conn:
+ query = """
+ WITH latest_labels AS (
+ SELECT DISTINCT ON (aid, "user")
+ aid, "user", label, created_at
+ FROM internal.video_type_label
+ WHERE aid = ANY($1::bigint[])
+ ORDER BY aid, "user", created_at DESC
+ )
+ SELECT aid, "user", label, created_at
+ FROM latest_labels
+ ORDER BY aid, "user"
+ """
+ rows = await conn.fetch(query, aid_list)
+
+ result = {}
+ for row in rows:
+ aid = int(row["aid"])
+ if aid not in result:
+ result[aid] = []
+
+ result[aid].append(
+ {
+ "user": row["user"],
+ "label": bool(row["label"]),
+ "created_at": row["created_at"].isoformat(),
+ }
+ )
+
+ return result
+
+ async def get_existing_embeddings(
+ self, checksums: List[str], model_name: str
+ ) -> Dict[str, Dict[str, Any]]:
+ """Get existing embeddings for given checksums and model"""
+ if not checksums:
+ return {}
+
+ async with self.pool.acquire() as conn:
+ query = """
+ SELECT data_checksum, vec_2048, vec_1536, vec_1024, created_at
+ FROM internal.embeddings
+ WHERE model_name = $1 AND data_checksum = ANY($2::text[])
+ """
+ rows = await conn.fetch(query, model_name, checksums)
+
+ result = {}
+ for row in rows:
+ checksum = row["data_checksum"]
+
+ # Convert vector strings to lists if they exist
+ vec_2048 = self._parse_vector_string(row["vec_2048"]) if row["vec_2048"] else None
+ vec_1536 = self._parse_vector_string(row["vec_1536"]) if row["vec_1536"] else None
+ vec_1024 = self._parse_vector_string(row["vec_1024"]) if row["vec_1024"] else None
+
+ result[checksum] = {
+ "checksum": checksum,
+ "vec_2048": vec_2048,
+ "vec_1536": vec_1536,
+ "vec_1024": vec_1024,
+ "created_at": row["created_at"].isoformat(),
+ }
+
+ return result
+
+ def _parse_vector_string(self, vector_str: str) -> List[float]:
+ """Parse vector string format '[1.0,2.0,3.0]' back to list"""
+ if not vector_str:
+ return []
+
+ try:
+ # Remove brackets and split by comma
+ vector_str = vector_str.strip()
+ if vector_str.startswith('[') and vector_str.endswith(']'):
+ vector_str = vector_str[1:-1]
+
+ return [float(x.strip()) for x in vector_str.split(',') if x.strip()]
+ except Exception as e:
+ logger.warning(f"Failed to parse vector string '{vector_str}': {e}")
+ return []
+
+ async def insert_embeddings(self, embeddings_data: List[Dict[str, Any]]) -> None:
+ """Batch insert embeddings into database"""
+ if not embeddings_data:
+ return
+
+ async with self.pool.acquire() as conn:
+ async with conn.transaction():
+ for data in embeddings_data:
+ # Determine which vector column to use based on dimensions
+ vec_column = f"vec_{data['dimensions']}"
+
+ # Convert vector list to string format for PostgreSQL
+ vector_str = "[" + ",".join(map(str, data["vector"])) + "]"
+
+ query = f"""
+ INSERT INTO internal.embeddings
+ (model_name, data_checksum, {vec_column}, created_at)
+ VALUES ($1, $2, $3, $4)
+ ON CONFLICT (data_checksum) DO NOTHING
+ """
+
+ await conn.execute(
+ query,
+ data["model_name"],
+ data["checksum"],
+ vector_str,
+ datetime.now(),
+ )
+
+ async def get_final_dataset(
+ self, aid_list: List[int], model_name: str
+ ) -> List[Dict[str, Any]]:
+ """Get final dataset with embeddings and labels"""
+ if not aid_list:
+ return []
+
+ # Get video metadata
+ metadata = await self.get_video_metadata(aid_list)
+
+ # Get user labels (latest per user)
+ labels = await self.get_user_labels(aid_list)
+
+ # Prepare text data for embedding
+ text_data = []
+ aid_to_text = {}
+
+ for aid in aid_list:
+ if aid in metadata:
+ # Combine title, description, and tags for embedding
+ text_parts = [
+ metadata[aid]["title"],
+ metadata[aid]["description"],
+ metadata[aid]["tags"],
+ ]
+ combined_text = " ".join(filter(None, text_parts))
+
+ # Create checksum for deduplication
+ checksum = hashlib.md5(combined_text.encode("utf-8")).hexdigest()
+
+ text_data.append(
+ {"aid": aid, "text": combined_text, "checksum": checksum}
+ )
+ aid_to_text[checksum] = aid
+
+ # Get checksums = [ existing embeddings
+ checks = [item["checksum"] for item in text_data]
+ existing_embeddings = await self.get_existing_embeddings(checks)
+
+ # ums, model_name Prepare final dataset
+ dataset = []
+
+ for item in text_data:
+ aid = item["aid"]
+ checksum = item["checksum"]
+
+ # Get embedding vector
+ embedding_vector = None
+ if checksum in existing_embeddings:
+ # Use existing embedding
+ emb_data = existing_embeddings[checksum]
+ if emb_data["vec_1536"]:
+ embedding_vector = emb_data["vec_1536"]
+ elif emb_data["vec_2048"]:
+ embedding_vector = emb_data["vec_2048"]
+ elif emb_data["vec_1024"]:
+ embedding_vector = emb_data["vec_1024"]
+
+ # Get labels for this aid
+ aid_labels = labels.get(aid, [])
+
+ # Determine final label using consensus (majority vote)
+ if aid_labels:
+ positive_votes = sum(1 for lbl in aid_labels if lbl["label"])
+ final_label = positive_votes > len(aid_labels) / 2
+ else:
+ final_label = None # No labels available
+
+ # Check for inconsistent labels
+ inconsistent = len(aid_labels) > 1 and (
+ sum(1 for lbl in aid_labels if lbl["label"]) != 0
+ and sum(1 for lbl in aid_labels if lbl["label"]) != len(aid_labels)
+ )
+
+ if embedding_vector and final_label is not None:
+ dataset.append(
+ {
+ "aid": aid,
+ "embedding": embedding_vector,
+ "label": final_label,
+ "metadata": metadata.get(aid, {}),
+ "user_labels": aid_labels,
+ "inconsistent": inconsistent,
+ "text_checksum": checksum,
+ }
+ )
+
+ return dataset
+
+
+# Global database manager instance
+db_manager = DatabaseManager()
diff --git a/ml_new/training/dataset_service.py b/ml_new/training/dataset_service.py
new file mode 100644
index 0000000..7977780
--- /dev/null
+++ b/ml_new/training/dataset_service.py
@@ -0,0 +1,373 @@
+"""
+Dataset building service - handles the complete dataset construction flow
+"""
+
+import os
+import json
+import logging
+import asyncio
+from pathlib import Path
+from typing import List, Dict, Any, Optional
+from datetime import datetime
+
+from database import DatabaseManager
+from embedding_service import EmbeddingService
+from config_loader import config_loader
+
+
+logger = logging.getLogger(__name__)
+
+
+class DatasetBuilder:
+ """Service for building datasets with the specified flow"""
+
+ def __init__(self, db_manager: DatabaseManager, embedding_service: EmbeddingService, storage_dir: str = "datasets"):
+ self.db_manager = db_manager
+ self.embedding_service = embedding_service
+ self.storage_dir = Path(storage_dir)
+ self.storage_dir.mkdir(exist_ok=True)
+
+ # Load existing datasets from file system
+ self.dataset_storage: Dict[str, Dict] = self._load_all_datasets()
+
+
+ def _get_dataset_file_path(self, dataset_id: str) -> Path:
+ """Get file path for dataset"""
+ return self.storage_dir / f"{dataset_id}.json"
+
+
+ def _load_dataset_from_file(self, dataset_id: str) -> Optional[Dict[str, Any]]:
+ """Load dataset from file"""
+ file_path = self._get_dataset_file_path(dataset_id)
+ if not file_path.exists():
+ return None
+
+ try:
+ with open(file_path, 'r', encoding='utf-8') as f:
+ return json.load(f)
+ except Exception as e:
+ logger.error(f"Failed to load dataset {dataset_id} from file: {e}")
+ return None
+
+
+ def _save_dataset_to_file(self, dataset_id: str, dataset_data: Dict[str, Any]) -> bool:
+ """Save dataset to file"""
+ file_path = self._get_dataset_file_path(dataset_id)
+
+ try:
+ with open(file_path, 'w', encoding='utf-8') as f:
+ json.dump(dataset_data, f, ensure_ascii=False, indent=2)
+ return True
+ except Exception as e:
+ logger.error(f"Failed to save dataset {dataset_id} to file: {e}")
+ return False
+
+
+ def _load_all_datasets(self) -> Dict[str, Dict]:
+ """Load all datasets from file system"""
+ datasets = {}
+
+ try:
+ for file_path in self.storage_dir.glob("*.json"):
+ dataset_id = file_path.stem
+ dataset_data = self._load_dataset_from_file(dataset_id)
+ if dataset_data:
+ datasets[dataset_id] = dataset_data
+ logger.info(f"Loaded {len(datasets)} datasets from file system")
+ except Exception as e:
+ logger.error(f"Failed to load datasets from file system: {e}")
+
+ return datasets
+
+
+ async def cleanup_old_datasets(self, max_age_days: int = 30):
+ """Remove datasets older than specified days"""
+ try:
+ cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60)
+ removed_count = 0
+
+ for dataset_id in list(self.dataset_storage.keys()):
+ dataset_info = self.dataset_storage[dataset_id]
+ if "created_at" in dataset_info:
+ try:
+ created_time = datetime.fromisoformat(dataset_info["created_at"]).timestamp()
+ if created_time < cutoff_time:
+ # Remove from memory
+ del self.dataset_storage[dataset_id]
+
+ # Remove file
+ file_path = self._get_dataset_file_path(dataset_id)
+ if file_path.exists():
+ file_path.unlink()
+
+ removed_count += 1
+ except Exception as e:
+ logger.warning(f"Failed to process dataset {dataset_id} for cleanup: {e}")
+
+ if removed_count > 0:
+ logger.info(f"Cleaned up {removed_count} old datasets")
+
+ except Exception as e:
+ logger.error(f"Failed to cleanup old datasets: {e}")
+
+ async def build_dataset(self, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool = False) -> str:
+ """
+ Build dataset with the specified flow:
+ 1. Select embedding model (from TOML config)
+ 2. Pull raw text from database
+ 3. Preprocess (placeholder for now)
+ 4. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
+ 5. Write to embeddings table
+ 6. Pull all needed embeddings to create dataset with format: embeddings, label
+ """
+
+ try:
+ logger.info(f"Starting dataset building task {dataset_id}")
+
+ EMBEDDING_MODELS = config_loader.get_embedding_models()
+
+ # Get model configuration
+ if embedding_model not in EMBEDDING_MODELS:
+ raise ValueError(f"Invalid embedding model: {embedding_model}")
+
+ model_config = EMBEDDING_MODELS[embedding_model]
+
+ # Step 1: Get video metadata from database
+ metadata = await self.db_manager.get_video_metadata(aid_list)
+
+ # Step 2: Get user labels
+ labels = await self.db_manager.get_user_labels(aid_list)
+
+ # Step 3: Prepare text data and checksums
+ text_data = []
+
+ for aid in aid_list:
+ if aid in metadata:
+ # Combine title, description, tags
+ combined_text = self.embedding_service.combine_video_text(
+ metadata[aid]['title'],
+ metadata[aid]['description'],
+ metadata[aid]['tags']
+ )
+
+ # Create checksum for deduplication
+ checksum = self.embedding_service.create_text_checksum(combined_text)
+
+ text_data.append({
+ 'aid': aid,
+ 'text': combined_text,
+ 'checksum': checksum
+ })
+
+ # Step 4: Check existing embeddings
+ checksums = [item['checksum'] for item in text_data]
+ existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model)
+
+ # Step 5: Generate new embeddings for texts that don't have them
+ new_embeddings_needed = []
+ for item in text_data:
+ if item['checksum'] not in existing_embeddings or force_regenerate:
+ new_embeddings_needed.append(item['text'])
+
+ new_embeddings_count = 0
+ if new_embeddings_needed:
+ logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
+ generated_embeddings = await self.embedding_service.generate_embeddings_batch(
+ new_embeddings_needed,
+ embedding_model
+ )
+
+ # Step 6: Store new embeddings in database
+ embeddings_to_store = []
+ for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)):
+ checksum = self.embedding_service.create_text_checksum(text)
+ embeddings_to_store.append({
+ 'model_name': embedding_model,
+ 'checksum': checksum,
+ 'dimensions': model_config.dimensions,
+ 'vector': embedding
+ })
+
+ await self.db_manager.insert_embeddings(embeddings_to_store)
+ new_embeddings_count = len(embeddings_to_store)
+
+ # Update existing embeddings cache
+ for emb_data in embeddings_to_store:
+ existing_embeddings[emb_data['checksum']] = {
+ 'checksum': emb_data['checksum'],
+ f'vec_{model_config.dimensions}': emb_data['vector']
+ }
+
+ # Step 7: Build final dataset
+ dataset = []
+ inconsistent_count = 0
+
+ for item in text_data:
+ aid = item['aid']
+ checksum = item['checksum']
+
+ # Get embedding vector
+ embedding_vector = None
+ if checksum in existing_embeddings:
+ vec_key = f'vec_{model_config.dimensions}'
+ if vec_key in existing_embeddings[checksum]:
+ embedding_vector = existing_embeddings[checksum][vec_key]
+
+ # Get labels for this aid
+ aid_labels = labels.get(aid, [])
+
+ # Determine final label using consensus (majority vote)
+ final_label = None
+ if aid_labels:
+ positive_votes = sum(1 for lbl in aid_labels if lbl['label'])
+ final_label = positive_votes > len(aid_labels) / 2
+
+ # Check for inconsistent labels
+ inconsistent = len(aid_labels) > 1 and (
+ sum(1 for lbl in aid_labels if lbl['label']) != 0 and
+ sum(1 for lbl in aid_labels if lbl['label']) != len(aid_labels)
+ )
+
+ if inconsistent:
+ inconsistent_count += 1
+
+ if embedding_vector and final_label is not None:
+ dataset.append({
+ 'aid': aid,
+ 'embedding': embedding_vector,
+ 'label': final_label,
+ 'metadata': metadata.get(aid, {}),
+ 'user_labels': aid_labels,
+ 'inconsistent': inconsistent,
+ 'text_checksum': checksum
+ })
+
+ reused_count = len(dataset) - new_embeddings_count
+
+ logger.info(f"Dataset building completed: {len(dataset)} records, {new_embeddings_count} new, {reused_count} reused, {inconsistent_count} inconsistent")
+
+ # Prepare dataset data
+ dataset_data = {
+ 'dataset': dataset,
+ 'stats': {
+ 'total_records': len(dataset),
+ 'new_embeddings': new_embeddings_count,
+ 'reused_embeddings': reused_count,
+ 'inconsistent_labels': inconsistent_count,
+ 'embedding_model': embedding_model
+ },
+ 'created_at': datetime.now().isoformat()
+ }
+
+ # Save to file and memory cache
+ if self._save_dataset_to_file(dataset_id, dataset_data):
+ self.dataset_storage[dataset_id] = dataset_data
+ logger.info(f"Dataset {dataset_id} saved to file system")
+ else:
+ logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only")
+ self.dataset_storage[dataset_id] = dataset_data
+
+ return dataset_id
+
+ except Exception as e:
+ logger.error(f"Dataset building failed for {dataset_id}: {str(e)}")
+
+ # Store error information
+ error_data = {
+ 'error': str(e),
+ 'created_at': datetime.now().isoformat()
+ }
+
+ # Try to save error to file as well
+ self._save_dataset_to_file(dataset_id, error_data)
+ self.dataset_storage[dataset_id] = error_data
+ raise
+
+ def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
+ """Get built dataset by ID"""
+ # First check memory cache
+ if dataset_id in self.dataset_storage:
+ return self.dataset_storage[dataset_id]
+
+ # If not in memory, try to load from file
+ dataset_data = self._load_dataset_from_file(dataset_id)
+ if dataset_data:
+ # Add to memory cache
+ self.dataset_storage[dataset_id] = dataset_data
+ return dataset_data
+
+ return None
+
+ def dataset_exists(self, dataset_id: str) -> bool:
+ """Check if dataset exists"""
+ # Check memory cache first
+ if dataset_id in self.dataset_storage:
+ return True
+
+ # Check file system
+ return self._get_dataset_file_path(dataset_id).exists()
+
+ def delete_dataset(self, dataset_id: str) -> bool:
+ """Delete dataset from both memory and file system"""
+ try:
+ # Remove from memory
+ if dataset_id in self.dataset_storage:
+ del self.dataset_storage[dataset_id]
+
+ # Remove file
+ file_path = self._get_dataset_file_path(dataset_id)
+ if file_path.exists():
+ file_path.unlink()
+ logger.info(f"Dataset {dataset_id} deleted from file system")
+ return True
+ else:
+ logger.warning(f"Dataset file {dataset_id} not found for deletion")
+ return False
+
+ except Exception as e:
+ logger.error(f"Failed to delete dataset {dataset_id}: {e}")
+ return False
+
+ def list_datasets(self) -> List[Dict[str, Any]]:
+ """List all datasets with their basic information"""
+ datasets = []
+
+ for dataset_id, dataset_info in self.dataset_storage.items():
+ if "error" not in dataset_info:
+ datasets.append({
+ "dataset_id": dataset_id,
+ "stats": dataset_info["stats"],
+ "created_at": dataset_info["created_at"]
+ })
+
+ # Sort by creation time (newest first)
+ datasets.sort(key=lambda x: x["created_at"], reverse=True)
+
+ return datasets
+
+ def get_dataset_stats(self) -> Dict[str, Any]:
+ """Get overall statistics about stored datasets"""
+ total_datasets = len(self.dataset_storage)
+ error_datasets = sum(1 for data in self.dataset_storage.values() if "error" in data)
+ valid_datasets = total_datasets - error_datasets
+
+ total_records = 0
+ total_new_embeddings = 0
+ total_reused_embeddings = 0
+
+ for dataset_info in self.dataset_storage.values():
+ if "stats" in dataset_info:
+ stats = dataset_info["stats"]
+ total_records += stats.get("total_records", 0)
+ total_new_embeddings += stats.get("new_embeddings", 0)
+ total_reused_embeddings += stats.get("reused_embeddings", 0)
+
+ return {
+ "total_datasets": total_datasets,
+ "valid_datasets": valid_datasets,
+ "error_datasets": error_datasets,
+ "total_records": total_records,
+ "total_new_embeddings": total_new_embeddings,
+ "total_reused_embeddings": total_reused_embeddings,
+ "storage_directory": str(self.storage_dir)
+ }
\ No newline at end of file
diff --git a/ml_new/training/embedding_models.toml b/ml_new/training/embedding_models.toml
new file mode 100644
index 0000000..f650fd0
--- /dev/null
+++ b/ml_new/training/embedding_models.toml
@@ -0,0 +1,12 @@
+# Embedding Models Configuration
+
+model = "qwen3-embedding"
+
+[models.qwen3-embedding]
+name = "text-embedding-v4"
+dimensions = 2048
+type = "openai-compatible"
+api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
+max_tokens = 8192
+max_batch_size = 10
+api_key_env = "ALIYUN_KEY"
diff --git a/ml_new/training/embedding_service.py b/ml_new/training/embedding_service.py
new file mode 100644
index 0000000..8bc16b8
--- /dev/null
+++ b/ml_new/training/embedding_service.py
@@ -0,0 +1,143 @@
+"""
+Embedding service for generating embeddings using OpenAI-compatible API
+"""
+import asyncio
+import hashlib
+from typing import List, Dict, Any, Optional
+import logging
+from openai import AsyncOpenAI
+import os
+from config_loader import config_loader
+from dotenv import load_dotenv
+
+load_dotenv()
+
+logger = logging.getLogger(__name__)
+
+class EmbeddingService:
+ def __init__(self):
+ # Get configuration from config loader
+ self.embedding_models = config_loader.get_embedding_models()
+
+ # Initialize OpenAI client (will be configured per model)
+ self.clients: Dict[str, AsyncOpenAI] = {}
+ self._initialize_clients()
+
+ # Rate limiting
+ self.max_requests_per_minute = int(os.getenv("MAX_REQUESTS_PER_MINUTE", "100"))
+ self.request_interval = 60.0 / self.max_requests_per_minute
+
+ def _initialize_clients(self):
+ """Initialize OpenAI clients for different models/endpoints"""
+ for model_name, model_config in self.embedding_models.items():
+ if model_config.type == "openai-compatible":
+ # Get API key from environment variable specified in config
+ api_key = os.getenv(model_config.api_key_env)
+
+ self.clients[model_name] = AsyncOpenAI(
+ api_key=api_key,
+ base_url=model_config.api_endpoint
+ )
+ logger.info(f"Initialized client for model {model_name}")
+
+ async def generate_embeddings_batch(
+ self,
+ texts: List[str],
+ model: str,
+ batch_size: Optional[int] = None
+ ) -> List[List[float]]:
+ """Generate embeddings for a batch of texts"""
+
+ # Get model configuration
+ if model not in self.embedding_models:
+ raise ValueError(f"Model '{model}' not found in configuration")
+
+ model_config = self.embedding_models[model]
+
+ # Use model's max_batch_size if not specified
+ if batch_size is None:
+ batch_size = model_config.max_batch_size
+
+ # Validate model and get expected dimensions
+ expected_dims = model_config.dimensions
+
+ if model not in self.clients:
+ raise ValueError(f"No client configured for model '{model}'")
+
+ client = self.clients[model]
+ all_embeddings = []
+
+ # Process in batches to avoid API limits
+ for i in range(0, len(texts), batch_size):
+ batch = texts[i:i + batch_size]
+
+ try:
+ # Rate limiting
+ if i > 0:
+ await asyncio.sleep(self.request_interval)
+
+ # Generate embeddings
+ response = await client.embeddings.create(
+ model=model_config.name,
+ input=batch,
+ dimensions=expected_dims
+ )
+
+ batch_embeddings = [data.embedding for data in response.data]
+ all_embeddings.extend(batch_embeddings)
+
+ logger.info(f"Generated embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
+
+ except Exception as e:
+ logger.error(f"Error generating embeddings for batch {i//batch_size + 1}: {e}")
+ # For now, fill with zeros as fallback (could implement retry logic)
+ zero_embedding = [0.0] * expected_dims
+ all_embeddings.extend([zero_embedding] * len(batch))
+
+ return all_embeddings
+
+ def create_text_checksum(self, text: str) -> str:
+ """Create MD5 checksum for text deduplication"""
+ return hashlib.md5(text.encode('utf-8')).hexdigest()
+
+ def combine_video_text(self, title: str, description: str, tags: str) -> str:
+ """Combine video metadata into a single text for embedding"""
+ parts = [
+ title.strip() if "标题:"+title else "",
+ description.strip() if "简介:"+description else "",
+ tags.strip() if "标签:"+tags else ""
+ ]
+
+ # Filter out empty parts and join
+ combined = '\n'.join(filter(None, parts))
+ return combined
+
+ async def health_check(self) -> Dict[str, Any]:
+ """Check if embedding service is healthy"""
+ try:
+ # Test with a simple embedding using the first available model
+ model_name = list(self.embedding_models.keys())[0]
+
+ test_embedding = await self.generate_embeddings_batch(
+ ["health check"],
+ model_name,
+ batch_size=1
+ )
+
+ return {
+ "status": "healthy",
+ "service": "embedding_service",
+ "model": model_name,
+ "dimensions": len(test_embedding[0]) if test_embedding else 0,
+ "available_models": list(self.embedding_models.keys())
+ }
+
+ except Exception as e:
+ return {
+ "status": "unhealthy",
+ "service": "embedding_service",
+ "error": str(e)
+ }
+
+# Global embedding service instance
+embedding_service = EmbeddingService()
\ No newline at end of file
diff --git a/ml_new/training/main.py b/ml_new/training/main.py
index 93211e2..b510e3d 100644
--- a/ml_new/training/main.py
+++ b/ml_new/training/main.py
@@ -1,246 +1,113 @@
-from fastapi import FastAPI, HTTPException, BackgroundTasks
-from fastapi.middleware.cors import CORSMiddleware
-from pydantic import BaseModel
-from typing import List, Dict, Optional, Any
-import asyncio
-import uuid
-import logging
-from datetime import datetime
-import json
+"""
+Main FastAPI application for ML training service
+"""
-# Setup logging
+import logging
+import uvicorn
+from contextlib import asynccontextmanager
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+
+from database import DatabaseManager
+from embedding_service import EmbeddingService
+from dataset_service import DatasetBuilder
+from api_routes import router, set_dataset_builder
+
+# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
-# Initialize FastAPI app
-app = FastAPI(
- title="CVSA ML Training API",
- version="1.0.0",
- description="ML training service for video classification"
-)
-# Enable CORS for web UI
-app.add_middleware(
- CORSMiddleware,
- allow_origins=["http://localhost:3000", "http://localhost:5173"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
-)
+# Global service instances
+db_manager = None
+embedding_service = None
+dataset_builder = None
-# Pydantic models
-class Hyperparameter(BaseModel):
- name: str
- type: str # 'number', 'boolean', 'select'
- value: Any
- range: Optional[tuple] = None
- options: Optional[List[str]] = None
- description: Optional[str] = None
-class TrainingConfig(BaseModel):
- learning_rate: float = 1e-4
- batch_size: int = 32
- epochs: int = 10
- early_stop: bool = True
- patience: int = 3
- embedding_model: str = "text-embedding-3-small"
-
-class TrainingRequest(BaseModel):
- experiment_name: str
- config: TrainingConfig
- dataset: Dict[str, Any]
-
-class TrainingStatus(BaseModel):
- experiment_id: str
- status: str # 'pending', 'running', 'completed', 'failed'
- progress: Optional[float] = None
- current_epoch: Optional[int] = None
- total_epochs: Optional[int] = None
- metrics: Optional[Dict[str, float]] = None
- error: Optional[str] = None
-
-class ExperimentResult(BaseModel):
- experiment_id: str
- experiment_name: str
- config: TrainingConfig
- metrics: Dict[str, float]
- created_at: str
- status: str
-
-class EmbeddingRequest(BaseModel):
- texts: List[str]
- model: str
-
-# In-memory storage for experiments (in production, use database)
-training_sessions: Dict[str, Dict] = {}
-experiments: Dict[str, ExperimentResult] = {}
-
-# Default hyperparameters that will be dynamically discovered
-DEFAULT_HYPERPARAMETERS = [
- Hyperparameter(
- name="learning_rate",
- type="number",
- value=1e-4,
- range=(1e-6, 1e-2),
- description="Learning rate for optimizer"
- ),
- Hyperparameter(
- name="batch_size",
- type="number",
- value=32,
- range=(8, 256),
- description="Training batch size"
- ),
- Hyperparameter(
- name="epochs",
- type="number",
- value=10,
- range=(1, 100),
- description="Number of training epochs"
- ),
- Hyperparameter(
- name="early_stop",
- type="boolean",
- value=True,
- description="Enable early stopping"
- ),
- Hyperparameter(
- name="patience",
- type="number",
- value=3,
- range=(1, 20),
- description="Early stopping patience"
- ),
- Hyperparameter(
- name="embedding_model",
- type="select",
- value="text-embedding-3-small",
- options=["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
- description="Embedding model to use"
- )
-]
-
-@app.get("/")
-async def root():
- return {"message": "CVSA ML Training API", "version": "1.0.0"}
-
-@app.get("/health")
-async def health_check():
- return {"status": "healthy", "service": "ml-training-api"}
-
-@app.get("/hyperparameters", response_model=List[Hyperparameter])
-async def get_hyperparameters():
- """Get all available hyperparameters for the current model"""
- return DEFAULT_HYPERPARAMETERS
-
-@app.post("/train")
-async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks):
- """Start a new training experiment"""
- experiment_id = str(uuid.uuid4())
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ """Application lifespan manager for startup and shutdown events"""
+ global db_manager, embedding_service, dataset_builder
- # Store training session
- training_sessions[experiment_id] = {
- "experiment_id": experiment_id,
- "experiment_name": request.experiment_name,
- "config": request.config.dict(),
- "dataset": request.dataset,
- "status": "pending",
- "created_at": datetime.now().isoformat()
- }
+ # Startup
+ logger.info("Initializing services...")
- # Start background training task
- background_tasks.add_task(run_training, experiment_id, request)
-
- return {"experiment_id": experiment_id}
-
-@app.get("/train/{experiment_id}/status", response_model=TrainingStatus)
-async def get_training_status(experiment_id: str):
- """Get training status for an experiment"""
- if experiment_id not in training_sessions:
- raise HTTPException(status_code=404, detail="Experiment not found")
-
- session = training_sessions[experiment_id]
-
- return TrainingStatus(
- experiment_id=experiment_id,
- status=session.get("status", "unknown"),
- progress=session.get("progress"),
- current_epoch=session.get("current_epoch"),
- total_epochs=session.get("total_epochs"),
- metrics=session.get("metrics"),
- error=session.get("error")
- )
-
-@app.get("/experiments", response_model=List[ExperimentResult])
-async def list_experiments():
- """List all experiments"""
- return list(experiments.values())
-
-@app.get("/experiments/{experiment_id}", response_model=ExperimentResult)
-async def get_experiment(experiment_id: str):
- """Get experiment details"""
- if experiment_id not in experiments:
- raise HTTPException(status_code=404, detail="Experiment not found")
-
- return experiments[experiment_id]
-
-@app.post("/embeddings")
-async def generate_embeddings(request: EmbeddingRequest):
- """Generate embeddings using OpenAI-compatible API"""
- # This is a placeholder implementation
- # In production, this would call actual embedding API
- embeddings = []
- for text in request.texts:
- # Mock embedding generation
- embedding = [0.1] * 1536 # Mock 1536-dimensional embedding
- embeddings.append(embedding)
-
- return embeddings
-
-async def run_training(experiment_id: str, request: TrainingRequest):
- """Background task to run training"""
try:
- session = training_sessions[experiment_id]
- session["status"] = "running"
- session["total_epochs"] = request.config.epochs
+ # Database manager
+ db_manager = DatabaseManager()
+ await db_manager.connect() # Initialize database connection pool
+ logger.info("Database manager initialized and connected")
- # Simulate training process
- for epoch in range(request.config.epochs):
- session["current_epoch"] = epoch + 1
- session["progress"] = (epoch + 1) / request.config.epochs
-
- # Simulate training metrics
- session["metrics"] = {
- "loss": max(0.0, 1.0 - (epoch + 1) * 0.1),
- "accuracy": min(0.95, 0.5 + (epoch + 1) * 0.05),
- "val_loss": max(0.0, 0.8 - (epoch + 1) * 0.08),
- "val_accuracy": min(0.92, 0.45 + (epoch + 1) * 0.04)
- }
-
- logger.info(f"Training epoch {epoch + 1}/{request.config.epochs}")
- await asyncio.sleep(1) # Simulate training time
+ # Embedding service
+ embedding_service = EmbeddingService()
+ logger.info("Embedding service initialized")
- # Training completed
- session["status"] = "completed"
- final_metrics = session["metrics"]
+ # Dataset builder
+ dataset_builder = DatasetBuilder(db_manager, embedding_service)
+ logger.info("Dataset builder initialized")
- # Store final experiment result
- experiments[experiment_id] = ExperimentResult(
- experiment_id=experiment_id,
- experiment_name=request.experiment_name,
- config=request.config,
- metrics=final_metrics,
- created_at=session["created_at"],
- status="completed"
- )
+ # Set global dataset builder instance
+ set_dataset_builder(dataset_builder)
- logger.info(f"Training completed for experiment {experiment_id}")
+ logger.info("All services initialized successfully")
except Exception as e:
- session["status"] = "failed"
- session["error"] = str(e)
- logger.error(f"Training failed for experiment {experiment_id}: {str(e)}")
+ logger.error(f"Failed to initialize services: {e}")
+ raise
+
+ # Yield control to the application
+ yield
+
+ # Shutdown
+ logger.info("Shutting down services...")
+
+ try:
+ if db_manager:
+ await db_manager.close()
+ logger.info("Database connection pool closed")
+ except Exception as e:
+ logger.error(f"Error during shutdown: {e}")
+
+
+def create_app() -> FastAPI:
+ """Create and configure FastAPI application"""
+
+ # Create FastAPI app with lifespan manager
+ app = FastAPI(
+ title="ML Training Service",
+ description="ML training, dataset building, and experiment management service",
+ version="1.0.0",
+ lifespan=lifespan
+ )
+
+ # Add CORS middleware
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"], # Configure appropriately for production
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+
+ # Include API routes
+ app.include_router(router)
+
+ return app
+
+
+def main():
+ """Main entry point"""
+ app = create_app()
+
+ # Run the application
+ uvicorn.run(
+ app,
+ host="0.0.0.0",
+ port=8322,
+ log_level="info",
+ access_log=True
+ )
+
if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
\ No newline at end of file
+ main()
\ No newline at end of file
diff --git a/ml_new/training/models.py b/ml_new/training/models.py
new file mode 100644
index 0000000..66f084b
--- /dev/null
+++ b/ml_new/training/models.py
@@ -0,0 +1,62 @@
+"""
+Data models for dataset building functionality
+"""
+
+from typing import List, Optional, Dict, Any
+from pydantic import BaseModel, Field
+from datetime import datetime
+
+
+class DatasetBuildRequest(BaseModel):
+ """Request model for dataset building"""
+ aid_list: List[int] = Field(..., description="List of video AIDs")
+ embedding_model: str = Field(..., description="Embedding model name")
+ force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings")
+
+
+class DatasetBuildResponse(BaseModel):
+ """Response model for dataset building"""
+ dataset_id: str
+ total_records: int
+ status: str
+ message: str
+ created_at: Optional[datetime] = None
+
+
+class DatasetRecord(BaseModel):
+ """Model for a single dataset record"""
+ aid: int
+ embedding: List[float]
+ label: bool
+ metadata: Dict[str, Any]
+ user_labels: List[Dict[str, Any]]
+ inconsistent: bool
+ text_checksum: str
+
+
+class DatasetInfo(BaseModel):
+ """Model for dataset information"""
+ dataset_id: str
+ dataset: List[DatasetRecord]
+ stats: Dict[str, Any]
+ created_at: datetime
+
+
+class DatasetBuildStats(BaseModel):
+ """Statistics for dataset building process"""
+ total_records: int
+ new_embeddings: int
+ reused_embeddings: int
+ inconsistent_labels: int
+ embedding_model: str
+ processing_time: Optional[float] = None
+
+
+class EmbeddingModelInfo(BaseModel):
+ """Information about embedding models"""
+ name: str
+ dimensions: int
+ type: str
+ api_endpoint: Optional[str] = None
+ max_tokens: Optional[int] = None
+ max_batch_size: Optional[int] = None
\ No newline at end of file
diff --git a/ml_new/training/requirements.txt b/ml_new/training/requirements.txt
index 4cd74eb..d14ecd3 100644
--- a/ml_new/training/requirements.txt
+++ b/ml_new/training/requirements.txt
@@ -9,4 +9,8 @@ scikit-learn==1.3.2
pandas==2.1.3
openai==1.3.7
psycopg2-binary==2.9.9
-sqlalchemy==2.0.23
\ No newline at end of file
+sqlalchemy==2.0.23
+asyncpg==0.29.0
+toml==0.10.2
+aiohttp==3.9.0
+python-dotenv==1.1.0
\ No newline at end of file
diff --git a/mutagen.yml b/mutagen.yml
new file mode 100644
index 0000000..d343689
--- /dev/null
+++ b/mutagen.yml
@@ -0,0 +1,26 @@
+# mutagen.yml
+sync:
+ development:
+ alpha: "."
+ beta: "root@cvsa-hk-02:/web/cvsa"
+ maxStagingFileSize: "5MB"
+ mode: "two-way-resolved"
+ ignore:
+ paths:
+ - "**/node_modules/"
+ - "*.log"
+ - "/model"
+ - "**/.DS_Store"
+ - ".env"
+ - ".env.*"
+ - "**/logs/"
+ - ".git/"
+ - ".jj/"
+ - "dist"
+ - "**/build/"
+ - "**/.react-router"
+ - "redis/"
+ - "temp/"
+ - "mutagen.yml"
+ - "/ml_new/"
+ - "/ml/"
\ No newline at end of file
diff --git a/mutagen.yml.lock b/mutagen.yml.lock
new file mode 100644
index 0000000..fd38c71
--- /dev/null
+++ b/mutagen.yml.lock
@@ -0,0 +1 @@
+proj_EPp5s45rolBZ729uBHaoBMzV51GCoGkpnJuTkK6owzH
\ No newline at end of file
diff --git a/packages/ml_panel/README.md b/packages/ml_panel/README.md
new file mode 100644
index 0000000..e1ff783
--- /dev/null
+++ b/packages/ml_panel/README.md
@@ -0,0 +1,267 @@
+# CVSA ML 基础设施重构项目
+
+## 项目概述
+
+本项目旨在重构现有的 ML 服务基础设施,从原始的 `ml/filter` 系统迁移到新的前后端分离架构。主要目标是为 ML 训练、实验管理和数据处理提供一个现代化的 Web UI 界面。
+
+### 核心功能
+
+- **数据管线管理**: 从 PostgreSQL 数据库获取和预处理训练数据
+- **实验管理**: 训练参数配置、实验追踪和结果可视化
+- **超参数调优**: 动态超参数配置和调整
+- **数据标注界面**: 简单易用的数据标注和管理工具
+- **模型训练**: 2分类视频分类模型训练
+- **嵌入向量管理**: 支持多种嵌入模型和向量维度
+
+## 架构设计
+
+### 技术栈
+
+- **前端**: React + TypeScript + Vite + Tailwind CSS + shadcn/ui
+- **后端**: FastAPI + Python
+- **数据库**: PostgreSQL + Drizzle ORM
+- **向量数据库**: PostgreSQL pgvector
+- **包管理**: Bun (TypeScript) + pip (Python)
+
+### 分层架构
+
+```
+┌─────────────────┐ ┌──────────────┐ ┌──────────────────┐
+│ Web UI │ │ FastAPI │ │ Database │
+│ (React TS) │◄──►│ (Python) │◄──►│ (PostgreSQL) │
+└─────────────────┘ └──────────────┘ └──────────────────┘
+```
+
+## 目录结构
+
+### 前端项目 (`packages/ml_panel/`)
+
+```
+packages/ml_panel/
+├── src/ # 前端应用
+│ ├── App.tsx # 主应用组件
+│ ├── main.tsx # 应用入口
+│ ├── index.css # 全局样式
+│ └── lib/
+│ └── utils.ts # 前端工具函数
+├── lib/ # 核心库文件
+│ ├── types.ts # 共享类型定义
+│ ├── ml-client.ts # ML API 客户端
+│ ├── data-pipeline/ # 数据管线类型
+│ │ └── types.ts
+│ └── index.ts # 导出文件
+├── package.json
+├── vite.config.ts
+└── tailwind.config.js
+```
+
+### 后端服务 (`ml_new/training/`)
+
+```
+ml_new/training/
+├── main.py # FastAPI 主服务
+├── requirements.txt # Python 依赖
+└── ... # 其他服务文件
+```
+
+### 数据库 Schema
+
+使用现有的 `packages/core/drizzle/main/schema.ts` 中的定义:
+
+- `videoTypeLabelInInternal`: 用户标注数据
+- `embeddingsInInternal`: 嵌入向量存储
+- `bilibiliMetadata`: 视频元数据
+
+## 已完成的工作
+
+### 1. 核心类型定义
+
+**文件**: `packages/ml_panel/lib/types.ts`
+
+- 定义了核心数据结构
+- `DatasetRecord`: 数据集记录
+- `UserLabel`: 用户标注
+- `EmbeddingModel`: 嵌入模型配置
+- `TrainingConfig`: 训练配置
+- `ExperimentResult`: 实验结果
+- `InconsistentLabel`: 标注不一致数据
+
+### 2. 数据管线类型
+
+**文件**: `packages/ml_panel/lib/data-pipeline/types.ts`
+
+- `VideoMetadata`: 视频元数据
+- `VideoTypeLabel`: 标注数据
+- `EmbeddingRecord`: 嵌入记录
+- `DataPipelineConfig`: 管线配置
+- `ProcessedDataset`: 处理后的数据集
+
+### 3. ML 客户端
+
+**文件**: `packages/ml_panel/lib/ml-client.ts`
+
+- `MLClient` 类用于与 FastAPI 通信
+- 超参数获取和更新
+- 训练任务启动和状态监控
+- 实验管理
+- 嵌入生成接口
+
+### 4. FastAPI 服务框架
+
+**文件**: `ml_new/training/main.py`
+
+- 基础的 FastAPI 应用配置
+- CORS 中间件配置
+- 内存存储的训练会话管理
+- 基础的 API 端点定义
+
+### 5. 项目配置
+
+- **前端**: `packages/ml_panel/package.json` - React + Vite + TypeScript 配置
+- **后端**: `ml_new/training/requirements.txt` - Python 依赖
+- **主项目**: `packages/ml/package.json` - Monorepo 工作空间配置
+
+## 核心功能实现状态
+
+### 已完成
+
+- [x] 基础项目结构搭建
+- [x] 核心类型定义
+- [x] ML API 客户端
+- [x] FastAPI 服务框架
+- [x] 前端项目配置
+
+### 待实现
+
+- [ ] 数据管线核心逻辑实现
+- [ ] React UI 组件开发
+- [ ] FastAPI 服务功能完善
+- [ ] 数据库连接和数据获取逻辑
+- [ ] 用户标注数据处理
+- [ ] 嵌入向量管理
+- [ ] 标注一致性检查
+- [ ] 训练任务队列
+- [ ] 实验追踪和可视化
+- [ ] 超参数动态配置
+- [ ] 完整的前后端集成测试
+
+## 数据流程设计
+
+### 1. 数据集创建流程 (高 RTT 优化)
+
+```mermaid
+graph TD
+ A[前端点击创建数据集] --> B[选定嵌入模型]
+ B --> C[从 TOML 配置载入模型参数]
+ C --> D[从数据库批量拉取原始文本]
+ D --> E[文本预处理]
+ E --> F[计算文本 hash 并去重]
+ F --> G[批量查询已有嵌入]
+ G --> H[区分需要生成的新文本]
+ H --> I[批量调用嵌入 API]
+ I --> J[批量写入 embeddings 表]
+ J --> K[拉取完整 embeddings 数据]
+ K --> L[合并标签数据]
+ L --> M[构建最终数据集
格式: embeddings, label]
+```
+
+### 2. 数据获取流程 (数据库优化)
+
+```mermaid
+graph TD
+ A[PostgreSQL 远程数据库
RTT: 100ms] --> B[videoTypeLabelInInternal]
+ A --> C[embeddingsInInternal]
+ A --> D[bilibiliMetadata]
+ B --> E[批量获取用户最后一次标注
避免循环查询]
+ C --> F[批量获取嵌入向量
一次性查询所有维度]
+ D --> G[批量获取视频元数据
IN 查询避免 N+1]
+ E --> H[标注一致性检查]
+ F --> I[向量数据处理]
+ G --> J[数据合并]
+ H --> K[数据集构建]
+ I --> K
+ J --> K
+```
+
+### 3. 训练流程
+
+```mermaid
+graph TD
+ A[前端配置] --> B[超参数设置]
+ B --> C[FastAPI 接收]
+ C --> D[数据管线处理]
+ D --> E[模型训练]
+ E --> F[实时状态更新]
+ F --> G[结果存储]
+ G --> H[前端展示]
+```
+
+## 技术要点
+
+### 1. 高性能数据库设计 (RTT 优化)
+
+- **批量操作**: 避免循环查询,使用 `IN` 语句和批量 `INSERT/UPDATE`
+
+### 2. 嵌入向量管理
+
+- **多模型支持**: `embeddingsInInternal` 支持不同维度的向量 (2048/1536/1024)
+- **去重机制**: 使用文本 hash 去重,避免重复生成嵌入
+- **批量处理**: 批量生成和存储嵌入向量
+- **缓存策略**: 优先使用已存在的嵌入向量
+
+### 3. 2分类模型架构
+
+- 从原有的 3分类系统迁移到 2分类
+- 输入: 预计算的嵌入向量 (而非原始文本)
+- 支持多种嵌入模型切换
+
+### 4. 数据一致性处理
+
+- **用户标注**: `videoTypeLabelInInternal` 存储多用户标注
+- **最后标注**: 获取每个用户的最后一次标注作为有效数据
+- **一致性检查**: 识别不同用户标注不一致的视频,标记为需要人工复核
+
+## 后续开发计划
+
+### Phase 1: 核心功能 (优先)
+
+1. **数据管线实现**
+ - 标注数据获取和一致性检查
+ - 嵌入向量生成和存储
+ - 数据集构建逻辑
+
+2. **FastAPI 服务完善**
+ - 构建新的模型架构(输入嵌入向量,直接二分类头)
+ - 迁移现有 ml/filter 训练逻辑
+ - 实现超参数动态暴露
+ - 集成 OpenAI 兼容嵌入 API
+ - 训练任务队列管理
+
+### Phase 2: 用户界面
+
+1. **数据集创建界面**
+ - 嵌入模型选择
+ - 数据预览和筛选
+ - 处理进度显示
+
+2. **训练参数配置界面**
+ - 超参数动态渲染
+ - 参数验证和约束
+
+3. **实验管理和追踪**
+ - 实验历史和比较
+ - 训练状态实时监控
+ - 结果可视化
+
+### Phase 3: 高级功能
+
+1. **超参数自动调优**
+2. **模型版本管理**
+3. **批量训练支持**
+4. **性能优化**
+
+## 注意事项
+
+1. **数据库性能**: 远程数据库 RTT 高,避免 N+1 查询,使用批量操作
+2. **标注一致性**: 实现自动的标注不一致检测
+3. **嵌入模型支持**: 为未来扩展多种嵌入模型预留接口
\ No newline at end of file
diff --git a/packages/ml_panel/lib/data-pipeline/types.ts b/packages/ml_panel/lib/data-pipeline/types.ts
deleted file mode 100644
index cbb1099..0000000
--- a/packages/ml_panel/lib/data-pipeline/types.ts
+++ /dev/null
@@ -1,47 +0,0 @@
-// Data pipeline specific types
-import type { DatasetRecord, UserLabel, EmbeddingModel, InconsistentLabel } from "../types";
-
-// Database types from packages/core
-export interface VideoMetadata {
- aid: number;
- title: string;
- description: string;
- tags: string;
- createdAt?: string;
-}
-
-export interface VideoTypeLabel {
- id: number;
- aid: number;
- label: boolean;
- user: string;
- createdAt: string;
-}
-
-export interface EmbeddingRecord {
- id: number;
- modelName: string;
- dataChecksum: string;
- vec2048?: number[];
- vec1536?: number[];
- vec1024?: number[];
- createdAt?: string;
-}
-
-export interface DataPipelineConfig {
- embeddingModels: EmbeddingModel[];
- batchSize: number;
- requireConsensus: boolean;
- maxInconsistentRatio: number;
-}
-
-export interface ProcessedDataset {
- records: DatasetRecord[];
- inconsistentLabels: InconsistentLabel[];
- statistics: {
- totalRecords: number;
- labeledRecords: number;
- inconsistentRecords: number;
- embeddingCoverage: Record;
- };
-}
diff --git a/packages/ml_panel/lib/index.ts b/packages/ml_panel/lib/index.ts
deleted file mode 100644
index e69de29..0000000
diff --git a/packages/ml_panel/lib/ml-client.ts b/packages/ml_panel/lib/ml-client.ts
deleted file mode 100644
index ff0294a..0000000
--- a/packages/ml_panel/lib/ml-client.ts
+++ /dev/null
@@ -1,107 +0,0 @@
-// ML Client for communicating with FastAPI service
-import type { TrainingConfig, ExperimentResult } from './types';
-
-export interface Hyperparameter {
- name: string;
- type: 'number' | 'boolean' | 'select';
- value: any;
- range?: [number, number];
- options?: string[];
- description?: string;
-}
-
-export interface TrainingRequest {
- experimentName: string;
- config: TrainingConfig;
- dataset: {
- aid: number[];
- embeddings: Record;
- labels: Record;
- };
-}
-
-export interface TrainingStatus {
- experimentId: string;
- status: 'pending' | 'running' | 'completed' | 'failed';
- progress?: number;
- currentEpoch?: number;
- totalEpochs?: number;
- metrics?: Record;
- error?: string;
-}
-
-export class MLClient {
- private baseUrl: string;
-
- constructor(baseUrl: string = 'http://localhost:8000') {
- this.baseUrl = baseUrl;
- }
-
- // Get available hyperparameters from the model
- async getHyperparameters(): Promise {
- const response = await fetch(`${this.baseUrl}/hyperparameters`);
- if (!response.ok) {
- throw new Error(`Failed to get hyperparameters: ${response.statusText}`);
- }
- return (await response.json()) as Hyperparameter[];
- }
-
- // Start a training experiment
- async startTraining(request: TrainingRequest): Promise<{ experimentId: string }> {
- const response = await fetch(`${this.baseUrl}/train`, {
- method: 'POST',
- headers: {
- 'Content-Type': 'application/json',
- },
- body: JSON.stringify(request),
- });
-
- if (!response.ok) {
- throw new Error(`Failed to start training: ${response.statusText}`);
- }
- return (await response.json()) as { experimentId: string };
- }
-
- // Get training status
- async getTrainingStatus(experimentId: string): Promise {
- const response = await fetch(`${this.baseUrl}/train/${experimentId}/status`);
- if (!response.ok) {
- throw new Error(`Failed to get training status: ${response.statusText}`);
- }
- return (await response.json()) as TrainingStatus;
- }
-
- // Get experiment results
- async getExperimentResult(experimentId: string): Promise {
- const response = await fetch(`${this.baseUrl}/experiments/${experimentId}`);
- if (!response.ok) {
- throw new Error(`Failed to get experiment result: ${response.statusText}`);
- }
- return (await response.json()) as ExperimentResult;
- }
-
- // List all experiments
- async listExperiments(): Promise {
- const response = await fetch(`${this.baseUrl}/experiments`);
- if (!response.ok) {
- throw new Error(`Failed to list experiments: ${response.statusText}`);
- }
- return (await response.json()) as ExperimentResult[];
- }
-
- // Generate embeddings using OpenAI-compatible API
- async generateEmbeddings(texts: string[], model: string): Promise {
- const response = await fetch(`${this.baseUrl}/embeddings`, {
- method: 'POST',
- headers: {
- 'Content-Type': 'application/json',
- },
- body: JSON.stringify({ texts, model }),
- });
-
- if (!response.ok) {
- throw new Error(`Failed to generate embeddings: ${response.statusText}`);
- }
- return (await response.json()) as number[][];
- }
-}
\ No newline at end of file
diff --git a/packages/ml_panel/lib/types.ts b/packages/ml_panel/lib/types.ts
deleted file mode 100644
index 6b826b2..0000000
--- a/packages/ml_panel/lib/types.ts
+++ /dev/null
@@ -1,54 +0,0 @@
-// Shared ML types and interfaces
-export interface DatasetRecord {
- aid: number;
- title: string;
- description: string;
- tags: string;
- embedding?: number[];
- label?: boolean;
- userLabels?: UserLabel[];
-}
-
-export interface UserLabel {
- user: string;
- label: boolean;
- createdAt: string;
-}
-
-export interface EmbeddingModel {
- name: string;
- dimensions: number;
- type: "openai-compatible" | "local";
- apiEndpoint?: string;
-}
-
-export interface TrainingConfig {
- learningRate: number;
- batchSize: number;
- epochs: number;
- earlyStop: boolean;
- patience?: number;
- embeddingModel: string;
-}
-
-export interface ExperimentResult {
- experimentId: string;
- config: TrainingConfig;
- metrics: {
- accuracy: number;
- precision: number;
- recall: number;
- f1: number;
- };
- createdAt: string;
- status: "running" | "completed" | "failed";
-}
-
-export interface InconsistentLabel {
- aid: number;
- title: string;
- description: string;
- tags: string;
- labels: UserLabel[];
- consensus?: boolean;
-}