1
0

add: dataset building API

This commit is contained in:
alikia2x (寒寒) 2025-12-10 04:11:08 +08:00
parent c89fd1ce67
commit 77668bbb52
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG Key ID: 56209E0CCD8420C6
20 changed files with 1696 additions and 439 deletions

2
.gitignore vendored
View File

@ -47,4 +47,4 @@ temp/
meili
.turbo
.turbo/

View File

@ -0,0 +1,5 @@
# common.md
1. Always use bun as package manager.
2. Always write comments in English.

2
ml_new/.gitignore vendored
View File

@ -1 +1 @@
datasets
datasets/

140
ml_new/training/README.md Normal file
View File

@ -0,0 +1,140 @@
# ML Training Service
A FastAPI-based ML training service for dataset building, embedding generation, and experiment management.
## Architecture
The service is organized into modular components:
```
ml_new/training/
├── main.py # FastAPI application entry point
├── models.py # Pydantic data models
├── config_loader.py # Configuration loading from TOML
├── database.py # Database connection and operations
├── embedding_service.py # Embedding generation service
├── dataset_service.py # Dataset building logic
├── api_routes.py # API endpoint definitions
├── embedding_models.toml # Embedding model configurations
└── requirements.txt # Python dependencies
```
## Key Components
### 1. Main Application (`main.py`)
- FastAPI app initialization
- CORS middleware configuration
- Service dependency injection
- Startup/shutdown event handlers
### 2. Data Models (`models.py`)
- `DatasetBuildRequest`: Request model for dataset building
- `DatasetBuildResponse`: Response model for dataset building
- `DatasetRecord`: Individual dataset record structure
- `EmbeddingModelInfo`: Embedding model configuration
### 3. Configuration (`config_loader.py`)
- Loads embedding model configurations from TOML
- Manages model parameters (dimensions, API endpoints, etc.)
### 4. Database Layer (`database.py`)
- PostgreSQL connection management
- CRUD operations for video metadata, user labels, and embeddings
- Optimized batch queries to avoid N+1 problems
### 5. Embedding Service (`embedding_service.py`)
- Integration with OpenAI-compatible embedding APIs
- Text preprocessing and checksum generation
- Batch embedding generation with rate limiting
### 6. Dataset Building (`dataset_service.py`)
- Complete dataset construction workflow:
1. Pull raw text from database
2. Text preprocessing (placeholder)
3. Batch embedding generation with deduplication
4. Embedding storage and caching
5. Final dataset compilation with labels
### 7. API Routes (`api_routes.py`)
- `/api/v1/health`: Health check
- `/api/v1/models/embedding`: List available embedding models
- `/api/v1/dataset/build`: Build new dataset
- `/api/v1/dataset/{id}`: Retrieve built dataset
- `/api/v1/datasets`: List all datasets
- `/api/v1/dataset/{id}`: Delete dataset
## Dataset Building Flow
1. **Model Selection**: Choose embedding model from TOML configuration
2. **Data Retrieval**: Pull video metadata and user labels from PostgreSQL
3. **Text Processing**: Combine title, description, and tags
4. **Deduplication**: Generate checksums to avoid duplicate embeddings
5. **Batch Processing**: Generate embeddings for new texts only
6. **Storage**: Store embeddings in database with caching
7. **Final Assembly**: Combine embeddings with labels using consensus mechanism
## Configuration
### Embedding Models (`embedding_models.toml`)
```toml
[text-embedding-3-large]
name = "text-embedding-3-large"
dimensions = 3072
type = "openai"
api_endpoint = "https://api.openai.com/v1/embeddings"
max_tokens = 8192
max_batch_size = 100
```
### Environment Variables
- `DATABASE_URL`: PostgreSQL connection string
- `OPENAI_API_KEY`: OpenAI API key for embedding generation
## Usage
### Start the Service
```bash
cd ml_new/training
python main.py
```
### Build a Dataset
```bash
curl -X POST "http://localhost:8322/v1/dataset/build" \
-H "Content-Type: application/json" \
-d '{
"aid_list": [170001, 170002, 170003],
"embedding_model": "text-embedding-3-large",
"force_regenerate": false
}'
```
### Check Health
```bash
curl "http://localhost:8322/v1/health"
```
### List Embedding Models
```bash
curl "http://localhost:8322/v1/models/embedding"
```
## Features
- **High Performance**: Optimized database queries with batch operations
- **Deduplication**: Text-level deduplication using MD5 checksums
- **Consensus Labels**: Majority vote mechanism for user annotations
- **Batch Processing**: Efficient embedding generation and storage
- **Error Handling**: Comprehensive error handling and logging
- **Async Support**: Fully asynchronous operations for scalability
- **CORS Enabled**: Ready for frontend integration
## Production Considerations
- Replace in-memory dataset storage with database
- Add authentication and authorization
- Implement rate limiting for API endpoints
- Add monitoring and metrics collection
- Configure proper logging levels
- Set up database connection pooling
- Add API documentation with OpenAPI/Swagger

View File

