389 lines
13 KiB
Python
389 lines
13 KiB
Python
"""
|
|
API routes for the ML training service
|
|
"""
|
|
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from ml_new.config.config_loader import config_loader
|
|
from models import (
|
|
DatasetBuildRequest,
|
|
DatasetBuildResponse,
|
|
TaskStatus,
|
|
TaskStatusResponse,
|
|
TaskListResponse,
|
|
SamplingRequest,
|
|
SamplingResponse,
|
|
DatasetCreateRequest,
|
|
DatasetCreateResponse
|
|
)
|
|
from ml_new.data.dataset_service import DatasetBuilder
|
|
from ml_new.config.logger_config import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Create router
|
|
router = APIRouter(prefix="/v1")
|
|
|
|
# Global dataset builder instance (will be set by main.py)
|
|
dataset_builder: DatasetBuilder = None
|
|
|
|
|
|
def set_dataset_builder(builder: DatasetBuilder):
|
|
"""Set the global dataset builder instance"""
|
|
global dataset_builder
|
|
dataset_builder = builder
|
|
|
|
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
if not dataset_builder:
|
|
return JSONResponse(
|
|
status_code=503,
|
|
content={"status": "unavailable", "message": "Dataset builder not initialized"}
|
|
)
|
|
|
|
try:
|
|
# Check embedding service health
|
|
embedding_health = await dataset_builder.embedding_service.health_check()
|
|
except Exception as e:
|
|
embedding_health = {"status": "unhealthy", "error": str(e)}
|
|
|
|
# Check database connection (pool should already be initialized)
|
|
db_status = "disconnected"
|
|
if dataset_builder.db_manager.is_connected:
|
|
try:
|
|
response = await dataset_builder.db_manager.pool.fetch("SELECT 1 FROM information_schema.tables")
|
|
db_status = "connected" if response else "disconnected"
|
|
except Exception as e:
|
|
db_status = f"error: {str(e)}"
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"service": "ml-training-api",
|
|
"embedding_service": embedding_health,
|
|
"database": db_status,
|
|
"available_models": list(config_loader.get_embedding_models().keys())
|
|
}
|
|
|
|
|
|
@router.get("/models/embedding")
|
|
async def get_embedding_models():
|
|
"""Get available embedding models"""
|
|
return {
|
|
"models": {
|
|
name: {
|
|
"name": config.name,
|
|
"dimensions": config.dimensions,
|
|
"type": config.type,
|
|
"api_endpoint": config.api_endpoint,
|
|
"max_tokens": config.max_tokens,
|
|
"max_batch_size": config.max_batch_size
|
|
}
|
|
for name, config in config_loader.get_embedding_models().items()
|
|
}
|
|
}
|
|
|
|
|
|
@router.post("/dataset/build", response_model=DatasetBuildResponse)
|
|
async def build_dataset_endpoint(request: DatasetBuildRequest):
|
|
"""Build dataset endpoint with task tracking"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
# Validate embedding model
|
|
if request.embedding_model not in config_loader.get_embedding_models():
|
|
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
|
|
|
|
dataset_id = str(uuid.uuid4())
|
|
|
|
# Start task-based dataset building
|
|
task_id = await dataset_builder.start_dataset_build_task(
|
|
dataset_id,
|
|
request.aid_list,
|
|
request.embedding_model,
|
|
request.force_regenerate,
|
|
request.description
|
|
)
|
|
|
|
return DatasetBuildResponse(
|
|
dataset_id=dataset_id,
|
|
total_records=len(request.aid_list),
|
|
status="started",
|
|
message=f"Dataset building started with task ID: {task_id}",
|
|
description=request.description
|
|
)
|
|
|
|
|
|
@router.get("/dataset/{dataset_id}")
|
|
async def get_dataset_endpoint(dataset_id: str):
|
|
"""Get built dataset by ID"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
if not dataset_builder.dataset_exists(dataset_id):
|
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
|
|
|
dataset_info = dataset_builder.get_dataset(dataset_id)
|
|
|
|
if "error" in dataset_info:
|
|
raise HTTPException(status_code=500, detail=dataset_info["error"])
|
|
|
|
return {
|
|
"dataset_id": dataset_id,
|
|
"dataset": dataset_info["dataset"],
|
|
"description": dataset_info.get("description"),
|
|
"stats": dataset_info["stats"],
|
|
"created_at": dataset_info["created_at"]
|
|
}
|
|
|
|
|
|
@router.get("/datasets")
|
|
async def list_datasets_endpoint():
|
|
"""List all built datasets"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
datasets = dataset_builder.list_datasets()
|
|
# Add description to each dataset
|
|
datasets_with_description = []
|
|
for dataset in datasets:
|
|
dataset_info = dataset_builder.get_dataset(dataset["dataset_id"])
|
|
if dataset_info and "description" in dataset_info:
|
|
dataset["description"] = dataset_info["description"]
|
|
else:
|
|
dataset["description"] = None
|
|
datasets_with_description.append(dataset)
|
|
return {"datasets": datasets_with_description}
|
|
|
|
|
|
@router.delete("/dataset/{dataset_id}")
|
|
async def delete_dataset_endpoint(dataset_id: str):
|
|
"""Delete a built dataset"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
if dataset_builder.delete_dataset(dataset_id):
|
|
return {"message": f"Dataset {dataset_id} deleted successfully"}
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Dataset not found")
|
|
|
|
|
|
@router.get("/datasets/stats")
|
|
async def get_dataset_stats_endpoint():
|
|
"""Get overall statistics about stored datasets"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
stats = dataset_builder.get_dataset_stats()
|
|
return stats
|
|
|
|
# Task Status Endpoints
|
|
|
|
@router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
|
|
async def get_task_status_endpoint(task_id: str):
|
|
"""Get status of a specific task"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
task_status = dataset_builder.get_task_status(task_id)
|
|
if not task_status:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
# Convert to response model
|
|
progress_dict = None
|
|
if task_status.progress:
|
|
progress_dict = {
|
|
"current_step": task_status.progress.current_step,
|
|
"total_steps": task_status.progress.total_steps,
|
|
"completed_steps": task_status.progress.completed_steps,
|
|
"percentage": task_status.progress.percentage,
|
|
"message": task_status.progress.message
|
|
}
|
|
|
|
return TaskStatusResponse(
|
|
task_id=task_status.task_id,
|
|
status=task_status.status,
|
|
progress=progress_dict,
|
|
result=task_status.result,
|
|
error=task_status.error_message,
|
|
created_at=task_status.created_at,
|
|
started_at=task_status.started_at,
|
|
completed_at=task_status.completed_at
|
|
)
|
|
|
|
|
|
@router.get("/tasks", response_model=TaskListResponse)
|
|
async def list_tasks_endpoint(status: Optional[TaskStatus] = None, limit: int = 50):
|
|
"""List all tasks, optionally filtered by status"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
tasks = dataset_builder.list_tasks(status_filter=status)
|
|
|
|
# Limit results
|
|
if limit > 0:
|
|
tasks = tasks[:limit]
|
|
|
|
# Convert to response models
|
|
task_responses = []
|
|
for task_status in tasks:
|
|
progress_dict = None
|
|
if task_status.progress:
|
|
progress_dict = {
|
|
"current_step": task_status.progress.current_step,
|
|
"total_steps": task_status.progress.total_steps,
|
|
"completed_steps": task_status.progress.completed_steps,
|
|
"percentage": task_status.progress.percentage,
|
|
"message": task_status.progress.message
|
|
}
|
|
|
|
task_responses.append(TaskStatusResponse(
|
|
task_id=task_status.task_id,
|
|
status=task_status.status,
|
|
progress=progress_dict,
|
|
result=task_status.result,
|
|
error=task_status.error_message,
|
|
created_at=task_status.created_at,
|
|
started_at=task_status.started_at,
|
|
completed_at=task_status.completed_at
|
|
))
|
|
|
|
# Get statistics
|
|
stats = dataset_builder.get_task_statistics()
|
|
|
|
return TaskListResponse(
|
|
tasks=task_responses,
|
|
total_count=stats["total_tasks"],
|
|
pending_count=stats["status_counts"][TaskStatus.PENDING],
|
|
running_count=stats["status_counts"][TaskStatus.RUNNING],
|
|
completed_count=stats["status_counts"][TaskStatus.COMPLETED],
|
|
failed_count=stats["status_counts"][TaskStatus.FAILED]
|
|
)
|
|
|
|
|
|
@router.get("/tasks/stats")
|
|
async def get_task_statistics_endpoint():
|
|
"""Get statistics about all tasks"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
return dataset_builder.get_task_statistics()
|
|
|
|
|
|
@router.post("/tasks/cleanup")
|
|
async def cleanup_tasks_endpoint(max_age_hours: int = 24):
|
|
"""Clean up completed/failed tasks older than specified hours"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
cleaned_count = await dataset_builder.cleanup_completed_tasks(max_age_hours)
|
|
return {"message": f"Cleaned up {cleaned_count} tasks older than {max_age_hours} hours"}
|
|
|
|
|
|
# Sampling Endpoints
|
|
|
|
@router.post("/dataset/sample", response_model=SamplingResponse)
|
|
async def sample_dataset_endpoint(request: SamplingRequest):
|
|
"""Sample AIDs based on strategy"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
try:
|
|
# Get AIDs based on strategy
|
|
aid_list = await dataset_builder.db_manager.get_aids_by_strategy(
|
|
strategy=request.strategy,
|
|
limit=request.limit,
|
|
)
|
|
|
|
# Get statistics
|
|
total_available = await dataset_builder.db_manager.get_all_aids_count()
|
|
|
|
return SamplingResponse(
|
|
strategy=request.strategy,
|
|
total_available=total_available,
|
|
sampled_count=len(aid_list),
|
|
aid_list=aid_list,
|
|
filters_applied={
|
|
"limit": request.limit
|
|
},
|
|
sampling_info={
|
|
"strategy_description": _get_strategy_description(request.strategy),
|
|
"sample_ratio": len(aid_list) / total_available if total_available > 0 else 0
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Sampling failed: {str(e)}")
|
|
raise HTTPException(status_code=400, detail=f"Sampling failed: {str(e)}")
|
|
|
|
@router.post("/dataset/create-with-sampling", response_model=DatasetCreateResponse)
|
|
async def create_dataset_with_sampling_endpoint(request: DatasetCreateRequest):
|
|
"""Create dataset using sampling strategy"""
|
|
|
|
if not dataset_builder:
|
|
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
|
|
|
# Validate embedding model
|
|
if request.embedding_model not in config_loader.get_embedding_models():
|
|
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
|
|
|
|
import uuid
|
|
dataset_id = str(uuid.uuid4())
|
|
|
|
try:
|
|
# First sample the AIDs
|
|
sampling_response = await sample_dataset_endpoint(request.sampling)
|
|
aid_list = sampling_response.aid_list
|
|
|
|
if not aid_list:
|
|
raise HTTPException(status_code=400, detail="No AIDs found matching the sampling criteria")
|
|
|
|
# Start task-based dataset building with sampled AIDs
|
|
task_id = await dataset_builder.start_dataset_build_task(
|
|
dataset_id,
|
|
aid_list,
|
|
request.embedding_model,
|
|
request.force_regenerate,
|
|
request.description
|
|
)
|
|
|
|
return DatasetCreateResponse(
|
|
dataset_id=dataset_id,
|
|
sampling_response=sampling_response,
|
|
task_id=task_id,
|
|
total_records=len(aid_list),
|
|
status="started",
|
|
message=f"Dataset building started with task ID: {task_id}",
|
|
description=request.description
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Dataset creation with sampling failed: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Dataset creation failed: {str(e)}")
|
|
|
|
|
|
def _get_strategy_description(strategy: str) -> str:
|
|
"""Get description for sampling strategy"""
|
|
descriptions = {
|
|
"all": "All labeled videos in the database",
|
|
"random": "Randomly sampled labeled videos"
|
|
}
|
|
return descriptions.get(strategy, "Unknown sampling strategy") |