ref: file structure of ml_new
This commit is contained in:
parent
79e8b05ee9
commit
664784dd3e
2
.gitignore
vendored
2
.gitignore
vendored
@ -32,7 +32,7 @@ model/
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
data/
|
||||
./data/
|
||||
redis/
|
||||
|
||||
# Build
|
||||
|
||||
@ -6,7 +6,7 @@ import toml
|
||||
import os
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
from logger_config import get_logger
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -28,7 +28,7 @@ class ConfigLoader:
|
||||
if config_path is None:
|
||||
# Default to the embedding_models.toml file we created
|
||||
config_path = os.path.join(
|
||||
os.path.dirname(__file__), "embedding_models.toml"
|
||||
os.path.dirname(__file__), "..", "embedding_models.toml"
|
||||
)
|
||||
|
||||
self.config_path = config_path
|
||||
@ -4,13 +4,13 @@ Database connection and operations for ML training service
|
||||
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import hashlib
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
import asyncpg
|
||||
from config_loader import config_loader
|
||||
from ml_new.config.config_loader import config_loader
|
||||
from ml_new.config.logger_config import get_logger
|
||||
from dotenv import load_dotenv
|
||||
from logger_config import get_logger
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@ -8,12 +8,12 @@ from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import threading
|
||||
|
||||
from database import DatabaseManager
|
||||
from embedding_service import EmbeddingService
|
||||
from config_loader import config_loader
|
||||
from logger_config import get_logger
|
||||
from models import TaskStatus, DatasetBuildTaskStatus, TaskProgress
|
||||
from dataset_storage_parquet import ParquetDatasetStorage
|
||||
from ml_new.data.database import DatabaseManager
|
||||
from ml_new.data.embedding_service import EmbeddingService
|
||||
from ml_new.config.config_loader import config_loader
|
||||
from ml_new.config.logger_config import get_logger
|
||||
from ml_new.models import TaskStatus, DatasetBuildTaskStatus, TaskProgress
|
||||
from ml_new.data.dataset_storage_parquet import ParquetDatasetStorage
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -5,13 +5,12 @@ Efficient dataset storage using Parquet format for better space utilization and
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from datetime import datetime
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from logger_config import get_logger
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -6,13 +6,13 @@ import hashlib
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
from config_loader import config_loader
|
||||
from ml_new.config.config_loader import config_loader
|
||||
from dotenv import load_dotenv
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
import onnxruntime as ort
|
||||
from logger_config import get_logger
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@ -15,6 +15,6 @@ api_key_env = "ALIYUN_KEY"
|
||||
name = "jina-embedding-v3-m2v-1024"
|
||||
dimensions = 1024
|
||||
type = "legacy"
|
||||
model_path = "../../model/embedding/model.onnx"
|
||||
model_path = "./model/embedding/model.onnx"
|
||||
tokenizer_name = "jinaai/jina-embeddings-v3"
|
||||
max_batch_size = 128
|
||||
|
||||
@ -2,16 +2,17 @@
|
||||
Main FastAPI application for ML training service
|
||||
"""
|
||||
|
||||
import os
|
||||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from database import DatabaseManager
|
||||
from embedding_service import EmbeddingService
|
||||
from dataset_service import DatasetBuilder
|
||||
from api_routes import router, set_dataset_builder
|
||||
from logger_config import get_logger
|
||||
from data.database import DatabaseManager
|
||||
from data.embedding_service import EmbeddingService
|
||||
from data.dataset_service import DatasetBuilder
|
||||
from ml_new.routes.main import router, set_dataset_builder
|
||||
from config.logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -26,39 +27,43 @@ dataset_builder = None
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager for startup and shutdown events"""
|
||||
global db_manager, embedding_service, dataset_builder
|
||||
|
||||
|
||||
# Startup
|
||||
logger.info("Initializing services...")
|
||||
|
||||
|
||||
try:
|
||||
# Database manager
|
||||
db_manager = DatabaseManager()
|
||||
await db_manager.connect() # Initialize database connection pool
|
||||
logger.info("Database manager initialized and connected")
|
||||
|
||||
|
||||
# Embedding service
|
||||
embedding_service = EmbeddingService()
|
||||
logger.info("Embedding service initialized")
|
||||
|
||||
|
||||
# Dataset builder
|
||||
dataset_builder = DatasetBuilder(db_manager, embedding_service)
|
||||
dataset_builder = DatasetBuilder(
|
||||
db_manager,
|
||||
embedding_service,
|
||||
os.path.join(os.path.dirname(__file__), "./datasets"),
|
||||
)
|
||||
logger.info("Dataset builder initialized")
|
||||
|
||||
|
||||
# Set global dataset builder instance
|
||||
set_dataset_builder(dataset_builder)
|
||||
|
||||
|
||||
logger.info("All services initialized successfully")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize services: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Yield control to the application
|
||||
yield
|
||||
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down services...")
|
||||
|
||||
|
||||
try:
|
||||
if db_manager:
|
||||
await db_manager.close()
|
||||
@ -69,15 +74,15 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure FastAPI application"""
|
||||
|
||||
|
||||
# Create FastAPI app with lifespan manager
|
||||
app = FastAPI(
|
||||
title="ML Training Service",
|
||||
description="ML training, dataset building, and experiment management service",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@ -86,17 +91,17 @@ def create_app() -> FastAPI:
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Include API routes
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
app = create_app()
|
||||
|
||||
|
||||
# Run the application
|
||||
uvicorn.run(
|
||||
app,
|
||||
@ -108,4 +113,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from config_loader import config_loader
|
||||
from ml_new.config.config_loader import config_loader
|
||||
from models import (
|
||||
DatasetBuildRequest,
|
||||
DatasetBuildResponse,
|
||||
@ -20,8 +20,8 @@ from models import (
|
||||
DatasetCreateRequest,
|
||||
DatasetCreateResponse
|
||||
)
|
||||
from dataset_service import DatasetBuilder
|
||||
from logger_config import get_logger
|
||||
from ml_new.data.dataset_service import DatasetBuilder
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
Loading…
Reference in New Issue
Block a user