@ -0,0 +1,196 @@
"""
API routes for the ML training service
"""
import logging
import uuid
from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse
from config_loader import config_loader
from models import DatasetBuildRequest, DatasetBuildResponse
from dataset_service import DatasetBuilder
logger = logging.getLogger(__name__)
# Create router
router = APIRouter(prefix="/v1")
# Global dataset builder instance (will be set by main.py)
dataset_builder: DatasetBuilder = None
def set_dataset_builder(builder: DatasetBuilder):
"""Set the global dataset builder instance"""
global dataset_builder
dataset_builder = builder
@router.get("/health")
async def health_check():
"""Health check endpoint"""
if not dataset_builder:
return JSONResponse(
status_code=503,
content={"status": "unavailable", "message": "Dataset builder not initialized"}
)
try:
# Check embedding service health
embedding_health = await dataset_builder.embedding_service.health_check()
except Exception as e:
embedding_health = {"status": "unhealthy", "error": str(e)}
# Check database connection (pool should already be initialized)
db_status = "disconnected"
if dataset_builder.db_manager.is_connected:
try:
response = await dataset_builder.db_manager.pool.fetch("SELECT 1 FROM information_schema.tables")
db_status = "connected" if response else "disconnected"
except Exception as e:
db_status = f"error: {str(e)}"
return {
"status": "healthy",
"service": "ml-training-api",
"embedding_service": embedding_health,
"database": db_status,
"available_models": list(config_loader.get_embedding_models().keys())
}
@router.get("/models/embedding")
async def get_embedding_models():
"""Get available embedding models"""
return {
"models": {
name: {
"name": config.name,
"dimensions": config.dimensions,
"type": config.type,
"api_endpoint": config.api_endpoint,
"max_tokens": config.max_tokens,
"max_batch_size": config.max_batch_size
}
for name, config in config_loader.get_embedding_models().items()
}
}
@router.post("/dataset/build", response_model=DatasetBuildResponse)
async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks: BackgroundTasks):
"""Build dataset endpoint"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
# Validate embedding model
if request.embedding_model not in config_loader.get_embedding_models():
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
dataset_id = str(uuid.uuid4())
# Start background task for dataset building
background_tasks.add_task(
dataset_builder.build_dataset,
dataset_id,
request.aid_list,
request.embedding_model,
request.force_regenerate
)
return DatasetBuildResponse(
dataset_id=dataset_id,
total_records=len(request.aid_list),
status="started",
message="Dataset building started"
)
@router.get("/dataset/{dataset_id}")
async def get_dataset_endpoint(dataset_id: str):
"""Get built dataset by ID"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
if not dataset_builder.dataset_exists(dataset_id):
raise HTTPException(status_code=404, detail="Dataset not found")
dataset_info = dataset_builder.get_dataset(dataset_id)
if "error" in dataset_info:
raise HTTPException(status_code=500, detail=dataset_info["error"])
return {
"dataset_id": dataset_id,
"dataset": dataset_info["dataset"],
"stats": dataset_info["stats"],
"created_at": dataset_info["created_at"]
}
@router.get("/datasets")
async def list_datasets():
"""List all built datasets"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
datasets = []
for dataset_id, dataset_info in dataset_builder.dataset_storage.items():
if "error" not in dataset_info:
datasets.append({
"dataset_id": dataset_id,
"stats": dataset_info["stats"],
"created_at": dataset_info["created_at"]
})
return {"datasets": datasets}
@router.delete("/dataset/{dataset_id}")
async def delete_dataset_endpoint(dataset_id: str):
"""Delete a built dataset"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
if dataset_builder.delete_dataset(dataset_id):
return {"message": f"Dataset {dataset_id} deleted successfully"}
else:
raise HTTPException(status_code=404, detail="Dataset not found")
@router.get("/datasets")
async def list_datasets_endpoint():
"""List all built datasets"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
datasets = dataset_builder.list_datasets()
return {"datasets": datasets}
@router.get("/datasets/stats")
async def get_dataset_stats_endpoint():
"""Get overall statistics about stored datasets"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
stats = dataset_builder.get_dataset_stats()
return stats
@router.post("/datasets/cleanup")
async def cleanup_datasets_endpoint(max_age_days: int = 30):
"""Remove datasets older than specified days"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
await dataset_builder.cleanup_old_datasets(max_age_days)
return {"message": f"Cleanup completed for datasets older than {max_age_days} days"}

View File

@ -0,0 +1,85 @@
"""
Configuration loader for embedding models and other settings
"""
import toml
import os
from typing import Dict
from pydantic import BaseModel
import logging
logger = logging.getLogger(__name__)
class EmbeddingModelConfig(BaseModel):
name: str
dimensions: int
type: str
api_endpoint: str = "https://api.openai.com/v1"
max_tokens: int = 8191
max_batch_size: int = 8
api_key_env: str = "OPENAI_API_KEY"
class ConfigLoader:
def __init__(self, config_path: str = None):
if config_path is None:
# Default to the embedding_models.toml file we created
config_path = os.path.join(
os.path.dirname(__file__), "embedding_models.toml"
)
self.config_path = config_path
self.embedding_models: Dict[str, EmbeddingModelConfig] = {}
self._load_config()
def _load_config(self):
"""Load configuration from TOML file"""
try:
if not os.path.exists(self.config_path):
logger.warning(f"Config file not found: {self.config_path}")
return
with open(self.config_path, "r", encoding="utf-8") as f:
config_data = toml.load(f)
# Load embedding models
if "models" not in config_data:
return
for model_key, model_data in config_data["models"].items():
self.embedding_models[model_key] = EmbeddingModelConfig(
**model_data
)
logger.info(
f"Loaded {len(self.embedding_models)} embedding models from {self.config_path}"
)
except Exception as e:
logger.error(f"Failed to load config from {self.config_path}: {e}")
def get_embedding_models(self) -> Dict[str, EmbeddingModelConfig]:
"""Get all available embedding models"""
return self.embedding_models.copy()
def get_embedding_model(self, model_name: str) -> EmbeddingModelConfig:
"""Get specific embedding model config"""
if model_name not in self.embedding_models:
raise ValueError(
f"Embedding model '{model_name}' not found in configuration"
)
return self.embedding_models[model_name]
def list_model_names(self) -> list:
"""Get list of available model names"""
return list(self.embedding_models.keys())
def reload_config(self):
"""Reload configuration from file"""
self.embedding_models = {}
self._load_config()
# Global config loader instance
config_loader = ConfigLoader()

284
ml_new/training/database.py Normal file
View File

@ -0,0 +1,284 @@
"""
Database connection and operations for ML training service
"""
import os
import hashlib
from typing import List, Dict, Optional, Any
from datetime import datetime
import asyncpg
import logging
from config_loader import config_loader
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
# Database configuration
DATABASE_URL = os.getenv("DATABASE_URL")
class DatabaseManager:
def __init__(self):
self.pool: Optional[asyncpg.Pool] = None
async def connect(self):
"""Initialize database connection pool"""
try:
self.pool = await asyncpg.create_pool(DATABASE_URL, min_size=5, max_size=20)
logger.info("Database connection pool initialized")
except Exception as e:
logger.error(f"Failed to connect to database: {e}")
raise
async def close(self):
"""Close database connection pool"""
if self.pool:
await self.pool.close()
logger.info("Database connection pool closed")
@property
def is_connected(self) -> bool:
"""Check if database connection pool is initialized"""
return self.pool is not None
async def get_embedding_models(self):
"""Get available embedding models from config"""
return config_loader.get_embedding_models()
async def get_video_metadata(
self, aid_list: List[int]
) -> Dict[int, Dict[str, Any]]:
"""Get video metadata for given AIDs"""
if not aid_list:
return {}
async with self.pool.acquire() as conn:
query = """
SELECT aid, title, description, tags
FROM bilibili_metadata
WHERE aid = ANY($1::bigint[])
"""
rows = await conn.fetch(query, aid_list)
result = {}
for row in rows:
result[int(row["aid"])] = {
"aid": int(row["aid"]),
"title": row["title"] or "",
"description": row["description"] or "",
"tags": row["tags"] or "",
}
return result
async def get_user_labels(
self, aid_list: List[int]
) -> Dict[int, List[Dict[str, Any]]]:
"""Get user labels for given AIDs, only the latest label per user"""
if not aid_list:
return {}
async with self.pool.acquire() as conn:
query = """
WITH latest_labels AS (
SELECT DISTINCT ON (aid, "user")
aid, "user", label, created_at
FROM internal.video_type_label
WHERE aid = ANY($1::bigint[])
ORDER BY aid, "user", created_at DESC
)
SELECT aid, "user", label, created_at
FROM latest_labels
ORDER BY aid, "user"
"""
rows = await conn.fetch(query, aid_list)
result = {}
for row in rows:
aid = int(row["aid"])
if aid not in result:
result[aid] = []
result[aid].append(
{
"user": row["user"],
"label": bool(row["label"]),
"created_at": row["created_at"].isoformat(),
}
)
return result
async def get_existing_embeddings(
self, checksums: List[str], model_name: str
) -> Dict[str, Dict[str, Any]]:
"""Get existing embeddings for given checksums and model"""
if not checksums:
return {}
async with self.pool.acquire() as conn:
query = """
SELECT data_checksum, vec_2048, vec_1536, vec_1024, created_at
FROM internal.embeddings
WHERE model_name = $1 AND data_checksum = ANY($2::text[])
"""
rows = await conn.fetch(query, model_name, checksums)
result = {}
for row in rows:
checksum = row["data_checksum"]
# Convert vector strings to lists if they exist
vec_2048 = self._parse_vector_string(row["vec_2048"]) if row["vec_2048"] else None
vec_1536 = self._parse_vector_string(row["vec_1536"]) if row["vec_1536"] else None
vec_1024 = self._parse_vector_string(row["vec_1024"]) if row["vec_1024"] else None
result[checksum] = {
"checksum": checksum,
"vec_2048": vec_2048,
"vec_1536": vec_1536,
"vec_1024": vec_1024,
"created_at": row["created_at"].isoformat(),
}
return result
def _parse_vector_string(self, vector_str: str) -> List[float]:
"""Parse vector string format '[1.0,2.0,3.0]' back to list"""
if not vector_str:
return []
try:
# Remove brackets and split by comma
vector_str = vector_str.strip()
if vector_str.startswith('[') and vector_str.endswith(']'):
vector_str = vector_str[1:-1]
return [float(x.strip()) for x in vector_str.split(',') if x.strip()]
except Exception as e:
logger.warning(f"Failed to parse vector string '{vector_str}': {e}")
return []
async def insert_embeddings(self, embeddings_data: List[Dict[str, Any]]) -> None:
"""Batch insert embeddings into database"""
if not embeddings_data:
return
async with self.pool.acquire() as conn:
async with conn.transaction():
for data in embeddings_data:
# Determine which vector column to use based on dimensions
vec_column = f"vec_{data['dimensions']}"
# Convert vector list to string format for PostgreSQL
vector_str = "[" + ",".join(map(str, data["vector"])) + "]"
query = f"""
INSERT INTO internal.embeddings
(model_name, data_checksum, {vec_column}, created_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (data_checksum) DO NOTHING
"""
await conn.execute(
query,
data["model_name"],
data["checksum"],
vector_str,
datetime.now(),
)
async def get_final_dataset(
self, aid_list: List[int], model_name: str
) -> List[Dict[str, Any]]:
"""Get final dataset with embeddings and labels"""
if not aid_list:
return []
# Get video metadata
metadata = await self.get_video_metadata(aid_list)
# Get user labels (latest per user)
labels = await self.get_user_labels(aid_list)
# Prepare text data for embedding
text_data = []
aid_to_text = {}
for aid in aid_list:
if aid in metadata:
# Combine title, description, and tags for embedding
text_parts = [
metadata[aid]["title"],
metadata[aid]["description"],
metadata[aid]["tags"],
]
combined_text = " ".join(filter(None, text_parts))
# Create checksum for deduplication
checksum = hashlib.md5(combined_text.encode("utf-8")).hexdigest()
text_data.append(
{"aid": aid, "text": combined_text, "checksum": checksum}
)
aid_to_text[checksum] = aid
# Get checksums = [ existing embeddings
checks = [item["checksum"] for item in text_data]
existing_embeddings = await self.get_existing_embeddings(checks)
# ums, model_name Prepare final dataset
dataset = []
for item in text_data:
aid = item["aid"]
checksum = item["checksum"]
# Get embedding vector
embedding_vector = None
if checksum in existing_embeddings:
# Use existing embedding
emb_data = existing_embeddings[checksum]
if emb_data["vec_1536"]:
embedding_vector = emb_data["vec_1536"]
elif emb_data["vec_2048"]:
embedding_vector = emb_data["vec_2048"]
elif emb_data["vec_1024"]:
embedding_vector = emb_data["vec_1024"]
# Get labels for this aid
aid_labels = labels.get(aid, [])
# Determine final label using consensus (majority vote)
if aid_labels:
positive_votes = sum(1 for lbl in aid_labels if lbl["label"])
final_label = positive_votes > len(aid_labels) / 2
else:
final_label = None # No labels available
# Check for inconsistent labels
inconsistent = len(aid_labels) > 1 and (
sum(1 for lbl in aid_labels if lbl["label"]) != 0
and sum(1 for lbl in aid_labels if lbl["label"]) != len(aid_labels)
)
if embedding_vector and final_label is not None:
dataset.append(
{
"aid": aid,
"embedding": embedding_vector,
"label": final_label,
"metadata": metadata.get(aid, {}),
"user_labels": aid_labels,
"inconsistent": inconsistent,
"text_checksum": checksum,
}
)
return dataset
# Global database manager instance
db_manager = DatabaseManager()

View File

@ -0,0 +1,373 @@
"""
Dataset building service - handles the complete dataset construction flow
"""
import os
import json
import logging
import asyncio
from pathlib import Path
from typing import List, Dict, Any, Optional
from datetime import datetime
from database import DatabaseManager
from embedding_service import EmbeddingService
from config_loader import config_loader
logger = logging.getLogger(__name__)
class DatasetBuilder:
"""Service for building datasets with the specified flow"""
def __init__(self, db_manager: DatabaseManager, embedding_service: EmbeddingService, storage_dir: str = "datasets"):
self.db_manager = db_manager
self.embedding_service = embedding_service
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(exist_ok=True)
# Load existing datasets from file system
self.dataset_storage: Dict[str, Dict] = self._load_all_datasets()
def _get_dataset_file_path(self, dataset_id: str) -> Path:
"""Get file path for dataset"""
return self.storage_dir / f"{dataset_id}.json"
def _load_dataset_from_file(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Load dataset from file"""
file_path = self._get_dataset_file_path(dataset_id)
if not file_path.exists():
return None
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load dataset {dataset_id} from file: {e}")
return None
def _save_dataset_to_file(self, dataset_id: str, dataset_data: Dict[str, Any]) -> bool:
"""Save dataset to file"""
file_path = self._get_dataset_file_path(dataset_id)
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(dataset_data, f, ensure_ascii=False, indent=2)
return True
except Exception as e:
logger.error(f"Failed to save dataset {dataset_id} to file: {e}")
return False
def _load_all_datasets(self) -> Dict[str, Dict]:
"""Load all datasets from file system"""
datasets = {}
try:
for file_path in self.storage_dir.glob("*.json"):
dataset_id = file_path.stem
dataset_data = self._load_dataset_from_file(dataset_id)
if dataset_data:
datasets[dataset_id] = dataset_data
logger.info(f"Loaded {len(datasets)} datasets from file system")
except Exception as e:
logger.error(f"Failed to load datasets from file system: {e}")
return datasets
async def cleanup_old_datasets(self, max_age_days: int = 30):
"""Remove datasets older than specified days"""
try:
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60)
removed_count = 0
for dataset_id in list(self.dataset_storage.keys()):
dataset_info = self.dataset_storage[dataset_id]
if "created_at" in dataset_info:
try:
created_time = datetime.fromisoformat(dataset_info["created_at"]).timestamp()
if created_time < cutoff_time:
# Remove from memory
del self.dataset_storage[dataset_id]
# Remove file
file_path = self._get_dataset_file_path(dataset_id)
if file_path.exists():
file_path.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"Failed to process dataset {dataset_id} for cleanup: {e}")
if removed_count > 0:
logger.info(f"Cleaned up {removed_count} old datasets")
except Exception as e:
logger.error(f"Failed to cleanup old datasets: {e}")
async def build_dataset(self, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool = False) -> str:
"""
Build dataset with the specified flow:
1. Select embedding model (from TOML config)
2. Pull raw text from database
3. Preprocess (placeholder for now)
4. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
5. Write to embeddings table
6. Pull all needed embeddings to create dataset with format: embeddings, label
"""
try:
logger.info(f"Starting dataset building task {dataset_id}")
EMBEDDING_MODELS = config_loader.get_embedding_models()
# Get model configuration
if embedding_model not in EMBEDDING_MODELS:
raise ValueError(f"Invalid embedding model: {embedding_model}")
model_config = EMBEDDING_MODELS[embedding_model]
# Step 1: Get video metadata from database
metadata = await self.db_manager.get_video_metadata(aid_list)
# Step 2: Get user labels
labels = await self.db_manager.get_user_labels(aid_list)
# Step 3: Prepare text data and checksums
text_data = []
for aid in aid_list:
if aid in metadata:
# Combine title, description, tags
combined_text = self.embedding_service.combine_video_text(
metadata[aid]['title'],
metadata[aid]['description'],
metadata[aid]['tags']
)
# Create checksum for deduplication
checksum = self.embedding_service.create_text_checksum(combined_text)
text_data.append({
'aid': aid,
'text': combined_text,
'checksum': checksum
})
# Step 4: Check existing embeddings
checksums = [item['checksum'] for item in text_data]
existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model)
# Step 5: Generate new embeddings for texts that don't have them
new_embeddings_needed = []
for item in text_data:
if item['checksum'] not in existing_embeddings or force_regenerate:
new_embeddings_needed.append(item['text'])
new_embeddings_count = 0
if new_embeddings_needed:
logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
generated_embeddings = await self.embedding_service.generate_embeddings_batch(
new_embeddings_needed,
embedding_model
)
# Step 6: Store new embeddings in database
embeddings_to_store = []
for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)):
checksum = self.embedding_service.create_text_checksum(text)
embeddings_to_store.append({
'model_name': embedding_model,
'checksum': checksum,
'dimensions': model_config.dimensions,
'vector': embedding
})
await self.db_manager.insert_embeddings(embeddings_to_store)
new_embeddings_count = len(embeddings_to_store)
# Update existing embeddings cache
for emb_data in embeddings_to_store:
existing_embeddings[emb_data['checksum']] = {
'checksum': emb_data['checksum'],
f'vec_{model_config.dimensions}': emb_data['vector']
}
# Step 7: Build final dataset
dataset = []
inconsistent_count = 0
for item in text_data:
aid = item['aid']
checksum = item['checksum']
# Get embedding vector
embedding_vector = None
if checksum in existing_embeddings:
vec_key = f'vec_{model_config.dimensions}'
if vec_key in existing_embeddings[checksum]:
embedding_vector = existing_embeddings[checksum][vec_key]
# Get labels for this aid
aid_labels = labels.get(aid, [])
# Determine final label using consensus (majority vote)
final_label = None
if aid_labels:
positive_votes = sum(1 for lbl in aid_labels if lbl['label'])
final_label = positive_votes > len(aid_labels) / 2
# Check for inconsistent labels
inconsistent = len(aid_labels) > 1 and (
sum(1 for lbl in aid_labels if lbl['label']) != 0 and
sum(1 for lbl in aid_labels if lbl['label']) != len(aid_labels)
)
if inconsistent:
inconsistent_count += 1
if embedding_vector and final_label is not None:
dataset.append({
'aid': aid,
'embedding': embedding_vector,
'label': final_label,
'metadata': metadata.get(aid, {}),
'user_labels': aid_labels,
'inconsistent': inconsistent,
'text_checksum': checksum
})
reused_count = len(dataset) - new_embeddings_count
logger.info(f"Dataset building completed: {len(dataset)} records, {new_embeddings_count} new, {reused_count} reused, {inconsistent_count} inconsistent")
# Prepare dataset data
dataset_data = {
'dataset': dataset,
'stats': {
'total_records': len(dataset),
'new_embeddings': new_embeddings_count,
'reused_embeddings': reused_count,
'inconsistent_labels': inconsistent_count,
'embedding_model': embedding_model
},
'created_at': datetime.now().isoformat()
}
# Save to file and memory cache
if self._save_dataset_to_file(dataset_id, dataset_data):
self.dataset_storage[dataset_id] = dataset_data
logger.info(f"Dataset {dataset_id} saved to file system")
else:
logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only")
self.dataset_storage[dataset_id] = dataset_data
return dataset_id
except Exception as e:
logger.error(f"Dataset building failed for {dataset_id}: {str(e)}")
# Store error information
error_data = {
'error': str(e),
'created_at': datetime.now().isoformat()
}
# Try to save error to file as well
self._save_dataset_to_file(dataset_id, error_data)
self.dataset_storage[dataset_id] = error_data
raise
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get built dataset by ID"""
# First check memory cache
if dataset_id in self.dataset_storage:
return self.dataset_storage[dataset_id]
# If not in memory, try to load from file
dataset_data = self._load_dataset_from_file(dataset_id)
if dataset_data:
# Add to memory cache
self.dataset_storage[dataset_id] = dataset_data
return dataset_data
return None
def dataset_exists(self, dataset_id: str) -> bool:
"""Check if dataset exists"""
# Check memory cache first
if dataset_id in self.dataset_storage:
return True
# Check file system
return self._get_dataset_file_path(dataset_id).exists()
def delete_dataset(self, dataset_id: str) -> bool:
"""Delete dataset from both memory and file system"""
try:
# Remove from memory
if dataset_id in self.dataset_storage:
del self.dataset_storage[dataset_id]
# Remove file
file_path = self._get_dataset_file_path(dataset_id)
if file_path.exists():
file_path.unlink()
logger.info(f"Dataset {dataset_id} deleted from file system")
return True
else:
logger.warning(f"Dataset file {dataset_id} not found for deletion")
return False
except Exception as e:
logger.error(f"Failed to delete dataset {dataset_id}: {e}")
return False
def list_datasets(self) -> List[Dict[str, Any]]:
"""List all datasets with their basic information"""
datasets = []
for dataset_id, dataset_info in self.dataset_storage.items():
if "error" not in dataset_info:
datasets.append({
"dataset_id": dataset_id,
"stats": dataset_info["stats"],
"created_at": dataset_info["created_at"]
})
# Sort by creation time (newest first)
datasets.sort(key=lambda x: x["created_at"], reverse=True)
return datasets
def get_dataset_stats(self) -> Dict[str, Any]:
"""Get overall statistics about stored datasets"""
total_datasets = len(self.dataset_storage)
error_datasets = sum(1 for data in self.dataset_storage.values() if "error" in data)
valid_datasets = total_datasets - error_datasets
total_records = 0
total_new_embeddings = 0
total_reused_embeddings = 0
for dataset_info in self.dataset_storage.values():
if "stats" in dataset_info:
stats = dataset_info["stats"]
total_records += stats.get("total_records", 0)
total_new_embeddings += stats.get("new_embeddings", 0)
total_reused_embeddings += stats.get("reused_embeddings", 0)
return {
"total_datasets": total_datasets,
"valid_datasets": valid_datasets,
"error_datasets": error_datasets,
"total_records": total_records,
"total_new_embeddings": total_new_embeddings,
"total_reused_embeddings": total_reused_embeddings,
"storage_directory": str(self.storage_dir)
}

View File

@ -0,0 +1,12 @@
# Embedding Models Configuration
model = "qwen3-embedding"
[models.qwen3-embedding]
name = "text-embedding-v4"
dimensions = 2048
type = "openai-compatible"
api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
max_tokens = 8192
max_batch_size = 10
api_key_env = "ALIYUN_KEY"

View File

@ -0,0 +1,143 @@
"""
Embedding service for generating embeddings using OpenAI-compatible API
"""
import asyncio
import hashlib
from typing import List, Dict, Any, Optional
import logging
from openai import AsyncOpenAI
import os
from config_loader import config_loader
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
class EmbeddingService:
def __init__(self):
# Get configuration from config loader
self.embedding_models = config_loader.get_embedding_models()
# Initialize OpenAI client (will be configured per model)
self.clients: Dict[str, AsyncOpenAI] = {}
self._initialize_clients()
# Rate limiting
self.max_requests_per_minute = int(os.getenv("MAX_REQUESTS_PER_MINUTE", "100"))
self.request_interval = 60.0 / self.max_requests_per_minute
def _initialize_clients(self):
"""Initialize OpenAI clients for different models/endpoints"""
for model_name, model_config in self.embedding_models.items():
if model_config.type == "openai-compatible":
# Get API key from environment variable specified in config
api_key = os.getenv(model_config.api_key_env)
self.clients[model_name] = AsyncOpenAI(
api_key=api_key,
base_url=model_config.api_endpoint
)
logger.info(f"Initialized client for model {model_name}")
async def generate_embeddings_batch(
self,
texts: List[str],
model: str,
batch_size: Optional[int] = None
) -> List[List[float]]:
"""Generate embeddings for a batch of texts"""
# Get model configuration
if model not in self.embedding_models:
raise ValueError(f"Model '{model}' not found in configuration")
model_config = self.embedding_models[model]
# Use model's max_batch_size if not specified
if batch_size is None:
batch_size = model_config.max_batch_size
# Validate model and get expected dimensions
expected_dims = model_config.dimensions
if model not in self.clients:
raise ValueError(f"No client configured for model '{model}'")
client = self.clients[model]
all_embeddings = []
# Process in batches to avoid API limits
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
try:
# Rate limiting
if i > 0:
await asyncio.sleep(self.request_interval)
# Generate embeddings
response = await client.embeddings.create(
model=model_config.name,
input=batch,
dimensions=expected_dims
)
batch_embeddings = [data.embedding for data in response.data]
all_embeddings.extend(batch_embeddings)
logger.info(f"Generated embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
except Exception as e:
logger.error(f"Error generating embeddings for batch {i//batch_size + 1}: {e}")
# For now, fill with zeros as fallback (could implement retry logic)
zero_embedding = [0.0] * expected_dims
all_embeddings.extend([zero_embedding] * len(batch))
return all_embeddings
def create_text_checksum(self, text: str) -> str:
"""Create MD5 checksum for text deduplication"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def combine_video_text(self, title: str, description: str, tags: str) -> str:
"""Combine video metadata into a single text for embedding"""
parts = [
title.strip() if "标题:"+title else "",
description.strip() if "简介:"+description else "",
tags.strip() if "标签:"+tags else ""
]
# Filter out empty parts and join
combined = '\n'.join(filter(None, parts))
return combined
async def health_check(self) -> Dict[str, Any]:
"""Check if embedding service is healthy"""
try:
# Test with a simple embedding using the first available model
model_name = list(self.embedding_models.keys())[0]
test_embedding = await self.generate_embeddings_batch(
["health check"],
model_name,
batch_size=1
)
return {
"status": "healthy",
"service": "embedding_service",
"model": model_name,
"dimensions": len(test_embedding[0]) if test_embedding else 0,
"available_models": list(self.embedding_models.keys())
}
except Exception as e:
return {
"status": "unhealthy",
"service": "embedding_service",
"error": str(e)
}
# Global embedding service instance
embedding_service = EmbeddingService()

