diff --git a/.gitignore b/.gitignore index 3bb3b08..bc1f0ad 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,7 @@ model/ *.db *.sqlite *.sqlite3 -data/ +./data/ redis/ # Build diff --git a/ml_new/config_loader.py b/ml_new/config/config_loader.py similarity index 95% rename from ml_new/config_loader.py rename to ml_new/config/config_loader.py index 5977c2b..f04ee06 100644 --- a/ml_new/config_loader.py +++ b/ml_new/config/config_loader.py @@ -6,7 +6,7 @@ import toml import os from typing import Dict from pydantic import BaseModel -from logger_config import get_logger +from ml_new.config.logger_config import get_logger logger = get_logger(__name__) @@ -28,7 +28,7 @@ class ConfigLoader: if config_path is None: # Default to the embedding_models.toml file we created config_path = os.path.join( - os.path.dirname(__file__), "embedding_models.toml" + os.path.dirname(__file__), "..", "embedding_models.toml" ) self.config_path = config_path diff --git a/ml_new/logger_config.py b/ml_new/config/logger_config.py similarity index 100% rename from ml_new/logger_config.py rename to ml_new/config/logger_config.py diff --git a/ml_new/database.py b/ml_new/data/database.py similarity index 98% rename from ml_new/database.py rename to ml_new/data/database.py index b873656..820c695 100644 --- a/ml_new/database.py +++ b/ml_new/data/database.py @@ -4,13 +4,13 @@ Database connection and operations for ML training service from collections import defaultdict import os -import hashlib from typing import List, Dict, Optional, Any from datetime import datetime import asyncpg -from config_loader import config_loader +from ml_new.config.config_loader import config_loader +from ml_new.config.logger_config import get_logger from dotenv import load_dotenv -from logger_config import get_logger + load_dotenv() diff --git a/ml_new/dataset_service.py b/ml_new/data/dataset_service.py similarity index 98% rename from ml_new/dataset_service.py rename to ml_new/data/dataset_service.py index 054b6c4..43c70b1 100644 --- a/ml_new/dataset_service.py +++ b/ml_new/data/dataset_service.py @@ -8,12 +8,12 @@ from typing import List, Dict, Any, Optional from datetime import datetime import threading -from database import DatabaseManager -from embedding_service import EmbeddingService -from config_loader import config_loader -from logger_config import get_logger -from models import TaskStatus, DatasetBuildTaskStatus, TaskProgress -from dataset_storage_parquet import ParquetDatasetStorage +from ml_new.data.database import DatabaseManager +from ml_new.data.embedding_service import EmbeddingService +from ml_new.config.config_loader import config_loader +from ml_new.config.logger_config import get_logger +from ml_new.models import TaskStatus, DatasetBuildTaskStatus, TaskProgress +from ml_new.data.dataset_storage_parquet import ParquetDatasetStorage logger = get_logger(__name__) diff --git a/ml_new/dataset_storage_parquet.py b/ml_new/data/dataset_storage_parquet.py similarity index 99% rename from ml_new/dataset_storage_parquet.py rename to ml_new/data/dataset_storage_parquet.py index 26518ae..87e9402 100644 --- a/ml_new/dataset_storage_parquet.py +++ b/ml_new/data/dataset_storage_parquet.py @@ -5,13 +5,12 @@ Efficient dataset storage using Parquet format for better space utilization and import pandas as pd import numpy as np import json -import os from pathlib import Path from typing import List, Dict, Any, Optional, Union from datetime import datetime import pyarrow as pa import pyarrow.parquet as pq -from logger_config import get_logger +from ml_new.config.logger_config import get_logger logger = get_logger(__name__) diff --git a/ml_new/embedding_service.py b/ml_new/data/embedding_service.py similarity index 99% rename from ml_new/embedding_service.py rename to ml_new/data/embedding_service.py index eadd254..dc2097f 100644 --- a/ml_new/embedding_service.py +++ b/ml_new/data/embedding_service.py @@ -6,13 +6,13 @@ import hashlib from typing import List, Dict, Any, Optional from openai import AsyncOpenAI import os -from config_loader import config_loader +from ml_new.config.config_loader import config_loader from dotenv import load_dotenv import torch import numpy as np from transformers import AutoTokenizer import onnxruntime as ort -from logger_config import get_logger +from ml_new.config.logger_config import get_logger load_dotenv() diff --git a/ml_new/embedding_models.toml b/ml_new/embedding_models.toml index f9ba291..88a02f7 100644 --- a/ml_new/embedding_models.toml +++ b/ml_new/embedding_models.toml @@ -15,6 +15,6 @@ api_key_env = "ALIYUN_KEY" name = "jina-embedding-v3-m2v-1024" dimensions = 1024 type = "legacy" -model_path = "../../model/embedding/model.onnx" +model_path = "./model/embedding/model.onnx" tokenizer_name = "jinaai/jina-embeddings-v3" max_batch_size = 128 diff --git a/ml_new/main.py b/ml_new/main.py index fa3d930..ace62bf 100644 --- a/ml_new/main.py +++ b/ml_new/main.py @@ -2,16 +2,17 @@ Main FastAPI application for ML training service """ +import os import uvicorn from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from database import DatabaseManager -from embedding_service import EmbeddingService -from dataset_service import DatasetBuilder -from api_routes import router, set_dataset_builder -from logger_config import get_logger +from data.database import DatabaseManager +from data.embedding_service import EmbeddingService +from data.dataset_service import DatasetBuilder +from ml_new.routes.main import router, set_dataset_builder +from config.logger_config import get_logger logger = get_logger(__name__) @@ -26,39 +27,43 @@ dataset_builder = None async def lifespan(app: FastAPI): """Application lifespan manager for startup and shutdown events""" global db_manager, embedding_service, dataset_builder - + # Startup logger.info("Initializing services...") - + try: # Database manager db_manager = DatabaseManager() await db_manager.connect() # Initialize database connection pool logger.info("Database manager initialized and connected") - + # Embedding service embedding_service = EmbeddingService() logger.info("Embedding service initialized") - + # Dataset builder - dataset_builder = DatasetBuilder(db_manager, embedding_service) + dataset_builder = DatasetBuilder( + db_manager, + embedding_service, + os.path.join(os.path.dirname(__file__), "./datasets"), + ) logger.info("Dataset builder initialized") - + # Set global dataset builder instance set_dataset_builder(dataset_builder) - + logger.info("All services initialized successfully") - + except Exception as e: logger.error(f"Failed to initialize services: {e}") raise - + # Yield control to the application yield - + # Shutdown logger.info("Shutting down services...") - + try: if db_manager: await db_manager.close() @@ -69,15 +74,15 @@ async def lifespan(app: FastAPI): def create_app() -> FastAPI: """Create and configure FastAPI application""" - + # Create FastAPI app with lifespan manager app = FastAPI( title="ML Training Service", description="ML training, dataset building, and experiment management service", version="1.0.0", - lifespan=lifespan + lifespan=lifespan, ) - + # Add CORS middleware app.add_middleware( CORSMiddleware, @@ -86,17 +91,17 @@ def create_app() -> FastAPI: allow_methods=["*"], allow_headers=["*"], ) - + # Include API routes app.include_router(router) - + return app def main(): """Main entry point""" app = create_app() - + # Run the application uvicorn.run( app, @@ -108,4 +113,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/ml_new/api_routes.py b/ml_new/routes/main.py similarity index 98% rename from ml_new/api_routes.py rename to ml_new/routes/main.py index b49a522..66aa7c5 100644 --- a/ml_new/api_routes.py +++ b/ml_new/routes/main.py @@ -8,7 +8,7 @@ from typing import Optional from fastapi import APIRouter, HTTPException from fastapi.responses import JSONResponse -from config_loader import config_loader +from ml_new.config.config_loader import config_loader from models import ( DatasetBuildRequest, DatasetBuildResponse, @@ -20,8 +20,8 @@ from models import ( DatasetCreateRequest, DatasetCreateResponse ) -from dataset_service import DatasetBuilder -from logger_config import get_logger +from ml_new.data.dataset_service import DatasetBuilder +from ml_new.config.logger_config import get_logger logger = get_logger(__name__)