1
0
cvsa/ml_new/training/main.py

246 lines
7.6 KiB
Python

from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional, Any
import asyncio
import uuid
import logging
from datetime import datetime
import json
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="CVSA ML Training API",
version="1.0.0",
description="ML training service for video classification"
)
# Enable CORS for web UI
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic models
class Hyperparameter(BaseModel):
name: str
type: str # 'number', 'boolean', 'select'
value: Any
range: Optional[tuple] = None
options: Optional[List[str]] = None
description: Optional[str] = None
class TrainingConfig(BaseModel):
learning_rate: float = 1e-4
batch_size: int = 32
epochs: int = 10
early_stop: bool = True
patience: int = 3
embedding_model: str = "text-embedding-3-small"
class TrainingRequest(BaseModel):
experiment_name: str
config: TrainingConfig
dataset: Dict[str, Any]
class TrainingStatus(BaseModel):
experiment_id: str
status: str # 'pending', 'running', 'completed', 'failed'
progress: Optional[float] = None
current_epoch: Optional[int] = None
total_epochs: Optional[int] = None
metrics: Optional[Dict[str, float]] = None
error: Optional[str] = None
class ExperimentResult(BaseModel):
experiment_id: str
experiment_name: str
config: TrainingConfig
metrics: Dict[str, float]
created_at: str
status: str
class EmbeddingRequest(BaseModel):
texts: List[str]
model: str
# In-memory storage for experiments (in production, use database)
training_sessions: Dict[str, Dict] = {}
experiments: Dict[str, ExperimentResult] = {}
# Default hyperparameters that will be dynamically discovered
DEFAULT_HYPERPARAMETERS = [
Hyperparameter(
name="learning_rate",
type="number",
value=1e-4,
range=(1e-6, 1e-2),
description="Learning rate for optimizer"
),
Hyperparameter(
name="batch_size",
type="number",
value=32,
range=(8, 256),
description="Training batch size"
),
Hyperparameter(
name="epochs",
type="number",
value=10,
range=(1, 100),
description="Number of training epochs"
),
Hyperparameter(
name="early_stop",
type="boolean",
value=True,
description="Enable early stopping"
),
Hyperparameter(
name="patience",
type="number",
value=3,
range=(1, 20),
description="Early stopping patience"
),
Hyperparameter(
name="embedding_model",
type="select",
value="text-embedding-3-small",
options=["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"],
description="Embedding model to use"
)
]
@app.get("/")
async def root():
return {"message": "CVSA ML Training API", "version": "1.0.0"}
@app.get("/health")
async def health_check():
return {"status": "healthy", "service": "ml-training-api"}
@app.get("/hyperparameters", response_model=List[Hyperparameter])
async def get_hyperparameters():
"""Get all available hyperparameters for the current model"""
return DEFAULT_HYPERPARAMETERS
@app.post("/train")
async def start_training(request: TrainingRequest, background_tasks: BackgroundTasks):
"""Start a new training experiment"""
experiment_id = str(uuid.uuid4())
# Store training session
training_sessions[experiment_id] = {
"experiment_id": experiment_id,
"experiment_name": request.experiment_name,
"config": request.config.dict(),
"dataset": request.dataset,
"status": "pending",
"created_at": datetime.now().isoformat()
}
# Start background training task
background_tasks.add_task(run_training, experiment_id, request)
return {"experiment_id": experiment_id}
@app.get("/train/{experiment_id}/status", response_model=TrainingStatus)
async def get_training_status(experiment_id: str):
"""Get training status for an experiment"""
if experiment_id not in training_sessions:
raise HTTPException(status_code=404, detail="Experiment not found")
session = training_sessions[experiment_id]
return TrainingStatus(
experiment_id=experiment_id,
status=session.get("status", "unknown"),
progress=session.get("progress"),
current_epoch=session.get("current_epoch"),
total_epochs=session.get("total_epochs"),
metrics=session.get("metrics"),
error=session.get("error")
)
@app.get("/experiments", response_model=List[ExperimentResult])
async def list_experiments():
"""List all experiments"""
return list(experiments.values())
@app.get("/experiments/{experiment_id}", response_model=ExperimentResult)
async def get_experiment(experiment_id: str):
"""Get experiment details"""
if experiment_id not in experiments:
raise HTTPException(status_code=404, detail="Experiment not found")
return experiments[experiment_id]
@app.post("/embeddings")
async def generate_embeddings(request: EmbeddingRequest):
"""Generate embeddings using OpenAI-compatible API"""
# This is a placeholder implementation
# In production, this would call actual embedding API
embeddings = []
for text in request.texts:
# Mock embedding generation
embedding = [0.1] * 1536 # Mock 1536-dimensional embedding
embeddings.append(embedding)
return embeddings
async def run_training(experiment_id: str, request: TrainingRequest):
"""Background task to run training"""
try:
session = training_sessions[experiment_id]
session["status"] = "running"
session["total_epochs"] = request.config.epochs
# Simulate training process
for epoch in range(request.config.epochs):
session["current_epoch"] = epoch + 1
session["progress"] = (epoch + 1) / request.config.epochs
# Simulate training metrics
session["metrics"] = {
"loss": max(0.0, 1.0 - (epoch + 1) * 0.1),
"accuracy": min(0.95, 0.5 + (epoch + 1) * 0.05),
"val_loss": max(0.0, 0.8 - (epoch + 1) * 0.08),
"val_accuracy": min(0.92, 0.45 + (epoch + 1) * 0.04)
}
logger.info(f"Training epoch {epoch + 1}/{request.config.epochs}")
await asyncio.sleep(1) # Simulate training time
# Training completed
session["status"] = "completed"
final_metrics = session["metrics"]
# Store final experiment result
experiments[experiment_id] = ExperimentResult(
experiment_id=experiment_id,
experiment_name=request.experiment_name,
config=request.config,
metrics=final_metrics,
created_at=session["created_at"],
status="completed"
)
logger.info(f"Training completed for experiment {experiment_id}")
except Exception as e:
session["status"] = "failed"
session["error"] = str(e)
logger.error(f"Training failed for experiment {experiment_id}: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)