View File

@ -1,246 +1,113 @@
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional, Any
import asyncio
import uuid
import logging
from datetime import datetime
import json
"""
Main FastAPI application for ML training service
"""
# Setup logging
import logging
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from database import DatabaseManager
from embedding_service import EmbeddingService
from dataset_service import DatasetBuilder
from api_routes import router, set_dataset_builder
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="CVSA ML Training API",
version="1.0.0",
description="ML training service for video classification"
)
# Enable CORS for web UI
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global service instances
db_manager = None
embedding_service = None
dataset_builder = None
# Pydantic models
class Hyperparameter(BaseModel):
name: str
type: str # 'number', 'boolean', 'select'
value: Any
range: Optional[tuple] = None
options: Optional[List[str]] = None
description: Optional[str] = None
class TrainingConfig(BaseModel):
learning_rate: float = 1e-4
batch_size: int = 32
epochs: int = 10
early_stop: bool = True
patience: int = 3
embedding_model: str = "text-embedding-3-small"
class TrainingRequest(BaseModel):
experiment_name: str
config: TrainingConfig
dataset: Dict[str, Any]
class TrainingStatus(BaseModel):
experiment_id: str
status: str # 'pending', 'running', 'completed', 'failed'
progress: Optional[float] = None
current_epoch: Optional[int] = None
total_epochs: Optional[int] = None
metrics: Optional[Dict[str, float]] = None
error: Optional[str] = None
class ExperimentResult(BaseModel):
experiment_id: str
experiment_name: str
config: TrainingConfig
metrics: Dict[str, float]
created_at: str
status: str
class EmbeddingRequest(BaseModel):
texts: List[str]
model: str
# In-memory storage for experiments (in production, use database)
training_sessions: Dict[str, Dict] = {}
experiments: Dict[str, ExperimentResult] = {}
# Default hyperparameters that will be dynamically discovered
DEFAULT_HYPERPARAMETERS = [
Hyperparameter(
name="learning_rate",
type="number",
value=1e-4,
range=(1e-6, 1e-2),
description="Learning rate for optimizer"
),
Hyperparameter(
name="batch_size",
type="number",
value=32,
range=(8, 256),
description="Training batch size"
),
Hyperparameter(
name="epochs",
type="number",
value=10,
range=(1, 100),
description="Number of training epochs"
),
Hyperparameter(
name="early_stop",
type="boolean",
value=True,
description="Enable early stopping"
),
Hyperparameter(
name="patience",
type="number",
value=3,
range=(1, 20),
description="Early stopping patience"
),
Hyperparameter(
name="embedding_model",
type="select",
value="text-embedding-3-small",
options=["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
description="Embedding model to use"
)
]
@app.get("/")
async def root():
return {"message": "CVSA ML Training API", "version": "1.0.0"}
@app.get("/health")
async def health_check():
return {"status": "healthy", "service": "ml-training-api"}
@app.get("/hyperparameters", response_model=List[Hyperparameter])
async def get_hyperparameters():
"""Get all available hyperparameters for the current model"""
return DEFAULT_HYPERPARAMETERS
@app.post("/train")
async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks):
"""Start a new training experiment"""
experiment_id = str(uuid.uuid4())
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager for startup and shutdown events"""
global db_manager, embedding_service, dataset_builder
# Store training session
training_sessions[experiment_id] = {
"experiment_id": experiment_id,
"experiment_name": request.experiment_name,
"config": request.config.dict(),
"dataset": request.dataset,
"status": "pending",
"created_at": datetime.now().isoformat()
}
# Startup
logger.info("Initializing services...")
# Start background training task
background_tasks.add_task(run_training, experiment_id, request)
return {"experiment_id": experiment_id}
@app.get("/train/{experiment_id}/status", response_model=TrainingStatus)
async def get_training_status(experiment_id: str):
"""Get training status for an experiment"""
if experiment_id not in training_sessions:
raise HTTPException(status_code=404, detail="Experiment not found")
session = training_sessions[experiment_id]
return TrainingStatus(
experiment_id=experiment_id,
status=session.get("status", "unknown"),
progress=session.get("progress"),
current_epoch=session.get("current_epoch"),
total_epochs=session.get("total_epochs"),
metrics=session.get("metrics"),
error=session.get("error")
)
@app.get("/experiments", response_model=List[ExperimentResult])
async def list_experiments():
"""List all experiments"""
return list(experiments.values())
@app.get("/experiments/{experiment_id}", response_model=ExperimentResult)
async def get_experiment(experiment_id: str):
"""Get experiment details"""
if experiment_id not in experiments:
raise HTTPException(status_code=404, detail="Experiment not found")
return experiments[experiment_id]
@app.post("/embeddings")
async def generate_embeddings(request: EmbeddingRequest):
"""Generate embeddings using OpenAI-compatible API"""
# This is a placeholder implementation
# In production, this would call actual embedding API
embeddings = []
for text in request.texts:
# Mock embedding generation
embedding = [0.1] * 1536 # Mock 1536-dimensional embedding
embeddings.append(embedding)
return embeddings
async def run_training(experiment_id: str, request: TrainingRequest):
"""Background task to run training"""
try:
session = training_sessions[experiment_id]
session["status"] = "running"
session["total_epochs"] = request.config.epochs
# Database manager
db_manager = DatabaseManager()
await db_manager.connect() # Initialize database connection pool
logger.info("Database manager initialized and connected")
# Simulate training process
for epoch in range(request.config.epochs):
session["current_epoch"] = epoch + 1
session["progress"] = (epoch + 1) / request.config.epochs
# Simulate training metrics
session["metrics"] = {
"loss": max(0.0, 1.0 - (epoch + 1) * 0.1),
"accuracy": min(0.95, 0.5 + (epoch + 1) * 0.05),
"val_loss": max(0.0, 0.8 - (epoch + 1) * 0.08),
"val_accuracy": min(0.92, 0.45 + (epoch + 1) * 0.04)
}
logger.info(f"Training epoch {epoch + 1}/{request.config.epochs}")
await asyncio.sleep(1) # Simulate training time
# Embedding service
embedding_service = EmbeddingService()
logger.info("Embedding service initialized")
# Training completed
session["status"] = "completed"
final_metrics = session["metrics"]
# Dataset builder
dataset_builder = DatasetBuilder(db_manager, embedding_service)
logger.info("Dataset builder initialized")
# Store final experiment result
experiments[experiment_id] = ExperimentResult(
experiment_id=experiment_id,
experiment_name=request.experiment_name,
config=request.config,
metrics=final_metrics,
created_at=session["created_at"],
status="completed"
)
# Set global dataset builder instance
set_dataset_builder(dataset_builder)
logger.info(f"Training completed for experiment {experiment_id}")
logger.info("All services initialized successfully")
except Exception as e:
session["status"] = "failed"
session["error"] = str(e)
logger.error(f"Training failed for experiment {experiment_id}: {str(e)}")
logger.error(f"Failed to initialize services: {e}")
raise
# Yield control to the application
yield
# Shutdown
logger.info("Shutting down services...")
try:
if db_manager:
await db_manager.close()
logger.info("Database connection pool closed")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
def create_app() -> FastAPI:
"""Create and configure FastAPI application"""
# Create FastAPI app with lifespan manager
app = FastAPI(
title="ML Training Service",
description="ML training, dataset building, and experiment management service",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include API routes
app.include_router(router)
return app
def main():
"""Main entry point"""
app = create_app()
# Run the application
uvicorn.run(
app,
host="0.0.0.0",
port=8322,
log_level="info",
access_log=True
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
main()

62
ml_new/training/models.py Normal file
View File

@ -0,0 +1,62 @@
"""
Data models for dataset building functionality
"""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from datetime import datetime
class DatasetBuildRequest(BaseModel):
"""Request model for dataset building"""
aid_list: List[int] = Field(..., description="List of video AIDs")
embedding_model: str = Field(..., description="Embedding model name")
force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings")
class DatasetBuildResponse(BaseModel):
"""Response model for dataset building"""
dataset_id: str
total_records: int
status: str
message: str
created_at: Optional[datetime] = None
class DatasetRecord(BaseModel):
"""Model for a single dataset record"""
aid: int
embedding: List[float]
label: bool
metadata: Dict[str, Any]
user_labels: List[Dict[str, Any]]
inconsistent: bool
text_checksum: str
class DatasetInfo(BaseModel):
"""Model for dataset information"""
dataset_id: str
dataset: List[DatasetRecord]
stats: Dict[str, Any]
created_at: datetime
class DatasetBuildStats(BaseModel):
"""Statistics for dataset building process"""
total_records: int
new_embeddings: int
reused_embeddings: int
inconsistent_labels: int
embedding_model: str
processing_time: Optional[float] = None
class EmbeddingModelInfo(BaseModel):
"""Information about embedding models"""
name: str
dimensions: int
type: str
api_endpoint: Optional[str] = None
max_tokens: Optional[int] = None
max_batch_size: Optional[int] = None

View File

@ -9,4 +9,8 @@ scikit-learn==1.3.2
pandas==2.1.3
openai==1.3.7
psycopg2-binary==2.9.9
sqlalchemy==2.0.23
sqlalchemy==2.0.23
asyncpg==0.29.0
toml==0.10.2
aiohttp==3.9.0
python-dotenv==1.1.0

26
mutagen.yml Normal file
View File

@ -0,0 +1,26 @@
# mutagen.yml
sync:
development:
alpha: "."
beta: "root@cvsa-hk-02:/web/cvsa"
maxStagingFileSize: "5MB"
mode: "two-way-resolved"
ignore:
paths:
- "**/node_modules/"
- "*.log"
- "/model"
- "**/.DS_Store"
- ".env"
- ".env.*"
- "**/logs/"
- ".git/"
- ".jj/"
- "dist"
- "**/build/"
- "**/.react-router"
- "redis/"
- "temp/"
- "mutagen.yml"
- "/ml_new/"
- "/ml/"

1
mutagen.yml.lock Normal file
View File

@ -0,0 +1 @@
proj_EPp5s45rolBZ729uBHaoBMzV51GCoGkpnJuTkK6owzH

267
packages/ml_panel/README.md Normal file
View File

@ -0,0 +1,267 @@
# CVSA ML 基础设施重构项目
## 项目概述
本项目旨在重构现有的 ML 服务基础设施,从原始的 `ml/filter` 系统迁移到新的前后端分离架构。主要目标是为 ML 训练、实验管理和数据处理提供一个现代化的 Web UI 界面。
### 核心功能
- **数据管线管理**: 从 PostgreSQL 数据库获取和预处理训练数据
- **实验管理**: 训练参数配置、实验追踪和结果可视化
- **超参数调优**: 动态超参数配置和调整
- **数据标注界面**: 简单易用的数据标注和管理工具
- **模型训练**: 2分类视频分类模型训练
- **嵌入向量管理**: 支持多种嵌入模型和向量维度
## 架构设计
### 技术栈
- **前端**: React + TypeScript + Vite + Tailwind CSS + shadcn/ui
- **后端**: FastAPI + Python
- **数据库**: PostgreSQL + Drizzle ORM
- **向量数据库**: PostgreSQL pgvector
- **包管理**: Bun (TypeScript) + pip (Python)
### 分层架构
```
┌─────────────────┐ ┌──────────────┐ ┌──────────────────┐
│ Web UI │ │ FastAPI │ │ Database │
│ (React TS) │◄──►│ (Python) │◄──►│ (PostgreSQL) │
└─────────────────┘ └──────────────┘ └──────────────────┘
```
## 目录结构
### 前端项目 (`packages/ml_panel/`)
```
packages/ml_panel/
├── src/ # 前端应用
│ ├── App.tsx # 主应用组件
│ ├── main.tsx # 应用入口
│ ├── index.css # 全局样式
│ └── lib/
│ └── utils.ts # 前端工具函数
├── lib/ # 核心库文件
│ ├── types.ts # 共享类型定义
│ ├── ml-client.ts # ML API 客户端
│ ├── data-pipeline/ # 数据管线类型
│ │ └── types.ts
│ └── index.ts # 导出文件
├── package.json
├── vite.config.ts
└── tailwind.config.js
```
### 后端服务 (`ml_new/training/`)
```
ml_new/training/
├── main.py # FastAPI 主服务
├── requirements.txt # Python 依赖
└── ... # 其他服务文件
```
### 数据库 Schema
使用现有的 `packages/core/drizzle/main/schema.ts` 中的定义:
- `videoTypeLabelInInternal`: 用户标注数据
- `embeddingsInInternal`: 嵌入向量存储
- `bilibiliMetadata`: 视频元数据
## 已完成的工作
### 1. 核心类型定义
**文件**: `packages/ml_panel/lib/types.ts`
- 定义了核心数据结构
- `DatasetRecord`: 数据集记录
- `UserLabel`: 用户标注
- `EmbeddingModel`: 嵌入模型配置
- `TrainingConfig`: 训练配置
- `ExperimentResult`: 实验结果
- `InconsistentLabel`: 标注不一致数据
### 2. 数据管线类型
**文件**: `packages/ml_panel/lib/data-pipeline/types.ts`
- `VideoMetadata`: 视频元数据
- `VideoTypeLabel`: 标注数据
- `EmbeddingRecord`: 嵌入记录
- `DataPipelineConfig`: 管线配置
- `ProcessedDataset`: 处理后的数据集
### 3. ML 客户端
**文件**: `packages/ml_panel/lib/ml-client.ts`
- `MLClient` 类用于与 FastAPI 通信
- 超参数获取和更新
- 训练任务启动和状态监控
- 实验管理
- 嵌入生成接口
### 4. FastAPI 服务框架
**文件**: `ml_new/training/main.py`
- 基础的 FastAPI 应用配置
- CORS 中间件配置
- 内存存储的训练会话管理
- 基础的 API 端点定义
### 5. 项目配置
- **前端**: `packages/ml_panel/package.json` - React + Vite + TypeScript 配置
- **后端**: `ml_new/training/requirements.txt` - Python 依赖
- **主项目**: `packages/ml/package.json` - Monorepo 工作空间配置
## 核心功能实现状态
### 已完成
- [x] 基础项目结构搭建
- [x] 核心类型定义
- [x] ML API 客户端
- [x] FastAPI 服务框架
- [x] 前端项目配置
### 待实现
- [ ] 数据管线核心逻辑实现
- [ ] React UI 组件开发
- [ ] FastAPI 服务功能完善
- [ ] 数据库连接和数据获取逻辑
- [ ] 用户标注数据处理
- [ ] 嵌入向量管理
- [ ] 标注一致性检查
- [ ] 训练任务队列
- [ ] 实验追踪和可视化
- [ ] 超参数动态配置
- [ ] 完整的前后端集成测试
## 数据流程设计
### 1. 数据集创建流程 (高 RTT 优化)
```mermaid
graph TD
A[前端点击创建数据集] --> B[选定嵌入模型]
B --> C[从 TOML 配置载入模型参数]
C --> D[从数据库批量拉取原始文本]
D --> E[文本预处理]
E --> F[计算文本 hash 并去重]
F --> G[批量查询已有嵌入]
G --> H[区分需要生成的新文本]
H --> I[批量调用嵌入 API]
I --> J[批量写入 embeddings 表]
J --> K[拉取完整 embeddings 数据]
K --> L[合并标签数据]
L --> M[构建最终数据集<br/>格式: embeddings, label]
```
### 2. 数据获取流程 (数据库优化)
```mermaid
graph TD
A[PostgreSQL 远程数据库<br/>RTT: 100ms] --> B[videoTypeLabelInInternal]
A --> C[embeddingsInInternal]
A --> D[bilibiliMetadata]
B --> E[批量获取用户最后一次标注<br/>避免循环查询]
C --> F[批量获取嵌入向量<br/>一次性查询所有维度]
D --> G[批量获取视频元数据<br/>IN 查询避免 N+1]
E --> H[标注一致性检查]
F --> I[向量数据处理]
G --> J[数据合并]
H --> K[数据集构建]
I --> K
J --> K
```
### 3. 训练流程
```mermaid
graph TD
A[前端配置] --> B[超参数设置]
B --> C[FastAPI 接收]
C --> D[数据管线处理]
D --> E[模型训练]
E --> F[实时状态更新]
F --> G[结果存储]
G --> H[前端展示]
```
## 技术要点
### 1. 高性能数据库设计 (RTT 优化)
- **批量操作**: 避免循环查询,使用 `IN` 语句和批量 `INSERT/UPDATE`
### 2. 嵌入向量管理
- **多模型支持**: `embeddingsInInternal` 支持不同维度的向量 (2048/1536/1024)
- **去重机制**: 使用文本 hash 去重,避免重复生成嵌入
- **批量处理**: 批量生成和存储嵌入向量
- **缓存策略**: 优先使用已存在的嵌入向量
### 3. 2分类模型架构
- 从原有的 3分类系统迁移到 2分类
- 输入: 预计算的嵌入向量 (而非原始文本)
- 支持多种嵌入模型切换
### 4. 数据一致性处理
- **用户标注**: `videoTypeLabelInInternal` 存储多用户标注
- **最后标注**: 获取每个用户的最后一次标注作为有效数据
- **一致性检查**: 识别不同用户标注不一致的视频,标记为需要人工复核
## 后续开发计划
### Phase 1: 核心功能 (优先)
1. **数据管线实现**
- 标注数据获取和一致性检查
- 嵌入向量生成和存储
- 数据集构建逻辑
2. **FastAPI 服务完善**
- 构建新的模型架构(输入嵌入向量,直接二分类头)
- 迁移现有 ml/filter 训练逻辑
- 实现超参数动态暴露
- 集成 OpenAI 兼容嵌入 API
- 训练任务队列管理
### Phase 2: 用户界面
1. **数据集创建界面**
- 嵌入模型选择
- 数据预览和筛选
- 处理进度显示
2. **训练参数配置界面**
- 超参数动态渲染
- 参数验证和约束
3. **实验管理和追踪**
- 实验历史和比较
- 训练状态实时监控
- 结果可视化
### Phase 3: 高级功能
1. **超参数自动调优**
2. **模型版本管理**
3. **批量训练支持**
4. **性能优化**
## 注意事项
1. **数据库性能**: 远程数据库 RTT 高,避免 N+1 查询,使用批量操作
2. **标注一致性**: 实现自动的标注不一致检测
3. **嵌入模型支持**: 为未来扩展多种嵌入模型预留接口

View File

@ -1,47 +0,0 @@
// Data pipeline specific types
import type { DatasetRecord, UserLabel, EmbeddingModel, InconsistentLabel } from "../types";
// Database types from packages/core
export interface VideoMetadata {
aid: number;
title: string;
description: string;
tags: string;
createdAt?: string;
}
export interface VideoTypeLabel {
id: number;
aid: number;
label: boolean;
user: string;
createdAt: string;
}
export interface EmbeddingRecord {
id: number;
modelName: string;
dataChecksum: string;
vec2048?: number[];
vec1536?: number[];
vec1024?: number[];
createdAt?: string;
}
export interface DataPipelineConfig {
embeddingModels: EmbeddingModel[];
batchSize: number;
requireConsensus: boolean;
maxInconsistentRatio: number;
}
export interface ProcessedDataset {
records: DatasetRecord[];
inconsistentLabels: InconsistentLabel[];
statistics: {
totalRecords: number;
labeledRecords: number;
inconsistentRecords: number;
embeddingCoverage: Record<string, number>;
};
}

View File

@ -1,107 +0,0 @@
// ML Client for communicating with FastAPI service
import type { TrainingConfig, ExperimentResult } from './types';
export interface Hyperparameter {
name: string;
type: 'number' | 'boolean' | 'select';
value: any;
range?: [number, number];
options?: string[];
description?: string;
}
export interface TrainingRequest {
experimentName: string;
config: TrainingConfig;
dataset: {
aid: number[];
embeddings: Record<string, number[]>;
labels: Record<number, boolean>;
};
}
export interface TrainingStatus {
experimentId: string;
status: 'pending' | 'running' | 'completed' | 'failed';
progress?: number;
currentEpoch?: number;
totalEpochs?: number;
metrics?: Record<string, number>;
error?: string;
}
export class MLClient {
private baseUrl: string;
constructor(baseUrl: string = 'http://localhost:8000') {
this.baseUrl = baseUrl;
}
// Get available hyperparameters from the model
async getHyperparameters(): Promise<Hyperparameter[]> {
const response = await fetch(`${this.baseUrl}/hyperparameters`);
if (!response.ok) {
throw new Error(`Failed to get hyperparameters: ${response.statusText}`);
}
return (await response.json()) as Hyperparameter[];
}
// Start a training experiment
async startTraining(request: TrainingRequest): Promise<{ experimentId: string }> {
const response = await fetch(`${this.baseUrl}/train`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(request),
});
if (!response.ok) {
throw new Error(`Failed to start training: ${response.statusText}`);
}
return (await response.json()) as { experimentId: string };
}
// Get training status
async getTrainingStatus(experimentId: string): Promise<TrainingStatus> {
const response = await fetch(`${this.baseUrl}/train/${experimentId}/status`);
if (!response.ok) {
throw new Error(`Failed to get training status: ${response.statusText}`);
}
return (await response.json()) as TrainingStatus;
}
// Get experiment results
async getExperimentResult(experimentId: string): Promise<ExperimentResult> {
const response = await fetch(`${this.baseUrl}/experiments/${experimentId}`);
if (!response.ok) {
throw new Error(`Failed to get experiment result: ${response.statusText}`);
}
return (await response.json()) as ExperimentResult;
}
// List all experiments
async listExperiments(): Promise<ExperimentResult[]> {
const response = await fetch(`${this.baseUrl}/experiments`);
if (!response.ok) {
throw new Error(`Failed to list experiments: ${response.statusText}`);
}
return (await response.json()) as ExperimentResult[];
}
// Generate embeddings using OpenAI-compatible API
async generateEmbeddings(texts: string[], model: string): Promise<number[][]> {
const response = await fetch(`${this.baseUrl}/embeddings`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ texts, model }),
});
if (!response.ok) {
throw new Error(`Failed to generate embeddings: ${response.statusText}`);
}
return (await response.json()) as number[][];
}
}

View File

@ -1,54 +0,0 @@
// Shared ML types and interfaces
export interface DatasetRecord {
aid: number;
title: string;
description: string;
tags: string;
embedding?: number[];
label?: boolean;
userLabels?: UserLabel[];
}
export interface UserLabel {
user: string;
label: boolean;
createdAt: string;
}
export interface EmbeddingModel {
name: string;
dimensions: number;
type: "openai-compatible" | "local";
apiEndpoint?: string;
}
export interface TrainingConfig {
learningRate: number;
batchSize: number;
epochs: number;
earlyStop: boolean;
patience?: number;
embeddingModel: string;
}
export interface ExperimentResult {
experimentId: string;
config: TrainingConfig;
metrics: {
accuracy: number;
precision: number;
recall: number;
f1: number;
};
createdAt: string;
status: "running" | "completed" | "failed";
}
export interface InconsistentLabel {
aid: number;
title: string;
description: string;
tags: string;
labels: UserLabel[];
consensus?: boolean;
}