1
0
cvsa/ml_new/main.py

117 lines
2.8 KiB
Python

"""
Main FastAPI application for ML training service
"""
import os
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from data.database import DatabaseManager
from data.embedding_service import EmbeddingService
from data.dataset_service import DatasetBuilder
from ml_new.routes.main import router, set_dataset_builder
from config.logger_config import get_logger
logger = get_logger(__name__)
# Global service instances
db_manager = None
embedding_service = None
dataset_builder = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan manager for startup and shutdown events"""
global db_manager, embedding_service, dataset_builder
# Startup
logger.info("Initializing services...")
try:
# Database manager
db_manager = DatabaseManager()
await db_manager.connect() # Initialize database connection pool
logger.info("Database manager initialized and connected")
# Embedding service
embedding_service = EmbeddingService()
logger.info("Embedding service initialized")
# Dataset builder
dataset_builder = DatasetBuilder(
db_manager,
embedding_service,
os.path.join(os.path.dirname(__file__), "./datasets"),
)
logger.info("Dataset builder initialized")
# Set global dataset builder instance
set_dataset_builder(dataset_builder)
logger.info("All services initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize services: {e}")
raise
# Yield control to the application
yield
# Shutdown
logger.info("Shutting down services...")
try:
if db_manager:
await db_manager.close()
logger.info("Database connection pool closed")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
def create_app() -> FastAPI:
"""Create and configure FastAPI application"""
# Create FastAPI app with lifespan manager
app = FastAPI(
title="ML Training Service",
description="ML training, dataset building, and experiment management service",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include API routes
app.include_router(router)
return app
def main():
"""Main entry point"""
app = create_app()
# Run the application
uvicorn.run(
app,
host="0.0.0.0",
port=8322,
log_level="info",
access_log=True,
)
if __name__ == "__main__":
main()