add: dataset building API
This commit is contained in:
parent
c89fd1ce67
commit
77668bbb52
2
.gitignore
vendored
2
.gitignore
vendored
@ -47,4 +47,4 @@ temp/
|
||||
|
||||
meili
|
||||
|
||||
.turbo
|
||||
.turbo/
|
||||
5
.kilocode/rules/common.md
Normal file
5
.kilocode/rules/common.md
Normal file
@ -0,0 +1,5 @@
|
||||
# common.md
|
||||
|
||||
1. Always use bun as package manager.
|
||||
|
||||
2. Always write comments in English.
|
||||
2
ml_new/.gitignore
vendored
2
ml_new/.gitignore
vendored
@ -1 +1 @@
|
||||
datasets
|
||||
datasets/
|
||||
140
ml_new/training/README.md
Normal file
140
ml_new/training/README.md
Normal file
@ -0,0 +1,140 @@
|
||||
# ML Training Service
|
||||
|
||||
A FastAPI-based ML training service for dataset building, embedding generation, and experiment management.
|
||||
|
||||
## Architecture
|
||||
|
||||
The service is organized into modular components:
|
||||
|
||||
```
|
||||
ml_new/training/
|
||||
├── main.py # FastAPI application entry point
|
||||
├── models.py # Pydantic data models
|
||||
├── config_loader.py # Configuration loading from TOML
|
||||
├── database.py # Database connection and operations
|
||||
├── embedding_service.py # Embedding generation service
|
||||
├── dataset_service.py # Dataset building logic
|
||||
├── api_routes.py # API endpoint definitions
|
||||
├── embedding_models.toml # Embedding model configurations
|
||||
└── requirements.txt # Python dependencies
|
||||
```
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Main Application (`main.py`)
|
||||
- FastAPI app initialization
|
||||
- CORS middleware configuration
|
||||
- Service dependency injection
|
||||
- Startup/shutdown event handlers
|
||||
|
||||
### 2. Data Models (`models.py`)
|
||||
- `DatasetBuildRequest`: Request model for dataset building
|
||||
- `DatasetBuildResponse`: Response model for dataset building
|
||||
- `DatasetRecord`: Individual dataset record structure
|
||||
- `EmbeddingModelInfo`: Embedding model configuration
|
||||
|
||||
### 3. Configuration (`config_loader.py`)
|
||||
- Loads embedding model configurations from TOML
|
||||
- Manages model parameters (dimensions, API endpoints, etc.)
|
||||
|
||||
### 4. Database Layer (`database.py`)
|
||||
- PostgreSQL connection management
|
||||
- CRUD operations for video metadata, user labels, and embeddings
|
||||
- Optimized batch queries to avoid N+1 problems
|
||||
|
||||
### 5. Embedding Service (`embedding_service.py`)
|
||||
- Integration with OpenAI-compatible embedding APIs
|
||||
- Text preprocessing and checksum generation
|
||||
- Batch embedding generation with rate limiting
|
||||
|
||||
### 6. Dataset Building (`dataset_service.py`)
|
||||
- Complete dataset construction workflow:
|
||||
1. Pull raw text from database
|
||||
2. Text preprocessing (placeholder)
|
||||
3. Batch embedding generation with deduplication
|
||||
4. Embedding storage and caching
|
||||
5. Final dataset compilation with labels
|
||||
|
||||
### 7. API Routes (`api_routes.py`)
|
||||
- `/api/v1/health`: Health check
|
||||
- `/api/v1/models/embedding`: List available embedding models
|
||||
- `/api/v1/dataset/build`: Build new dataset
|
||||
- `/api/v1/dataset/{id}`: Retrieve built dataset
|
||||
- `/api/v1/datasets`: List all datasets
|
||||
- `/api/v1/dataset/{id}`: Delete dataset
|
||||
|
||||
## Dataset Building Flow
|
||||
|
||||
1. **Model Selection**: Choose embedding model from TOML configuration
|
||||
2. **Data Retrieval**: Pull video metadata and user labels from PostgreSQL
|
||||
3. **Text Processing**: Combine title, description, and tags
|
||||
4. **Deduplication**: Generate checksums to avoid duplicate embeddings
|
||||
5. **Batch Processing**: Generate embeddings for new texts only
|
||||
6. **Storage**: Store embeddings in database with caching
|
||||
7. **Final Assembly**: Combine embeddings with labels using consensus mechanism
|
||||
|
||||
## Configuration
|
||||
|
||||
### Embedding Models (`embedding_models.toml`)
|
||||
```toml
|
||||
[text-embedding-3-large]
|
||||
name = "text-embedding-3-large"
|
||||
dimensions = 3072
|
||||
type = "openai"
|
||||
api_endpoint = "https://api.openai.com/v1/embeddings"
|
||||
max_tokens = 8192
|
||||
max_batch_size = 100
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- `OPENAI_API_KEY`: OpenAI API key for embedding generation
|
||||
|
||||
## Usage
|
||||
|
||||
### Start the Service
|
||||
```bash
|
||||
cd ml_new/training
|
||||
python main.py
|
||||
```
|
||||
|
||||
### Build a Dataset
|
||||
```bash
|
||||
curl -X POST "http://localhost:8322/v1/dataset/build" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"aid_list": [170001, 170002, 170003],
|
||||
"embedding_model": "text-embedding-3-large",
|
||||
"force_regenerate": false
|
||||
}'
|
||||
```
|
||||
|
||||
### Check Health
|
||||
```bash
|
||||
curl "http://localhost:8322/v1/health"
|
||||
```
|
||||
|
||||
### List Embedding Models
|
||||
```bash
|
||||
curl "http://localhost:8322/v1/models/embedding"
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **High Performance**: Optimized database queries with batch operations
|
||||
- **Deduplication**: Text-level deduplication using MD5 checksums
|
||||
- **Consensus Labels**: Majority vote mechanism for user annotations
|
||||
- **Batch Processing**: Efficient embedding generation and storage
|
||||
- **Error Handling**: Comprehensive error handling and logging
|
||||
- **Async Support**: Fully asynchronous operations for scalability
|
||||
- **CORS Enabled**: Ready for frontend integration
|
||||
|
||||
## Production Considerations
|
||||
|
||||
- Replace in-memory dataset storage with database
|
||||
- Add authentication and authorization
|
||||
- Implement rate limiting for API endpoints
|
||||
- Add monitoring and metrics collection
|
||||
- Configure proper logging levels
|
||||
- Set up database connection pooling
|
||||
- Add API documentation with OpenAPI/Swagger
|
||||
196
ml_new/training/api_routes.py
Normal file
196
ml_new/training/api_routes.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""
|
||||
API routes for the ML training service
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from config_loader import config_loader
|
||||
from models import DatasetBuildRequest, DatasetBuildResponse
|
||||
from dataset_service import DatasetBuilder
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router
|
||||
router = APIRouter(prefix="/v1")
|
||||
|
||||
# Global dataset builder instance (will be set by main.py)
|
||||
dataset_builder: DatasetBuilder = None
|
||||
|
||||
|
||||
def set_dataset_builder(builder: DatasetBuilder):
|
||||
"""Set the global dataset builder instance"""
|
||||
global dataset_builder
|
||||
dataset_builder = builder
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
if not dataset_builder:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"status": "unavailable", "message": "Dataset builder not initialized"}
|
||||
)
|
||||
|
||||
try:
|
||||
# Check embedding service health
|
||||
embedding_health = await dataset_builder.embedding_service.health_check()
|
||||
except Exception as e:
|
||||
embedding_health = {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
# Check database connection (pool should already be initialized)
|
||||
db_status = "disconnected"
|
||||
if dataset_builder.db_manager.is_connected:
|
||||
try:
|
||||
response = await dataset_builder.db_manager.pool.fetch("SELECT 1 FROM information_schema.tables")
|
||||
db_status = "connected" if response else "disconnected"
|
||||
except Exception as e:
|
||||
db_status = f"error: {str(e)}"
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "ml-training-api",
|
||||
"embedding_service": embedding_health,
|
||||
"database": db_status,
|
||||
"available_models": list(config_loader.get_embedding_models().keys())
|
||||
}
|
||||
|
||||
|
||||
@router.get("/models/embedding")
|
||||
async def get_embedding_models():
|
||||
"""Get available embedding models"""
|
||||
return {
|
||||
"models": {
|
||||
name: {
|
||||
"name": config.name,
|
||||
"dimensions": config.dimensions,
|
||||
"type": config.type,
|
||||
"api_endpoint": config.api_endpoint,
|
||||
"max_tokens": config.max_tokens,
|
||||
"max_batch_size": config.max_batch_size
|
||||
}
|
||||
for name, config in config_loader.get_embedding_models().items()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/dataset/build", response_model=DatasetBuildResponse)
|
||||
async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks: BackgroundTasks):
|
||||
"""Build dataset endpoint"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
# Validate embedding model
|
||||
if request.embedding_model not in config_loader.get_embedding_models():
|
||||
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
|
||||
|
||||
dataset_id = str(uuid.uuid4())
|
||||
# Start background task for dataset building
|
||||
background_tasks.add_task(
|
||||
dataset_builder.build_dataset,
|
||||
dataset_id,
|
||||
request.aid_list,
|
||||
request.embedding_model,
|
||||
request.force_regenerate
|
||||
)
|
||||
|
||||
return DatasetBuildResponse(
|
||||
dataset_id=dataset_id,
|
||||
total_records=len(request.aid_list),
|
||||
status="started",
|
||||
message="Dataset building started"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/dataset/{dataset_id}")
|
||||
async def get_dataset_endpoint(dataset_id: str):
|
||||
"""Get built dataset by ID"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
if not dataset_builder.dataset_exists(dataset_id):
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
dataset_info = dataset_builder.get_dataset(dataset_id)
|
||||
|
||||
if "error" in dataset_info:
|
||||
raise HTTPException(status_code=500, detail=dataset_info["error"])
|
||||
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"dataset": dataset_info["dataset"],
|
||||
"stats": dataset_info["stats"],
|
||||
"created_at": dataset_info["created_at"]
|
||||
}
|
||||
|
||||
|
||||
@router.get("/datasets")
|
||||
async def list_datasets():
|
||||
"""List all built datasets"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
datasets = []
|
||||
for dataset_id, dataset_info in dataset_builder.dataset_storage.items():
|
||||
if "error" not in dataset_info:
|
||||
datasets.append({
|
||||
"dataset_id": dataset_id,
|
||||
"stats": dataset_info["stats"],
|
||||
"created_at": dataset_info["created_at"]
|
||||
})
|
||||
|
||||
return {"datasets": datasets}
|
||||
|
||||
|
||||
@router.delete("/dataset/{dataset_id}")
|
||||
async def delete_dataset_endpoint(dataset_id: str):
|
||||
"""Delete a built dataset"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
if dataset_builder.delete_dataset(dataset_id):
|
||||
return {"message": f"Dataset {dataset_id} deleted successfully"}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Dataset not found")
|
||||
|
||||
|
||||
@router.get("/datasets")
|
||||
async def list_datasets_endpoint():
|
||||
"""List all built datasets"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
datasets = dataset_builder.list_datasets()
|
||||
return {"datasets": datasets}
|
||||
|
||||
|
||||
@router.get("/datasets/stats")
|
||||
async def get_dataset_stats_endpoint():
|
||||
"""Get overall statistics about stored datasets"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
stats = dataset_builder.get_dataset_stats()
|
||||
return stats
|
||||
|
||||
|
||||
@router.post("/datasets/cleanup")
|
||||
async def cleanup_datasets_endpoint(max_age_days: int = 30):
|
||||
"""Remove datasets older than specified days"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
await dataset_builder.cleanup_old_datasets(max_age_days)
|
||||
return {"message": f"Cleanup completed for datasets older than {max_age_days} days"}
|
||||
85
ml_new/training/config_loader.py
Normal file
85
ml_new/training/config_loader.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
Configuration loader for embedding models and other settings
|
||||
"""
|
||||
|
||||
import toml
|
||||
import os
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingModelConfig(BaseModel):
|
||||
name: str
|
||||
dimensions: int
|
||||
type: str
|
||||
api_endpoint: str = "https://api.openai.com/v1"
|
||||
max_tokens: int = 8191
|
||||
max_batch_size: int = 8
|
||||
api_key_env: str = "OPENAI_API_KEY"
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
def __init__(self, config_path: str = None):
|
||||
if config_path is None:
|
||||
# Default to the embedding_models.toml file we created
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(__file__), "embedding_models.toml"
|
||||
)
|
||||
|
||||
self.config_path = config_path
|
||||
self.embedding_models: Dict[str, EmbeddingModelConfig] = {}
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""Load configuration from TOML file"""
|
||||
try:
|
||||
if not os.path.exists(self.config_path):
|
||||
logger.warning(f"Config file not found: {self.config_path}")
|
||||
return
|
||||
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
config_data = toml.load(f)
|
||||
|
||||
# Load embedding models
|
||||
if "models" not in config_data:
|
||||
return
|
||||
|
||||
for model_key, model_data in config_data["models"].items():
|
||||
self.embedding_models[model_key] = EmbeddingModelConfig(
|
||||
**model_data
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(self.embedding_models)} embedding models from {self.config_path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load config from {self.config_path}: {e}")
|
||||
|
||||
def get_embedding_models(self) -> Dict[str, EmbeddingModelConfig]:
|
||||
"""Get all available embedding models"""
|
||||
return self.embedding_models.copy()
|
||||
|
||||
def get_embedding_model(self, model_name: str) -> EmbeddingModelConfig:
|
||||
"""Get specific embedding model config"""
|
||||
if model_name not in self.embedding_models:
|
||||
raise ValueError(
|
||||
f"Embedding model '{model_name}' not found in configuration"
|
||||
)
|
||||
return self.embedding_models[model_name]
|
||||
|
||||
def list_model_names(self) -> list:
|
||||
"""Get list of available model names"""
|
||||
return list(self.embedding_models.keys())
|
||||
|
||||
def reload_config(self):
|
||||
"""Reload configuration from file"""
|
||||
self.embedding_models = {}
|
||||
self._load_config()
|
||||
|
||||
|
||||
# Global config loader instance
|
||||
config_loader = ConfigLoader()
|
||||
284
ml_new/training/database.py
Normal file
284
ml_new/training/database.py
Normal file
@ -0,0 +1,284 @@
|
||||
"""
|
||||
Database connection and operations for ML training service
|
||||
"""
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
import asyncpg
|
||||
import logging
|
||||
from config_loader import config_loader
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
class DatabaseManager:
|
||||
def __init__(self):
|
||||
self.pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
async def connect(self):
|
||||
"""Initialize database connection pool"""
|
||||
try:
|
||||
self.pool = await asyncpg.create_pool(DATABASE_URL, min_size=5, max_size=20)
|
||||
|
||||
logger.info("Database connection pool initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to database: {e}")
|
||||
raise
|
||||
|
||||
async def close(self):
|
||||
"""Close database connection pool"""
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
logger.info("Database connection pool closed")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if database connection pool is initialized"""
|
||||
return self.pool is not None
|
||||
|
||||
async def get_embedding_models(self):
|
||||
"""Get available embedding models from config"""
|
||||
return config_loader.get_embedding_models()
|
||||
|
||||
async def get_video_metadata(
|
||||
self, aid_list: List[int]
|
||||
) -> Dict[int, Dict[str, Any]]:
|
||||
"""Get video metadata for given AIDs"""
|
||||
if not aid_list:
|
||||
return {}
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT aid, title, description, tags
|
||||
FROM bilibili_metadata
|
||||
WHERE aid = ANY($1::bigint[])
|
||||
"""
|
||||
rows = await conn.fetch(query, aid_list)
|
||||
|
||||
result = {}
|
||||
for row in rows:
|
||||
result[int(row["aid"])] = {
|
||||
"aid": int(row["aid"]),
|
||||
"title": row["title"] or "",
|
||||
"description": row["description"] or "",
|
||||
"tags": row["tags"] or "",
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def get_user_labels(
|
||||
self, aid_list: List[int]
|
||||
) -> Dict[int, List[Dict[str, Any]]]:
|
||||
"""Get user labels for given AIDs, only the latest label per user"""
|
||||
if not aid_list:
|
||||
return {}
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
query = """
|
||||
WITH latest_labels AS (
|
||||
SELECT DISTINCT ON (aid, "user")
|
||||
aid, "user", label, created_at
|
||||
FROM internal.video_type_label
|
||||
WHERE aid = ANY($1::bigint[])
|
||||
ORDER BY aid, "user", created_at DESC
|
||||
)
|
||||
SELECT aid, "user", label, created_at
|
||||
FROM latest_labels
|
||||
ORDER BY aid, "user"
|
||||
"""
|
||||
rows = await conn.fetch(query, aid_list)
|
||||
|
||||
result = {}
|
||||
for row in rows:
|
||||
aid = int(row["aid"])
|
||||
if aid not in result:
|
||||
result[aid] = []
|
||||
|
||||
result[aid].append(
|
||||
{
|
||||
"user": row["user"],
|
||||
"label": bool(row["label"]),
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def get_existing_embeddings(
|
||||
self, checksums: List[str], model_name: str
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get existing embeddings for given checksums and model"""
|
||||
if not checksums:
|
||||
return {}
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT data_checksum, vec_2048, vec_1536, vec_1024, created_at
|
||||
FROM internal.embeddings
|
||||
WHERE model_name = $1 AND data_checksum = ANY($2::text[])
|
||||
"""
|
||||
rows = await conn.fetch(query, model_name, checksums)
|
||||
|
||||
result = {}
|
||||
for row in rows:
|
||||
checksum = row["data_checksum"]
|
||||
|
||||
# Convert vector strings to lists if they exist
|
||||
vec_2048 = self._parse_vector_string(row["vec_2048"]) if row["vec_2048"] else None
|
||||
vec_1536 = self._parse_vector_string(row["vec_1536"]) if row["vec_1536"] else None
|
||||
vec_1024 = self._parse_vector_string(row["vec_1024"]) if row["vec_1024"] else None
|
||||
|
||||
result[checksum] = {
|
||||
"checksum": checksum,
|
||||
"vec_2048": vec_2048,
|
||||
"vec_1536": vec_1536,
|
||||
"vec_1024": vec_1024,
|
||||
"created_at": row["created_at"].isoformat(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _parse_vector_string(self, vector_str: str) -> List[float]:
|
||||
"""Parse vector string format '[1.0,2.0,3.0]' back to list"""
|
||||
if not vector_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Remove brackets and split by comma
|
||||
vector_str = vector_str.strip()
|
||||
if vector_str.startswith('[') and vector_str.endswith(']'):
|
||||
vector_str = vector_str[1:-1]
|
||||
|
||||
return [float(x.strip()) for x in vector_str.split(',') if x.strip()]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse vector string '{vector_str}': {e}")
|
||||
return []
|
||||
|
||||
async def insert_embeddings(self, embeddings_data: List[Dict[str, Any]]) -> None:
|
||||
"""Batch insert embeddings into database"""
|
||||
if not embeddings_data:
|
||||
return
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
for data in embeddings_data:
|
||||
# Determine which vector column to use based on dimensions
|
||||
vec_column = f"vec_{data['dimensions']}"
|
||||
|
||||
# Convert vector list to string format for PostgreSQL
|
||||
vector_str = "[" + ",".join(map(str, data["vector"])) + "]"
|
||||
|
||||
query = f"""
|
||||
INSERT INTO internal.embeddings
|
||||
(model_name, data_checksum, {vec_column}, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (data_checksum) DO NOTHING
|
||||
"""
|
||||
|
||||
await conn.execute(
|
||||
query,
|
||||
data["model_name"],
|
||||
data["checksum"],
|
||||
vector_str,
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
async def get_final_dataset(
|
||||
self, aid_list: List[int], model_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get final dataset with embeddings and labels"""
|
||||
if not aid_list:
|
||||
return []
|
||||
|
||||
# Get video metadata
|
||||
metadata = await self.get_video_metadata(aid_list)
|
||||
|
||||
# Get user labels (latest per user)
|
||||
labels = await self.get_user_labels(aid_list)
|
||||
|
||||
# Prepare text data for embedding
|
||||
text_data = []
|
||||
aid_to_text = {}
|
||||
|
||||
for aid in aid_list:
|
||||
if aid in metadata:
|
||||
# Combine title, description, and tags for embedding
|
||||
text_parts = [
|
||||
metadata[aid]["title"],
|
||||
metadata[aid]["description"],
|
||||
metadata[aid]["tags"],
|
||||
]
|
||||
combined_text = " ".join(filter(None, text_parts))
|
||||
|
||||
# Create checksum for deduplication
|
||||
checksum = hashlib.md5(combined_text.encode("utf-8")).hexdigest()
|
||||
|
||||
text_data.append(
|
||||
{"aid": aid, "text": combined_text, "checksum": checksum}
|
||||
)
|
||||
aid_to_text[checksum] = aid
|
||||
|
||||
# Get checksums = [ existing embeddings
|
||||
checks = [item["checksum"] for item in text_data]
|
||||
existing_embeddings = await self.get_existing_embeddings(checks)
|
||||
|
||||
# ums, model_name Prepare final dataset
|
||||
dataset = []
|
||||
|
||||
for item in text_data:
|
||||
aid = item["aid"]
|
||||
checksum = item["checksum"]
|
||||
|
||||
# Get embedding vector
|
||||
embedding_vector = None
|
||||
if checksum in existing_embeddings:
|
||||
# Use existing embedding
|
||||
emb_data = existing_embeddings[checksum]
|
||||
if emb_data["vec_1536"]:
|
||||
embedding_vector = emb_data["vec_1536"]
|
||||
elif emb_data["vec_2048"]:
|
||||
embedding_vector = emb_data["vec_2048"]
|
||||
elif emb_data["vec_1024"]:
|
||||
embedding_vector = emb_data["vec_1024"]
|
||||
|
||||
# Get labels for this aid
|
||||
aid_labels = labels.get(aid, [])
|
||||
|
||||
# Determine final label using consensus (majority vote)
|
||||
if aid_labels:
|
||||
positive_votes = sum(1 for lbl in aid_labels if lbl["label"])
|
||||
final_label = positive_votes > len(aid_labels) / 2
|
||||
else:
|
||||
final_label = None # No labels available
|
||||
|
||||
# Check for inconsistent labels
|
||||
inconsistent = len(aid_labels) > 1 and (
|
||||
sum(1 for lbl in aid_labels if lbl["label"]) != 0
|
||||
and sum(1 for lbl in aid_labels if lbl["label"]) != len(aid_labels)
|
||||
)
|
||||
|
||||
if embedding_vector and final_label is not None:
|
||||
dataset.append(
|
||||
{
|
||||
"aid": aid,
|
||||
"embedding": embedding_vector,
|
||||
"label": final_label,
|
||||
"metadata": metadata.get(aid, {}),
|
||||
"user_labels": aid_labels,
|
||||
"inconsistent": inconsistent,
|
||||
"text_checksum": checksum,
|
||||
}
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
# Global database manager instance
|
||||
db_manager = DatabaseManager()
|
||||
373
ml_new/training/dataset_service.py
Normal file
373
ml_new/training/dataset_service.py
Normal file
@ -0,0 +1,373 @@
|
||||
"""
|
||||
Dataset building service - handles the complete dataset construction flow
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from database import DatabaseManager
|
||||
from embedding_service import EmbeddingService
|
||||
from config_loader import config_loader
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetBuilder:
|
||||
"""Service for building datasets with the specified flow"""
|
||||
|
||||
def __init__(self, db_manager: DatabaseManager, embedding_service: EmbeddingService, storage_dir: str = "datasets"):
|
||||
self.db_manager = db_manager
|
||||
self.embedding_service = embedding_service
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Load existing datasets from file system
|
||||
self.dataset_storage: Dict[str, Dict] = self._load_all_datasets()
|
||||
|
||||
|
||||
def _get_dataset_file_path(self, dataset_id: str) -> Path:
|
||||
"""Get file path for dataset"""
|
||||
return self.storage_dir / f"{dataset_id}.json"
|
||||
|
||||
|
||||
def _load_dataset_from_file(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load dataset from file"""
|
||||
file_path = self._get_dataset_file_path(dataset_id)
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load dataset {dataset_id} from file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _save_dataset_to_file(self, dataset_id: str, dataset_data: Dict[str, Any]) -> bool:
|
||||
"""Save dataset to file"""
|
||||
file_path = self._get_dataset_file_path(dataset_id)
|
||||
|
||||
try:
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset_data, f, ensure_ascii=False, indent=2)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save dataset {dataset_id} to file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _load_all_datasets(self) -> Dict[str, Dict]:
|
||||
"""Load all datasets from file system"""
|
||||
datasets = {}
|
||||
|
||||
try:
|
||||
for file_path in self.storage_dir.glob("*.json"):
|
||||
dataset_id = file_path.stem
|
||||
dataset_data = self._load_dataset_from_file(dataset_id)
|
||||
if dataset_data:
|
||||
datasets[dataset_id] = dataset_data
|
||||
logger.info(f"Loaded {len(datasets)} datasets from file system")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load datasets from file system: {e}")
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
async def cleanup_old_datasets(self, max_age_days: int = 30):
|
||||
"""Remove datasets older than specified days"""
|
||||
try:
|
||||
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60)
|
||||
removed_count = 0
|
||||
|
||||
for dataset_id in list(self.dataset_storage.keys()):
|
||||
dataset_info = self.dataset_storage[dataset_id]
|
||||
if "created_at" in dataset_info:
|
||||
try:
|
||||
created_time = datetime.fromisoformat(dataset_info["created_at"]).timestamp()
|
||||
if created_time < cutoff_time:
|
||||
# Remove from memory
|
||||
del self.dataset_storage[dataset_id]
|
||||
|
||||
# Remove file
|
||||
file_path = self._get_dataset_file_path(dataset_id)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
removed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process dataset {dataset_id} for cleanup: {e}")
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"Cleaned up {removed_count} old datasets")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old datasets: {e}")
|
||||
|
||||
async def build_dataset(self, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool = False) -> str:
|
||||
"""
|
||||
Build dataset with the specified flow:
|
||||
1. Select embedding model (from TOML config)
|
||||
2. Pull raw text from database
|
||||
3. Preprocess (placeholder for now)
|
||||
4. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
|
||||
5. Write to embeddings table
|
||||
6. Pull all needed embeddings to create dataset with format: embeddings, label
|
||||
"""
|
||||
|
||||
try:
|
||||
logger.info(f"Starting dataset building task {dataset_id}")
|
||||
|
||||
EMBEDDING_MODELS = config_loader.get_embedding_models()
|
||||
|
||||
# Get model configuration
|
||||
if embedding_model not in EMBEDDING_MODELS:
|
||||
raise ValueError(f"Invalid embedding model: {embedding_model}")
|
||||
|
||||
model_config = EMBEDDING_MODELS[embedding_model]
|
||||
|
||||
# Step 1: Get video metadata from database
|
||||
metadata = await self.db_manager.get_video_metadata(aid_list)
|
||||
|
||||
# Step 2: Get user labels
|
||||
labels = await self.db_manager.get_user_labels(aid_list)
|
||||
|
||||
# Step 3: Prepare text data and checksums
|
||||
text_data = []
|
||||
|
||||
for aid in aid_list:
|
||||
if aid in metadata:
|
||||
# Combine title, description, tags
|
||||
combined_text = self.embedding_service.combine_video_text(
|
||||
metadata[aid]['title'],
|
||||
metadata[aid]['description'],
|
||||
metadata[aid]['tags']
|
||||
)
|
||||
|
||||
# Create checksum for deduplication
|
||||
checksum = self.embedding_service.create_text_checksum(combined_text)
|
||||
|
||||
text_data.append({
|
||||
'aid': aid,
|
||||
'text': combined_text,
|
||||
'checksum': checksum
|
||||
})
|
||||
|
||||
# Step 4: Check existing embeddings
|
||||
checksums = [item['checksum'] for item in text_data]
|
||||
existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model)
|
||||
|
||||
# Step 5: Generate new embeddings for texts that don't have them
|
||||
new_embeddings_needed = []
|
||||
for item in text_data:
|
||||
if item['checksum'] not in existing_embeddings or force_regenerate:
|
||||
new_embeddings_needed.append(item['text'])
|
||||
|
||||
new_embeddings_count = 0
|
||||
if new_embeddings_needed:
|
||||
logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
|
||||
generated_embeddings = await self.embedding_service.generate_embeddings_batch(
|
||||
new_embeddings_needed,
|
||||
embedding_model
|
||||
)
|
||||
|
||||
# Step 6: Store new embeddings in database
|
||||
embeddings_to_store = []
|
||||
for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)):
|
||||
checksum = self.embedding_service.create_text_checksum(text)
|
||||
embeddings_to_store.append({
|
||||
'model_name': embedding_model,
|
||||
'checksum': checksum,
|
||||
'dimensions': model_config.dimensions,
|
||||
'vector': embedding
|
||||
})
|
||||
|
||||
await self.db_manager.insert_embeddings(embeddings_to_store)
|
||||
new_embeddings_count = len(embeddings_to_store)
|
||||
|
||||
# Update existing embeddings cache
|
||||
for emb_data in embeddings_to_store:
|
||||
existing_embeddings[emb_data['checksum']] = {
|
||||
'checksum': emb_data['checksum'],
|
||||
f'vec_{model_config.dimensions}': emb_data['vector']
|
||||
}
|
||||
|
||||
# Step 7: Build final dataset
|
||||
dataset = []
|
||||
inconsistent_count = 0
|
||||
|
||||
for item in text_data:
|
||||
aid = item['aid']
|
||||
checksum = item['checksum']
|
||||
|
||||
# Get embedding vector
|
||||
embedding_vector = None
|
||||
if checksum in existing_embeddings:
|
||||
vec_key = f'vec_{model_config.dimensions}'
|
||||
if vec_key in existing_embeddings[checksum]:
|
||||
embedding_vector = existing_embeddings[checksum][vec_key]
|
||||
|
||||
# Get labels for this aid
|
||||
aid_labels = labels.get(aid, [])
|
||||
|
||||
# Determine final label using consensus (majority vote)
|
||||
final_label = None
|
||||
if aid_labels:
|
||||
positive_votes = sum(1 for lbl in aid_labels if lbl['label'])
|
||||
final_label = positive_votes > len(aid_labels) / 2
|
||||
|
||||
# Check for inconsistent labels
|
||||
inconsistent = len(aid_labels) > 1 and (
|
||||
sum(1 for lbl in aid_labels if lbl['label']) != 0 and
|
||||
sum(1 for lbl in aid_labels if lbl['label']) != len(aid_labels)
|
||||
)
|
||||
|
||||
if inconsistent:
|
||||
inconsistent_count += 1
|
||||
|
||||
if embedding_vector and final_label is not None:
|
||||
dataset.append({
|
||||
'aid': aid,
|
||||
'embedding': embedding_vector,
|
||||
'label': final_label,
|
||||
'metadata': metadata.get(aid, {}),
|
||||
'user_labels': aid_labels,
|
||||
'inconsistent': inconsistent,
|
||||
'text_checksum': checksum
|
||||
})
|
||||
|
||||
reused_count = len(dataset) - new_embeddings_count
|
||||
|
||||
logger.info(f"Dataset building completed: {len(dataset)} records, {new_embeddings_count} new, {reused_count} reused, {inconsistent_count} inconsistent")
|
||||
|
||||
# Prepare dataset data
|
||||
dataset_data = {
|
||||
'dataset': dataset,
|
||||
'stats': {
|
||||
'total_records': len(dataset),
|
||||
'new_embeddings': new_embeddings_count,
|
||||
'reused_embeddings': reused_count,
|
||||
'inconsistent_labels': inconsistent_count,
|
||||
'embedding_model': embedding_model
|
||||
},
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Save to file and memory cache
|
||||
if self._save_dataset_to_file(dataset_id, dataset_data):
|
||||
self.dataset_storage[dataset_id] = dataset_data
|
||||
logger.info(f"Dataset {dataset_id} saved to file system")
|
||||
else:
|
||||
logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only")
|
||||
self.dataset_storage[dataset_id] = dataset_data
|
||||
|
||||
return dataset_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dataset building failed for {dataset_id}: {str(e)}")
|
||||
|
||||
# Store error information
|
||||
error_data = {
|
||||
'error': str(e),
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Try to save error to file as well
|
||||
self._save_dataset_to_file(dataset_id, error_data)
|
||||
self.dataset_storage[dataset_id] = error_data
|
||||
raise
|
||||
|
||||
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get built dataset by ID"""
|
||||
# First check memory cache
|
||||
if dataset_id in self.dataset_storage:
|
||||
return self.dataset_storage[dataset_id]
|
||||
|
||||
# If not in memory, try to load from file
|
||||
dataset_data = self._load_dataset_from_file(dataset_id)
|
||||
if dataset_data:
|
||||
# Add to memory cache
|
||||
self.dataset_storage[dataset_id] = dataset_data
|
||||
return dataset_data
|
||||
|
||||
return None
|
||||
|
||||
def dataset_exists(self, dataset_id: str) -> bool:
|
||||
"""Check if dataset exists"""
|
||||
# Check memory cache first
|
||||
if dataset_id in self.dataset_storage:
|
||||
return True
|
||||
|
||||
# Check file system
|
||||
return self._get_dataset_file_path(dataset_id).exists()
|
||||
|
||||
def delete_dataset(self, dataset_id: str) -> bool:
|
||||
"""Delete dataset from both memory and file system"""
|
||||
try:
|
||||
# Remove from memory
|
||||
if dataset_id in self.dataset_storage:
|
||||
del self.dataset_storage[dataset_id]
|
||||
|
||||
# Remove file
|
||||
file_path = self._get_dataset_file_path(dataset_id)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.info(f"Dataset {dataset_id} deleted from file system")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Dataset file {dataset_id} not found for deletion")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
def list_datasets(self) -> List[Dict[str, Any]]:
|
||||
"""List all datasets with their basic information"""
|
||||
datasets = []
|
||||
|
||||
for dataset_id, dataset_info in self.dataset_storage.items():
|
||||
if "error" not in dataset_info:
|
||||
datasets.append({
|
||||
"dataset_id": dataset_id,
|
||||
"stats": dataset_info["stats"],
|
||||
"created_at": dataset_info["created_at"]
|
||||
})
|
||||
|
||||
# Sort by creation time (newest first)
|
||||
datasets.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
return datasets
|
||||
|
||||
def get_dataset_stats(self) -> Dict[str, Any]:
|
||||
"""Get overall statistics about stored datasets"""
|
||||
total_datasets = len(self.dataset_storage)
|
||||
error_datasets = sum(1 for data in self.dataset_storage.values() if "error" in data)
|
||||
valid_datasets = total_datasets - error_datasets
|
||||
|
||||
total_records = 0
|
||||
total_new_embeddings = 0
|
||||
total_reused_embeddings = 0
|
||||
|
||||
for dataset_info in self.dataset_storage.values():
|
||||
if "stats" in dataset_info:
|
||||
stats = dataset_info["stats"]
|
||||
total_records += stats.get("total_records", 0)
|
||||
total_new_embeddings += stats.get("new_embeddings", 0)
|
||||
total_reused_embeddings += stats.get("reused_embeddings", 0)
|
||||
|
||||
return {
|
||||
"total_datasets": total_datasets,
|
||||
"valid_datasets": valid_datasets,
|
||||
"error_datasets": error_datasets,
|
||||
"total_records": total_records,
|
||||
"total_new_embeddings": total_new_embeddings,
|
||||
"total_reused_embeddings": total_reused_embeddings,
|
||||
"storage_directory": str(self.storage_dir)
|
||||
}
|
||||
12
ml_new/training/embedding_models.toml
Normal file
12
ml_new/training/embedding_models.toml
Normal file
@ -0,0 +1,12 @@
|
||||
# Embedding Models Configuration
|
||||
|
||||
model = "qwen3-embedding"
|
||||
|
||||
[models.qwen3-embedding]
|
||||
name = "text-embedding-v4"
|
||||
dimensions = 2048
|
||||
type = "openai-compatible"
|
||||
api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
max_tokens = 8192
|
||||
max_batch_size = 10
|
||||
api_key_env = "ALIYUN_KEY"
|
||||
143
ml_new/training/embedding_service.py
Normal file
143
ml_new/training/embedding_service.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""
|
||||
Embedding service for generating embeddings using OpenAI-compatible API
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
from config_loader import config_loader
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(self):
|
||||
# Get configuration from config loader
|
||||
self.embedding_models = config_loader.get_embedding_models()
|
||||
|
||||
# Initialize OpenAI client (will be configured per model)
|
||||
self.clients: Dict[str, AsyncOpenAI] = {}
|
||||
self._initialize_clients()
|
||||
|
||||
# Rate limiting
|
||||
self.max_requests_per_minute = int(os.getenv("MAX_REQUESTS_PER_MINUTE", "100"))
|
||||
self.request_interval = 60.0 / self.max_requests_per_minute
|
||||
|
||||
def _initialize_clients(self):
|
||||
"""Initialize OpenAI clients for different models/endpoints"""
|
||||
for model_name, model_config in self.embedding_models.items():
|
||||
if model_config.type == "openai-compatible":
|
||||
# Get API key from environment variable specified in config
|
||||
api_key = os.getenv(model_config.api_key_env)
|
||||
|
||||
self.clients[model_name] = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=model_config.api_endpoint
|
||||
)
|
||||
logger.info(f"Initialized client for model {model_name}")
|
||||
|
||||
async def generate_embeddings_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
model: str,
|
||||
batch_size: Optional[int] = None
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings for a batch of texts"""
|
||||
|
||||
# Get model configuration
|
||||
if model not in self.embedding_models:
|
||||
raise ValueError(f"Model '{model}' not found in configuration")
|
||||
|
||||
model_config = self.embedding_models[model]
|
||||
|
||||
# Use model's max_batch_size if not specified
|
||||
if batch_size is None:
|
||||
batch_size = model_config.max_batch_size
|
||||
|
||||
# Validate model and get expected dimensions
|
||||
expected_dims = model_config.dimensions
|
||||
|
||||
if model not in self.clients:
|
||||
raise ValueError(f"No client configured for model '{model}'")
|
||||
|
||||
client = self.clients[model]
|
||||
all_embeddings = []
|
||||
|
||||
# Process in batches to avoid API limits
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
|
||||
try:
|
||||
# Rate limiting
|
||||
if i > 0:
|
||||
await asyncio.sleep(self.request_interval)
|
||||
|
||||
# Generate embeddings
|
||||
response = await client.embeddings.create(
|
||||
model=model_config.name,
|
||||
input=batch,
|
||||
dimensions=expected_dims
|
||||
)
|
||||
|
||||
batch_embeddings = [data.embedding for data in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
logger.info(f"Generated embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings for batch {i//batch_size + 1}: {e}")
|
||||
# For now, fill with zeros as fallback (could implement retry logic)
|
||||
zero_embedding = [0.0] * expected_dims
|
||||
all_embeddings.extend([zero_embedding] * len(batch))
|
||||
|
||||
return all_embeddings
|
||||
|
||||
def create_text_checksum(self, text: str) -> str:
|
||||
"""Create MD5 checksum for text deduplication"""
|
||||
return hashlib.md5(text.encode('utf-8')).hexdigest()
|
||||
|
||||
def combine_video_text(self, title: str, description: str, tags: str) -> str:
|
||||
"""Combine video metadata into a single text for embedding"""
|
||||
parts = [
|
||||
title.strip() if "标题:"+title else "",
|
||||
description.strip() if "简介:"+description else "",
|
||||
tags.strip() if "标签:"+tags else ""
|
||||
]
|
||||
|
||||
# Filter out empty parts and join
|
||||
combined = '\n'.join(filter(None, parts))
|
||||
return combined
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Check if embedding service is healthy"""
|
||||
try:
|
||||
# Test with a simple embedding using the first available model
|
||||
model_name = list(self.embedding_models.keys())[0]
|
||||
|
||||
test_embedding = await self.generate_embeddings_batch(
|
||||
["health check"],
|
||||
model_name,
|
||||
batch_size=1
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "embedding_service",
|
||||
"model": model_name,
|
||||
"dimensions": len(test_embedding[0]) if test_embedding else 0,
|
||||
"available_models": list(self.embedding_models.keys())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "embedding_service",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Global embedding service instance
|
||||
embedding_service = EmbeddingService()
|
||||
@ -1,246 +1,113 @@
|
||||
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Optional, Any
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import json
|
||||
"""
|
||||
Main FastAPI application for ML training service
|
||||
"""
|
||||
|
||||
# Setup logging
|
||||
import logging
|
||||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from database import DatabaseManager
|
||||
from embedding_service import EmbeddingService
|
||||
from dataset_service import DatasetBuilder
|
||||
from api_routes import router, set_dataset_builder
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title="CVSA ML Training API",
|
||||
version="1.0.0",
|
||||
description="ML training service for video classification"
|
||||
)
|
||||
|
||||
# Enable CORS for web UI
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:3000", "http://localhost:5173"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Global service instances
|
||||
db_manager = None
|
||||
embedding_service = None
|
||||
dataset_builder = None
|
||||
|
||||
# Pydantic models
|
||||
class Hyperparameter(BaseModel):
|
||||
name: str
|
||||
type: str # 'number', 'boolean', 'select'
|
||||
value: Any
|
||||
range: Optional[tuple] = None
|
||||
options: Optional[List[str]] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
class TrainingConfig(BaseModel):
|
||||
learning_rate: float = 1e-4
|
||||
batch_size: int = 32
|
||||
epochs: int = 10
|
||||
early_stop: bool = True
|
||||
patience: int = 3
|
||||
embedding_model: str = "text-embedding-3-small"
|
||||
|
||||
class TrainingRequest(BaseModel):
|
||||
experiment_name: str
|
||||
config: TrainingConfig
|
||||
dataset: Dict[str, Any]
|
||||
|
||||
class TrainingStatus(BaseModel):
|
||||
experiment_id: str
|
||||
status: str # 'pending', 'running', 'completed', 'failed'
|
||||
progress: Optional[float] = None
|
||||
current_epoch: Optional[int] = None
|
||||
total_epochs: Optional[int] = None
|
||||
metrics: Optional[Dict[str, float]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
class ExperimentResult(BaseModel):
|
||||
experiment_id: str
|
||||
experiment_name: str
|
||||
config: TrainingConfig
|
||||
metrics: Dict[str, float]
|
||||
created_at: str
|
||||
status: str
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
texts: List[str]
|
||||
model: str
|
||||
|
||||
# In-memory storage for experiments (in production, use database)
|
||||
training_sessions: Dict[str, Dict] = {}
|
||||
experiments: Dict[str, ExperimentResult] = {}
|
||||
|
||||
# Default hyperparameters that will be dynamically discovered
|
||||
DEFAULT_HYPERPARAMETERS = [
|
||||
Hyperparameter(
|
||||
name="learning_rate",
|
||||
type="number",
|
||||
value=1e-4,
|
||||
range=(1e-6, 1e-2),
|
||||
description="Learning rate for optimizer"
|
||||
),
|
||||
Hyperparameter(
|
||||
name="batch_size",
|
||||
type="number",
|
||||
value=32,
|
||||
range=(8, 256),
|
||||
description="Training batch size"
|
||||
),
|
||||
Hyperparameter(
|
||||
name="epochs",
|
||||
type="number",
|
||||
value=10,
|
||||
range=(1, 100),
|
||||
description="Number of training epochs"
|
||||
),
|
||||
Hyperparameter(
|
||||
name="early_stop",
|
||||
type="boolean",
|
||||
value=True,
|
||||
description="Enable early stopping"
|
||||
),
|
||||
Hyperparameter(
|
||||
name="patience",
|
||||
type="number",
|
||||
value=3,
|
||||
range=(1, 20),
|
||||
description="Early stopping patience"
|
||||
),
|
||||
Hyperparameter(
|
||||
name="embedding_model",
|
||||
type="select",
|
||||
value="text-embedding-3-small",
|
||||
options=["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
|
||||
description="Embedding model to use"
|
||||
)
|
||||
]
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "CVSA ML Training API", "version": "1.0.0"}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy", "service": "ml-training-api"}
|
||||
|
||||
@app.get("/hyperparameters", response_model=List[Hyperparameter])
|
||||
async def get_hyperparameters():
|
||||
"""Get all available hyperparameters for the current model"""
|
||||
return DEFAULT_HYPERPARAMETERS
|
||||
|
||||
@app.post("/train")
|
||||
async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks):
|
||||
"""Start a new training experiment"""
|
||||
experiment_id = str(uuid.uuid4())
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager for startup and shutdown events"""
|
||||
global db_manager, embedding_service, dataset_builder
|
||||
|
||||
# Store training session
|
||||
training_sessions[experiment_id] = {
|
||||
"experiment_id": experiment_id,
|
||||
"experiment_name": request.experiment_name,
|
||||
"config": request.config.dict(),
|
||||
"dataset": request.dataset,
|
||||
"status": "pending",
|
||||
"created_at": datetime.now().isoformat()
|
||||
}
|
||||
# Startup
|
||||
logger.info("Initializing services...")
|
||||
|
||||
# Start background training task
|
||||
background_tasks.add_task(run_training, experiment_id, request)
|
||||
|
||||
return {"experiment_id": experiment_id}
|
||||
|
||||
@app.get("/train/{experiment_id}/status", response_model=TrainingStatus)
|
||||
async def get_training_status(experiment_id: str):
|
||||
"""Get training status for an experiment"""
|
||||
if experiment_id not in training_sessions:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
|
||||
session = training_sessions[experiment_id]
|
||||
|
||||
return TrainingStatus(
|
||||
experiment_id=experiment_id,
|
||||
status=session.get("status", "unknown"),
|
||||
progress=session.get("progress"),
|
||||
current_epoch=session.get("current_epoch"),
|
||||
total_epochs=session.get("total_epochs"),
|
||||
metrics=session.get("metrics"),
|
||||
error=session.get("error")
|
||||
)
|
||||
|
||||
@app.get("/experiments", response_model=List[ExperimentResult])
|
||||
async def list_experiments():
|
||||
"""List all experiments"""
|
||||
return list(experiments.values())
|
||||
|
||||
@app.get("/experiments/{experiment_id}", response_model=ExperimentResult)
|
||||
async def get_experiment(experiment_id: str):
|
||||
"""Get experiment details"""
|
||||
if experiment_id not in experiments:
|
||||
raise HTTPException(status_code=404, detail="Experiment not found")
|
||||
|
||||
return experiments[experiment_id]
|
||||
|
||||
@app.post("/embeddings")
|
||||
async def generate_embeddings(request: EmbeddingRequest):
|
||||
"""Generate embeddings using OpenAI-compatible API"""
|
||||
# This is a placeholder implementation
|
||||
# In production, this would call actual embedding API
|
||||
embeddings = []
|
||||
for text in request.texts:
|
||||
# Mock embedding generation
|
||||
embedding = [0.1] * 1536 # Mock 1536-dimensional embedding
|
||||
embeddings.append(embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def run_training(experiment_id: str, request: TrainingRequest):
|
||||
"""Background task to run training"""
|
||||
try:
|
||||
session = training_sessions[experiment_id]
|
||||
session["status"] = "running"
|
||||
session["total_epochs"] = request.config.epochs
|
||||
# Database manager
|
||||
db_manager = DatabaseManager()
|
||||
await db_manager.connect() # Initialize database connection pool
|
||||
logger.info("Database manager initialized and connected")
|
||||
|
||||
# Simulate training process
|
||||
for epoch in range(request.config.epochs):
|
||||
session["current_epoch"] = epoch + 1
|
||||
session["progress"] = (epoch + 1) / request.config.epochs
|
||||
|
||||
# Simulate training metrics
|
||||
session["metrics"] = {
|
||||
"loss": max(0.0, 1.0 - (epoch + 1) * 0.1),
|
||||
"accuracy": min(0.95, 0.5 + (epoch + 1) * 0.05),
|
||||
"val_loss": max(0.0, 0.8 - (epoch + 1) * 0.08),
|
||||
"val_accuracy": min(0.92, 0.45 + (epoch + 1) * 0.04)
|
||||
}
|
||||
|
||||
logger.info(f"Training epoch {epoch + 1}/{request.config.epochs}")
|
||||
await asyncio.sleep(1) # Simulate training time
|
||||
# Embedding service
|
||||
embedding_service = EmbeddingService()
|
||||
logger.info("Embedding service initialized")
|
||||
|
||||
# Training completed
|
||||
session["status"] = "completed"
|
||||
final_metrics = session["metrics"]
|
||||
# Dataset builder
|
||||
dataset_builder = DatasetBuilder(db_manager, embedding_service)
|
||||
logger.info("Dataset builder initialized")
|
||||
|
||||
# Store final experiment result
|
||||
experiments[experiment_id] = ExperimentResult(
|
||||
experiment_id=experiment_id,
|
||||
experiment_name=request.experiment_name,
|
||||
config=request.config,
|
||||
metrics=final_metrics,
|
||||
created_at=session["created_at"],
|
||||
status="completed"
|
||||
)
|
||||
# Set global dataset builder instance
|
||||
set_dataset_builder(dataset_builder)
|
||||
|
||||
logger.info(f"Training completed for experiment {experiment_id}")
|
||||
logger.info("All services initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
session["status"] = "failed"
|
||||
session["error"] = str(e)
|
||||
logger.error(f"Training failed for experiment {experiment_id}: {str(e)}")
|
||||
logger.error(f"Failed to initialize services: {e}")
|
||||
raise
|
||||
|
||||
# Yield control to the application
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down services...")
|
||||
|
||||
try:
|
||||
if db_manager:
|
||||
await db_manager.close()
|
||||
logger.info("Database connection pool closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure FastAPI application"""
|
||||
|
||||
# Create FastAPI app with lifespan manager
|
||||
app = FastAPI(
|
||||
title="ML Training Service",
|
||||
description="ML training, dataset building, and experiment management service",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Configure appropriately for production
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include API routes
|
||||
app.include_router(router)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
app = create_app()
|
||||
|
||||
# Run the application
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=8322,
|
||||
log_level="info",
|
||||
access_log=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
main()
|
||||
62
ml_new/training/models.py
Normal file
62
ml_new/training/models.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""
|
||||
Data models for dataset building functionality
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class DatasetBuildRequest(BaseModel):
|
||||
"""Request model for dataset building"""
|
||||
aid_list: List[int] = Field(..., description="List of video AIDs")
|
||||
embedding_model: str = Field(..., description="Embedding model name")
|
||||
force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings")
|
||||
|
||||
|
||||
class DatasetBuildResponse(BaseModel):
|
||||
"""Response model for dataset building"""
|
||||
dataset_id: str
|
||||
total_records: int
|
||||
status: str
|
||||
message: str
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class DatasetRecord(BaseModel):
|
||||
"""Model for a single dataset record"""
|
||||
aid: int
|
||||
embedding: List[float]
|
||||
label: bool
|
||||
metadata: Dict[str, Any]
|
||||
user_labels: List[Dict[str, Any]]
|
||||
inconsistent: bool
|
||||
text_checksum: str
|
||||
|
||||
|
||||
class DatasetInfo(BaseModel):
|
||||
"""Model for dataset information"""
|
||||
dataset_id: str
|
||||
dataset: List[DatasetRecord]
|
||||
stats: Dict[str, Any]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class DatasetBuildStats(BaseModel):
|
||||
"""Statistics for dataset building process"""
|
||||
total_records: int
|
||||
new_embeddings: int
|
||||
reused_embeddings: int
|
||||
inconsistent_labels: int
|
||||
embedding_model: str
|
||||
processing_time: Optional[float] = None
|
||||
|
||||
|
||||
class EmbeddingModelInfo(BaseModel):
|
||||
"""Information about embedding models"""
|
||||
name: str
|
||||
dimensions: int
|
||||
type: str
|
||||
api_endpoint: Optional[str] = None
|
||||
max_tokens: Optional[int] = None
|
||||
max_batch_size: Optional[int] = None
|
||||
@ -9,4 +9,8 @@ scikit-learn==1.3.2
|
||||
pandas==2.1.3
|
||||
openai==1.3.7
|
||||
psycopg2-binary==2.9.9
|
||||
sqlalchemy==2.0.23
|
||||
sqlalchemy==2.0.23
|
||||
asyncpg==0.29.0
|
||||
toml==0.10.2
|
||||
aiohttp==3.9.0
|
||||
python-dotenv==1.1.0
|
||||
26
mutagen.yml
Normal file
26
mutagen.yml
Normal file
@ -0,0 +1,26 @@
|
||||
# mutagen.yml
|
||||
sync:
|
||||
development:
|
||||
alpha: "."
|
||||
beta: "root@cvsa-hk-02:/web/cvsa"
|
||||
maxStagingFileSize: "5MB"
|
||||
mode: "two-way-resolved"
|
||||
ignore:
|
||||
paths:
|
||||
- "**/node_modules/"
|
||||
- "*.log"
|
||||
- "/model"
|
||||
- "**/.DS_Store"
|
||||
- ".env"
|
||||
- ".env.*"
|
||||
- "**/logs/"
|
||||
- ".git/"
|
||||
- ".jj/"
|
||||
- "dist"
|
||||
- "**/build/"
|
||||
- "**/.react-router"
|
||||
- "redis/"
|
||||
- "temp/"
|
||||
- "mutagen.yml"
|
||||
- "/ml_new/"
|
||||
- "/ml/"
|
||||
1
mutagen.yml.lock
Normal file
1
mutagen.yml.lock
Normal file
@ -0,0 +1 @@
|
||||
proj_EPp5s45rolBZ729uBHaoBMzV51GCoGkpnJuTkK6owzH
|
||||
267
packages/ml_panel/README.md
Normal file
267
packages/ml_panel/README.md
Normal file
@ -0,0 +1,267 @@
|
||||
# CVSA ML 基础设施重构项目
|
||||
|
||||
## 项目概述
|
||||
|
||||
本项目旨在重构现有的 ML 服务基础设施,从原始的 `ml/filter` 系统迁移到新的前后端分离架构。主要目标是为 ML 训练、实验管理和数据处理提供一个现代化的 Web UI 界面。
|
||||
|
||||
### 核心功能
|
||||
|
||||
- **数据管线管理**: 从 PostgreSQL 数据库获取和预处理训练数据
|
||||
- **实验管理**: 训练参数配置、实验追踪和结果可视化
|
||||
- **超参数调优**: 动态超参数配置和调整
|
||||
- **数据标注界面**: 简单易用的数据标注和管理工具
|
||||
- **模型训练**: 2分类视频分类模型训练
|
||||
- **嵌入向量管理**: 支持多种嵌入模型和向量维度
|
||||
|
||||
## 架构设计
|
||||
|
||||
### 技术栈
|
||||
|
||||
- **前端**: React + TypeScript + Vite + Tailwind CSS + shadcn/ui
|
||||
- **后端**: FastAPI + Python
|
||||
- **数据库**: PostgreSQL + Drizzle ORM
|
||||
- **向量数据库**: PostgreSQL pgvector
|
||||
- **包管理**: Bun (TypeScript) + pip (Python)
|
||||
|
||||
### 分层架构
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌──────────────┐ ┌──────────────────┐
|
||||
│ Web UI │ │ FastAPI │ │ Database │
|
||||
│ (React TS) │◄──►│ (Python) │◄──►│ (PostgreSQL) │
|
||||
└─────────────────┘ └──────────────┘ └──────────────────┘
|
||||
```
|
||||
|
||||
## 目录结构
|
||||
|
||||
### 前端项目 (`packages/ml_panel/`)
|
||||
|
||||
```
|
||||
packages/ml_panel/
|
||||
├── src/ # 前端应用
|
||||
│ ├── App.tsx # 主应用组件
|
||||
│ ├── main.tsx # 应用入口
|
||||
│ ├── index.css # 全局样式
|
||||
│ └── lib/
|
||||
│ └── utils.ts # 前端工具函数
|
||||
├── lib/ # 核心库文件
|
||||
│ ├── types.ts # 共享类型定义
|
||||
│ ├── ml-client.ts # ML API 客户端
|
||||
│ ├── data-pipeline/ # 数据管线类型
|
||||
│ │ └── types.ts
|
||||
│ └── index.ts # 导出文件
|
||||
├── package.json
|
||||
├── vite.config.ts
|
||||
└── tailwind.config.js
|
||||
```
|
||||
|
||||
### 后端服务 (`ml_new/training/`)
|
||||
|
||||
```
|
||||
ml_new/training/
|
||||
├── main.py # FastAPI 主服务
|
||||
├── requirements.txt # Python 依赖
|
||||
└── ... # 其他服务文件
|
||||
```
|
||||
|
||||
### 数据库 Schema
|
||||
|
||||
使用现有的 `packages/core/drizzle/main/schema.ts` 中的定义:
|
||||
|
||||
- `videoTypeLabelInInternal`: 用户标注数据
|
||||
- `embeddingsInInternal`: 嵌入向量存储
|
||||
- `bilibiliMetadata`: 视频元数据
|
||||
|
||||
## 已完成的工作
|
||||
|
||||
### 1. 核心类型定义
|
||||
|
||||
**文件**: `packages/ml_panel/lib/types.ts`
|
||||
|
||||
- 定义了核心数据结构
|
||||
- `DatasetRecord`: 数据集记录
|
||||
- `UserLabel`: 用户标注
|
||||
- `EmbeddingModel`: 嵌入模型配置
|
||||
- `TrainingConfig`: 训练配置
|
||||
- `ExperimentResult`: 实验结果
|
||||
- `InconsistentLabel`: 标注不一致数据
|
||||
|
||||
### 2. 数据管线类型
|
||||
|
||||
**文件**: `packages/ml_panel/lib/data-pipeline/types.ts`
|
||||
|
||||
- `VideoMetadata`: 视频元数据
|
||||
- `VideoTypeLabel`: 标注数据
|
||||
- `EmbeddingRecord`: 嵌入记录
|
||||
- `DataPipelineConfig`: 管线配置
|
||||
- `ProcessedDataset`: 处理后的数据集
|
||||
|
||||
### 3. ML 客户端
|
||||
|
||||
**文件**: `packages/ml_panel/lib/ml-client.ts`
|
||||
|
||||
- `MLClient` 类用于与 FastAPI 通信
|
||||
- 超参数获取和更新
|
||||
- 训练任务启动和状态监控
|
||||
- 实验管理
|
||||
- 嵌入生成接口
|
||||
|
||||
### 4. FastAPI 服务框架
|
||||
|
||||
**文件**: `ml_new/training/main.py`
|
||||
|
||||
- 基础的 FastAPI 应用配置
|
||||
- CORS 中间件配置
|
||||
- 内存存储的训练会话管理
|
||||
- 基础的 API 端点定义
|
||||
|
||||
### 5. 项目配置
|
||||
|
||||
- **前端**: `packages/ml_panel/package.json` - React + Vite + TypeScript 配置
|
||||
- **后端**: `ml_new/training/requirements.txt` - Python 依赖
|
||||
- **主项目**: `packages/ml/package.json` - Monorepo 工作空间配置
|
||||
|
||||
## 核心功能实现状态
|
||||
|
||||
### 已完成
|
||||
|
||||
- [x] 基础项目结构搭建
|
||||
- [x] 核心类型定义
|
||||
- [x] ML API 客户端
|
||||
- [x] FastAPI 服务框架
|
||||
- [x] 前端项目配置
|
||||
|
||||
### 待实现
|
||||
|
||||
- [ ] 数据管线核心逻辑实现
|
||||
- [ ] React UI 组件开发
|
||||
- [ ] FastAPI 服务功能完善
|
||||
- [ ] 数据库连接和数据获取逻辑
|
||||
- [ ] 用户标注数据处理
|
||||
- [ ] 嵌入向量管理
|
||||
- [ ] 标注一致性检查
|
||||
- [ ] 训练任务队列
|
||||
- [ ] 实验追踪和可视化
|
||||
- [ ] 超参数动态配置
|
||||
- [ ] 完整的前后端集成测试
|
||||
|
||||
## 数据流程设计
|
||||
|
||||
### 1. 数据集创建流程 (高 RTT 优化)
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[前端点击创建数据集] --> B[选定嵌入模型]
|
||||
B --> C[从 TOML 配置载入模型参数]
|
||||
C --> D[从数据库批量拉取原始文本]
|
||||
D --> E[文本预处理]
|
||||
E --> F[计算文本 hash 并去重]
|
||||
F --> G[批量查询已有嵌入]
|
||||
G --> H[区分需要生成的新文本]
|
||||
H --> I[批量调用嵌入 API]
|
||||
I --> J[批量写入 embeddings 表]
|
||||
J --> K[拉取完整 embeddings 数据]
|
||||
K --> L[合并标签数据]
|
||||
L --> M[构建最终数据集<br/>格式: embeddings, label]
|
||||
```
|
||||
|
||||
### 2. 数据获取流程 (数据库优化)
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[PostgreSQL 远程数据库<br/>RTT: 100ms] --> B[videoTypeLabelInInternal]
|
||||
A --> C[embeddingsInInternal]
|
||||
A --> D[bilibiliMetadata]
|
||||
B --> E[批量获取用户最后一次标注<br/>避免循环查询]
|
||||
C --> F[批量获取嵌入向量<br/>一次性查询所有维度]
|
||||
D --> G[批量获取视频元数据<br/>IN 查询避免 N+1]
|
||||
E --> H[标注一致性检查]
|
||||
F --> I[向量数据处理]
|
||||
G --> J[数据合并]
|
||||
H --> K[数据集构建]
|
||||
I --> K
|
||||
J --> K
|
||||
```
|
||||
|
||||
### 3. 训练流程
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[前端配置] --> B[超参数设置]
|
||||
B --> C[FastAPI 接收]
|
||||
C --> D[数据管线处理]
|
||||
D --> E[模型训练]
|
||||
E --> F[实时状态更新]
|
||||
F --> G[结果存储]
|
||||
G --> H[前端展示]
|
||||
```
|
||||
|
||||
## 技术要点
|
||||
|
||||
### 1. 高性能数据库设计 (RTT 优化)
|
||||
|
||||
- **批量操作**: 避免循环查询,使用 `IN` 语句和批量 `INSERT/UPDATE`
|
||||
|
||||
### 2. 嵌入向量管理
|
||||
|
||||
- **多模型支持**: `embeddingsInInternal` 支持不同维度的向量 (2048/1536/1024)
|
||||
- **去重机制**: 使用文本 hash 去重,避免重复生成嵌入
|
||||
- **批量处理**: 批量生成和存储嵌入向量
|
||||
- **缓存策略**: 优先使用已存在的嵌入向量
|
||||
|
||||
### 3. 2分类模型架构
|
||||
|
||||
- 从原有的 3分类系统迁移到 2分类
|
||||
- 输入: 预计算的嵌入向量 (而非原始文本)
|
||||
- 支持多种嵌入模型切换
|
||||
|
||||
### 4. 数据一致性处理
|
||||
|
||||
- **用户标注**: `videoTypeLabelInInternal` 存储多用户标注
|
||||
- **最后标注**: 获取每个用户的最后一次标注作为有效数据
|
||||
- **一致性检查**: 识别不同用户标注不一致的视频,标记为需要人工复核
|
||||
|
||||
## 后续开发计划
|
||||
|
||||
### Phase 1: 核心功能 (优先)
|
||||
|
||||
1. **数据管线实现**
|
||||
- 标注数据获取和一致性检查
|
||||
- 嵌入向量生成和存储
|
||||
- 数据集构建逻辑
|
||||
|
||||
2. **FastAPI 服务完善**
|
||||
- 构建新的模型架构(输入嵌入向量,直接二分类头)
|
||||
- 迁移现有 ml/filter 训练逻辑
|
||||
- 实现超参数动态暴露
|
||||
- 集成 OpenAI 兼容嵌入 API
|
||||
- 训练任务队列管理
|
||||
|
||||
### Phase 2: 用户界面
|
||||
|
||||
1. **数据集创建界面**
|
||||
- 嵌入模型选择
|
||||
- 数据预览和筛选
|
||||
- 处理进度显示
|
||||
|
||||
2. **训练参数配置界面**
|
||||
- 超参数动态渲染
|
||||
- 参数验证和约束
|
||||
|
||||
3. **实验管理和追踪**
|
||||
- 实验历史和比较
|
||||
- 训练状态实时监控
|
||||
- 结果可视化
|
||||
|
||||
### Phase 3: 高级功能
|
||||
|
||||
1. **超参数自动调优**
|
||||
2. **模型版本管理**
|
||||
3. **批量训练支持**
|
||||
4. **性能优化**
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **数据库性能**: 远程数据库 RTT 高,避免 N+1 查询,使用批量操作
|
||||
2. **标注一致性**: 实现自动的标注不一致检测
|
||||
3. **嵌入模型支持**: 为未来扩展多种嵌入模型预留接口
|
||||
@ -1,47 +0,0 @@
|
||||
// Data pipeline specific types
|
||||
import type { DatasetRecord, UserLabel, EmbeddingModel, InconsistentLabel } from "../types";
|
||||
|
||||
// Database types from packages/core
|
||||
export interface VideoMetadata {
|
||||
aid: number;
|
||||
title: string;
|
||||
description: string;
|
||||
tags: string;
|
||||
createdAt?: string;
|
||||
}
|
||||
|
||||
export interface VideoTypeLabel {
|
||||
id: number;
|
||||
aid: number;
|
||||
label: boolean;
|
||||
user: string;
|
||||
createdAt: string;
|
||||
}
|
||||
|
||||
export interface EmbeddingRecord {
|
||||
id: number;
|
||||
modelName: string;
|
||||
dataChecksum: string;
|
||||
vec2048?: number[];
|
||||
vec1536?: number[];
|
||||
vec1024?: number[];
|
||||
createdAt?: string;
|
||||
}
|
||||
|
||||
export interface DataPipelineConfig {
|
||||
embeddingModels: EmbeddingModel[];
|
||||
batchSize: number;
|
||||
requireConsensus: boolean;
|
||||
maxInconsistentRatio: number;
|
||||
}
|
||||
|
||||
export interface ProcessedDataset {
|
||||
records: DatasetRecord[];
|
||||
inconsistentLabels: InconsistentLabel[];
|
||||
statistics: {
|
||||
totalRecords: number;
|
||||
labeledRecords: number;
|
||||
inconsistentRecords: number;
|
||||
embeddingCoverage: Record<string, number>;
|
||||
};
|
||||
}
|
||||
@ -1,107 +0,0 @@
|
||||
// ML Client for communicating with FastAPI service
|
||||
import type { TrainingConfig, ExperimentResult } from './types';
|
||||
|
||||
export interface Hyperparameter {
|
||||
name: string;
|
||||
type: 'number' | 'boolean' | 'select';
|
||||
value: any;
|
||||
range?: [number, number];
|
||||
options?: string[];
|
||||
description?: string;
|
||||
}
|
||||
|
||||
export interface TrainingRequest {
|
||||
experimentName: string;
|
||||
config: TrainingConfig;
|
||||
dataset: {
|
||||
aid: number[];
|
||||
embeddings: Record<string, number[]>;
|
||||
labels: Record<number, boolean>;
|
||||
};
|
||||
}
|
||||
|
||||
export interface TrainingStatus {
|
||||
experimentId: string;
|
||||
status: 'pending' | 'running' | 'completed' | 'failed';
|
||||
progress?: number;
|
||||
currentEpoch?: number;
|
||||
totalEpochs?: number;
|
||||
metrics?: Record<string, number>;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export class MLClient {
|
||||
private baseUrl: string;
|
||||
|
||||
constructor(baseUrl: string = 'http://localhost:8000') {
|
||||
this.baseUrl = baseUrl;
|
||||
}
|
||||
|
||||
// Get available hyperparameters from the model
|
||||
async getHyperparameters(): Promise<Hyperparameter[]> {
|
||||
const response = await fetch(`${this.baseUrl}/hyperparameters`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to get hyperparameters: ${response.statusText}`);
|
||||
}
|
||||
return (await response.json()) as Hyperparameter[];
|
||||
}
|
||||
|
||||
// Start a training experiment
|
||||
async startTraining(request: TrainingRequest): Promise<{ experimentId: string }> {
|
||||
const response = await fetch(`${this.baseUrl}/train`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to start training: ${response.statusText}`);
|
||||
}
|
||||
return (await response.json()) as { experimentId: string };
|
||||
}
|
||||
|
||||
// Get training status
|
||||
async getTrainingStatus(experimentId: string): Promise<TrainingStatus> {
|
||||
const response = await fetch(`${this.baseUrl}/train/${experimentId}/status`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to get training status: ${response.statusText}`);
|
||||
}
|
||||
return (await response.json()) as TrainingStatus;
|
||||
}
|
||||
|
||||
// Get experiment results
|
||||
async getExperimentResult(experimentId: string): Promise<ExperimentResult> {
|
||||
const response = await fetch(`${this.baseUrl}/experiments/${experimentId}`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to get experiment result: ${response.statusText}`);
|
||||
}
|
||||
return (await response.json()) as ExperimentResult;
|
||||
}
|
||||
|
||||
// List all experiments
|
||||
async listExperiments(): Promise<ExperimentResult[]> {
|
||||
const response = await fetch(`${this.baseUrl}/experiments`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to list experiments: ${response.statusText}`);
|
||||
}
|
||||
return (await response.json()) as ExperimentResult[];
|
||||
}
|
||||
|
||||
// Generate embeddings using OpenAI-compatible API
|
||||
async generateEmbeddings(texts: string[], model: string): Promise<number[][]> {
|
||||
const response = await fetch(`${this.baseUrl}/embeddings`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ texts, model }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to generate embeddings: ${response.statusText}`);
|
||||
}
|
||||
return (await response.json()) as number[][];
|
||||
}
|
||||
}
|
||||
@ -1,54 +0,0 @@
|
||||
// Shared ML types and interfaces
|
||||
export interface DatasetRecord {
|
||||
aid: number;
|
||||
title: string;
|
||||
description: string;
|
||||
tags: string;
|
||||
embedding?: number[];
|
||||
label?: boolean;
|
||||
userLabels?: UserLabel[];
|
||||
}
|
||||
|
||||
export interface UserLabel {
|
||||
user: string;
|
||||
label: boolean;
|
||||
createdAt: string;
|
||||
}
|
||||
|
||||
export interface EmbeddingModel {
|
||||
name: string;
|
||||
dimensions: number;
|
||||
type: "openai-compatible" | "local";
|
||||
apiEndpoint?: string;
|
||||
}
|
||||
|
||||
export interface TrainingConfig {
|
||||
learningRate: number;
|
||||
batchSize: number;
|
||||
epochs: number;
|
||||
earlyStop: boolean;
|
||||
patience?: number;
|
||||
embeddingModel: string;
|
||||
}
|
||||
|
||||
export interface ExperimentResult {
|
||||
experimentId: string;
|
||||
config: TrainingConfig;
|
||||
metrics: {
|
||||
accuracy: number;
|
||||
precision: number;
|
||||
recall: number;
|
||||
f1: number;
|
||||
};
|
||||
createdAt: string;
|
||||
status: "running" | "completed" | "failed";
|
||||
}
|
||||
|
||||
export interface InconsistentLabel {
|
||||
aid: number;
|
||||
title: string;
|
||||
description: string;
|
||||
tags: string;
|
||||
labels: UserLabel[];
|
||||
consensus?: boolean;
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user