add: legacy embedding model support, task-based dataset building
This commit is contained in:
parent
77668bbb52
commit
7dbef68cdc
@ -2,4 +2,4 @@
|
||||
|
||||
1. Always use bun as package manager.
|
||||
|
||||
2. Always write comments in English.
|
||||
2. IMPORTANT: **Always write comments inside your code (through tool calls) in English, and respond (including description of changes, requests and questions to user ) in the same language as the user's query.**
|
||||
|
||||
@ -1,140 +0,0 @@
|
||||
# ML Training Service
|
||||
|
||||
A FastAPI-based ML training service for dataset building, embedding generation, and experiment management.
|
||||
|
||||
## Architecture
|
||||
|
||||
The service is organized into modular components:
|
||||
|
||||
```
|
||||
ml_new/training/
|
||||
├── main.py # FastAPI application entry point
|
||||
├── models.py # Pydantic data models
|
||||
├── config_loader.py # Configuration loading from TOML
|
||||
├── database.py # Database connection and operations
|
||||
├── embedding_service.py # Embedding generation service
|
||||
├── dataset_service.py # Dataset building logic
|
||||
├── api_routes.py # API endpoint definitions
|
||||
├── embedding_models.toml # Embedding model configurations
|
||||
└── requirements.txt # Python dependencies
|
||||
```
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Main Application (`main.py`)
|
||||
- FastAPI app initialization
|
||||
- CORS middleware configuration
|
||||
- Service dependency injection
|
||||
- Startup/shutdown event handlers
|
||||
|
||||
### 2. Data Models (`models.py`)
|
||||
- `DatasetBuildRequest`: Request model for dataset building
|
||||
- `DatasetBuildResponse`: Response model for dataset building
|
||||
- `DatasetRecord`: Individual dataset record structure
|
||||
- `EmbeddingModelInfo`: Embedding model configuration
|
||||
|
||||
### 3. Configuration (`config_loader.py`)
|
||||
- Loads embedding model configurations from TOML
|
||||
- Manages model parameters (dimensions, API endpoints, etc.)
|
||||
|
||||
### 4. Database Layer (`database.py`)
|
||||
- PostgreSQL connection management
|
||||
- CRUD operations for video metadata, user labels, and embeddings
|
||||
- Optimized batch queries to avoid N+1 problems
|
||||
|
||||
### 5. Embedding Service (`embedding_service.py`)
|
||||
- Integration with OpenAI-compatible embedding APIs
|
||||
- Text preprocessing and checksum generation
|
||||
- Batch embedding generation with rate limiting
|
||||
|
||||
### 6. Dataset Building (`dataset_service.py`)
|
||||
- Complete dataset construction workflow:
|
||||
1. Pull raw text from database
|
||||
2. Text preprocessing (placeholder)
|
||||
3. Batch embedding generation with deduplication
|
||||
4. Embedding storage and caching
|
||||
5. Final dataset compilation with labels
|
||||
|
||||
### 7. API Routes (`api_routes.py`)
|
||||
- `/api/v1/health`: Health check
|
||||
- `/api/v1/models/embedding`: List available embedding models
|
||||
- `/api/v1/dataset/build`: Build new dataset
|
||||
- `/api/v1/dataset/{id}`: Retrieve built dataset
|
||||
- `/api/v1/datasets`: List all datasets
|
||||
- `/api/v1/dataset/{id}`: Delete dataset
|
||||
|
||||
## Dataset Building Flow
|
||||
|
||||
1. **Model Selection**: Choose embedding model from TOML configuration
|
||||
2. **Data Retrieval**: Pull video metadata and user labels from PostgreSQL
|
||||
3. **Text Processing**: Combine title, description, and tags
|
||||
4. **Deduplication**: Generate checksums to avoid duplicate embeddings
|
||||
5. **Batch Processing**: Generate embeddings for new texts only
|
||||
6. **Storage**: Store embeddings in database with caching
|
||||
7. **Final Assembly**: Combine embeddings with labels using consensus mechanism
|
||||
|
||||
## Configuration
|
||||
|
||||
### Embedding Models (`embedding_models.toml`)
|
||||
```toml
|
||||
[text-embedding-3-large]
|
||||
name = "text-embedding-3-large"
|
||||
dimensions = 3072
|
||||
type = "openai"
|
||||
api_endpoint = "https://api.openai.com/v1/embeddings"
|
||||
max_tokens = 8192
|
||||
max_batch_size = 100
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
- `DATABASE_URL`: PostgreSQL connection string
|
||||
- `OPENAI_API_KEY`: OpenAI API key for embedding generation
|
||||
|
||||
## Usage
|
||||
|
||||
### Start the Service
|
||||
```bash
|
||||
cd ml_new/training
|
||||
python main.py
|
||||
```
|
||||
|
||||
### Build a Dataset
|
||||
```bash
|
||||
curl -X POST "http://localhost:8322/v1/dataset/build" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"aid_list": [170001, 170002, 170003],
|
||||
"embedding_model": "text-embedding-3-large",
|
||||
"force_regenerate": false
|
||||
}'
|
||||
```
|
||||
|
||||
### Check Health
|
||||
```bash
|
||||
curl "http://localhost:8322/v1/health"
|
||||
```
|
||||
|
||||
### List Embedding Models
|
||||
```bash
|
||||
curl "http://localhost:8322/v1/models/embedding"
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **High Performance**: Optimized database queries with batch operations
|
||||
- **Deduplication**: Text-level deduplication using MD5 checksums
|
||||
- **Consensus Labels**: Majority vote mechanism for user annotations
|
||||
- **Batch Processing**: Efficient embedding generation and storage
|
||||
- **Error Handling**: Comprehensive error handling and logging
|
||||
- **Async Support**: Fully asynchronous operations for scalability
|
||||
- **CORS Enabled**: Ready for frontend integration
|
||||
|
||||
## Production Considerations
|
||||
|
||||
- Replace in-memory dataset storage with database
|
||||
- Add authentication and authorization
|
||||
- Implement rate limiting for API endpoints
|
||||
- Add monitoring and metrics collection
|
||||
- Configure proper logging levels
|
||||
- Set up database connection pooling
|
||||
- Add API documentation with OpenAPI/Swagger
|
||||
@ -2,18 +2,19 @@
|
||||
API routes for the ML training service
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from config_loader import config_loader
|
||||
from models import DatasetBuildRequest, DatasetBuildResponse
|
||||
from models import DatasetBuildRequest, DatasetBuildResponse, TaskStatus, TaskStatusResponse, TaskListResponse
|
||||
from dataset_service import DatasetBuilder
|
||||
from logger_config import get_logger
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Create router
|
||||
router = APIRouter(prefix="/v1")
|
||||
@ -80,8 +81,8 @@ async def get_embedding_models():
|
||||
|
||||
|
||||
@router.post("/dataset/build", response_model=DatasetBuildResponse)
|
||||
async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks: BackgroundTasks):
|
||||
"""Build dataset endpoint"""
|
||||
async def build_dataset_endpoint(request: DatasetBuildRequest):
|
||||
"""Build dataset endpoint with task tracking"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
@ -91,20 +92,22 @@ async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
|
||||
|
||||
dataset_id = str(uuid.uuid4())
|
||||
# Start background task for dataset building
|
||||
background_tasks.add_task(
|
||||
dataset_builder.build_dataset,
|
||||
|
||||
# Start task-based dataset building
|
||||
task_id = await dataset_builder.start_dataset_build_task(
|
||||
dataset_id,
|
||||
request.aid_list,
|
||||
request.embedding_model,
|
||||
request.force_regenerate
|
||||
request.force_regenerate,
|
||||
request.description
|
||||
)
|
||||
|
||||
return DatasetBuildResponse(
|
||||
dataset_id=dataset_id,
|
||||
total_records=len(request.aid_list),
|
||||
status="started",
|
||||
message="Dataset building started"
|
||||
message=f"Dataset building started with task ID: {task_id}",
|
||||
description=request.description
|
||||
)
|
||||
|
||||
|
||||
@ -126,6 +129,7 @@ async def get_dataset_endpoint(dataset_id: str):
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"dataset": dataset_info["dataset"],
|
||||
"description": dataset_info.get("description"),
|
||||
"stats": dataset_info["stats"],
|
||||
"created_at": dataset_info["created_at"]
|
||||
}
|
||||
@ -143,6 +147,7 @@ async def list_datasets():
|
||||
if "error" not in dataset_info:
|
||||
datasets.append({
|
||||
"dataset_id": dataset_id,
|
||||
"description": dataset_info.get("description"),
|
||||
"stats": dataset_info["stats"],
|
||||
"created_at": dataset_info["created_at"]
|
||||
})
|
||||
@ -171,7 +176,16 @@ async def list_datasets_endpoint():
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
datasets = dataset_builder.list_datasets()
|
||||
return {"datasets": datasets}
|
||||
# Add description to each dataset
|
||||
datasets_with_description = []
|
||||
for dataset in datasets:
|
||||
dataset_info = dataset_builder.get_dataset(dataset["dataset_id"])
|
||||
if dataset_info and "description" in dataset_info:
|
||||
dataset["description"] = dataset_info["description"]
|
||||
else:
|
||||
dataset["description"] = None
|
||||
datasets_with_description.append(dataset)
|
||||
return {"datasets": datasets_with_description}
|
||||
|
||||
|
||||
@router.get("/datasets/stats")
|
||||
@ -193,4 +207,111 @@ async def cleanup_datasets_endpoint(max_age_days: int = 30):
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
await dataset_builder.cleanup_old_datasets(max_age_days)
|
||||
return {"message": f"Cleanup completed for datasets older than {max_age_days} days"}
|
||||
return {"message": f"Cleanup completed for datasets older than {max_age_days} days"}
|
||||
|
||||
|
||||
# Task Status Endpoints
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
|
||||
async def get_task_status_endpoint(task_id: str):
|
||||
"""Get status of a specific task"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
task_status = dataset_builder.get_task_status(task_id)
|
||||
if not task_status:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Convert to response model
|
||||
progress_dict = None
|
||||
if task_status.progress:
|
||||
progress_dict = {
|
||||
"current_step": task_status.progress.current_step,
|
||||
"total_steps": task_status.progress.total_steps,
|
||||
"completed_steps": task_status.progress.completed_steps,
|
||||
"percentage": task_status.progress.percentage,
|
||||
"message": task_status.progress.message
|
||||
}
|
||||
|
||||
return TaskStatusResponse(
|
||||
task_id=task_status.task_id,
|
||||
status=task_status.status,
|
||||
progress=progress_dict,
|
||||
result=task_status.result,
|
||||
error=task_status.error_message,
|
||||
created_at=task_status.created_at,
|
||||
started_at=task_status.started_at,
|
||||
completed_at=task_status.completed_at
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=TaskListResponse)
|
||||
async def list_tasks_endpoint(status: Optional[TaskStatus] = None, limit: int = 50):
|
||||
"""List all tasks, optionally filtered by status"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
tasks = dataset_builder.list_tasks(status_filter=status)
|
||||
|
||||
# Limit results
|
||||
if limit > 0:
|
||||
tasks = tasks[:limit]
|
||||
|
||||
# Convert to response models
|
||||
task_responses = []
|
||||
for task_status in tasks:
|
||||
progress_dict = None
|
||||
if task_status.progress:
|
||||
progress_dict = {
|
||||
"current_step": task_status.progress.current_step,
|
||||
"total_steps": task_status.progress.total_steps,
|
||||
"completed_steps": task_status.progress.completed_steps,
|
||||
"percentage": task_status.progress.percentage,
|
||||
"message": task_status.progress.message
|
||||
}
|
||||
|
||||
task_responses.append(TaskStatusResponse(
|
||||
task_id=task_status.task_id,
|
||||
status=task_status.status,
|
||||
progress=progress_dict,
|
||||
result=task_status.result,
|
||||
error=task_status.error_message,
|
||||
created_at=task_status.created_at,
|
||||
started_at=task_status.started_at,
|
||||
completed_at=task_status.completed_at
|
||||
))
|
||||
|
||||
# Get statistics
|
||||
stats = dataset_builder.get_task_statistics()
|
||||
|
||||
return TaskListResponse(
|
||||
tasks=task_responses,
|
||||
total_count=stats["total_tasks"],
|
||||
pending_count=stats["status_counts"][TaskStatus.PENDING],
|
||||
running_count=stats["status_counts"][TaskStatus.RUNNING],
|
||||
completed_count=stats["status_counts"][TaskStatus.COMPLETED],
|
||||
failed_count=stats["status_counts"][TaskStatus.FAILED]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks/stats")
|
||||
async def get_task_statistics_endpoint():
|
||||
"""Get statistics about all tasks"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
return dataset_builder.get_task_statistics()
|
||||
|
||||
|
||||
@router.post("/tasks/cleanup")
|
||||
async def cleanup_tasks_endpoint(max_age_hours: int = 24):
|
||||
"""Clean up completed/failed tasks older than specified hours"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
cleaned_count = await dataset_builder.cleanup_completed_tasks(max_age_hours)
|
||||
return {"message": f"Cleaned up {cleaned_count} tasks older than {max_age_hours} hours"}
|
||||
@ -6,9 +6,9 @@ import toml
|
||||
import os
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from logger_config import get_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EmbeddingModelConfig(BaseModel):
|
||||
@ -19,6 +19,8 @@ class EmbeddingModelConfig(BaseModel):
|
||||
max_tokens: int = 8191
|
||||
max_batch_size: int = 8
|
||||
api_key_env: str = "OPENAI_API_KEY"
|
||||
model_path: str = ""
|
||||
tokenizer_name: str = ""
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
@ -31,6 +33,7 @@ class ConfigLoader:
|
||||
|
||||
self.config_path = config_path
|
||||
self.embedding_models: Dict[str, EmbeddingModelConfig] = {}
|
||||
self.selected_model: str = None
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
@ -51,6 +54,8 @@ class ConfigLoader:
|
||||
self.embedding_models[model_key] = EmbeddingModelConfig(
|
||||
**model_data
|
||||
)
|
||||
|
||||
self.selected_model = config_data.get("model", list(self.embedding_models.keys())[0])
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(self.embedding_models)} embedding models from {self.config_path}"
|
||||
@ -58,6 +63,10 @@ class ConfigLoader:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load config from {self.config_path}: {e}")
|
||||
|
||||
def get_selected_model(self) -> str:
|
||||
"""Get selected model for health check"""
|
||||
return self.selected_model
|
||||
|
||||
def get_embedding_models(self) -> Dict[str, EmbeddingModelConfig]:
|
||||
"""Get all available embedding models"""
|
||||
|
||||
@ -7,13 +7,13 @@ import hashlib
|
||||
from typing import List, Dict, Optional, Any
|
||||
from datetime import datetime
|
||||
import asyncpg
|
||||
import logging
|
||||
from config_loader import config_loader
|
||||
from dotenv import load_dotenv
|
||||
from logger_config import get_logger
|
||||
|
||||
load_dotenv()
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
@ -120,7 +120,7 @@ class DatabaseManager:
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT data_checksum, vec_2048, vec_1536, vec_1024, created_at
|
||||
SELECT data_checksum, dimensions, vec_2048, vec_1536, vec_1024, created_at
|
||||
FROM internal.embeddings
|
||||
WHERE model_name = $1 AND data_checksum = ANY($2::text[])
|
||||
"""
|
||||
@ -177,14 +177,15 @@ class DatabaseManager:
|
||||
|
||||
query = f"""
|
||||
INSERT INTO internal.embeddings
|
||||
(model_name, data_checksum, {vec_column}, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (data_checksum) DO NOTHING
|
||||
(model_name, dimensions, data_checksum, {vec_column}, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (model_name, dimensions, data_checksum) DO NOTHING
|
||||
"""
|
||||
|
||||
await conn.execute(
|
||||
query,
|
||||
data["model_name"],
|
||||
data["dimensions"],
|
||||
data["checksum"],
|
||||
vector_str,
|
||||
datetime.now(),
|
||||
|
||||
@ -2,20 +2,21 @@
|
||||
Dataset building service - handles the complete dataset construction flow
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
import threading
|
||||
|
||||
from database import DatabaseManager
|
||||
from embedding_service import EmbeddingService
|
||||
from config_loader import config_loader
|
||||
from logger_config import get_logger
|
||||
from models import TaskStatus, DatasetBuildTaskStatus, TaskProgress
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatasetBuilder:
|
||||
@ -29,6 +30,11 @@ class DatasetBuilder:
|
||||
|
||||
# Load existing datasets from file system
|
||||
self.dataset_storage: Dict[str, Dict] = self._load_all_datasets()
|
||||
|
||||
# Task status tracking
|
||||
self._task_status_lock = threading.Lock()
|
||||
self.task_statuses: Dict[str, DatasetBuildTaskStatus] = {}
|
||||
self.running_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
def _get_dataset_file_path(self, dataset_id: str) -> Path:
|
||||
@ -80,6 +86,134 @@ class DatasetBuilder:
|
||||
return datasets
|
||||
|
||||
|
||||
def _create_task_status(self, task_id: str, dataset_id: str, aid_list: List[int],
|
||||
embedding_model: str, force_regenerate: bool) -> DatasetBuildTaskStatus:
|
||||
"""Create initial task status"""
|
||||
with self._task_status_lock:
|
||||
task_status = DatasetBuildTaskStatus(
|
||||
task_id=task_id,
|
||||
status=TaskStatus.PENDING,
|
||||
dataset_id=dataset_id,
|
||||
aid_list=aid_list,
|
||||
embedding_model=embedding_model,
|
||||
force_regenerate=force_regenerate,
|
||||
created_at=datetime.now(),
|
||||
progress=TaskProgress(
|
||||
current_step="initialized",
|
||||
total_steps=7,
|
||||
completed_steps=0,
|
||||
percentage=0.0,
|
||||
message="Task initialized"
|
||||
)
|
||||
)
|
||||
self.task_statuses[task_id] = task_status
|
||||
return task_status
|
||||
|
||||
|
||||
def _update_task_status(self, task_id: str, **kwargs):
|
||||
"""Update task status with new values"""
|
||||
with self._task_status_lock:
|
||||
if task_id in self.task_statuses:
|
||||
task_status = self.task_statuses[task_id]
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(task_status, key):
|
||||
setattr(task_status, key, value)
|
||||
self.task_statuses[task_id] = task_status
|
||||
|
||||
|
||||
def _update_task_progress(self, task_id: str, current_step: str, completed_steps: int,
|
||||
message: str = None, percentage: float = None):
|
||||
"""Update task progress"""
|
||||
with self._task_status_lock:
|
||||
if task_id in self.task_statuses:
|
||||
task_status = self.task_statuses[task_id]
|
||||
if percentage is not None:
|
||||
progress_percentage = percentage
|
||||
else:
|
||||
progress_percentage = (completed_steps / task_status.progress.total_steps) * 100 if task_status.progress else 0.0
|
||||
|
||||
task_status.progress = TaskProgress(
|
||||
current_step=current_step,
|
||||
total_steps=task_status.progress.total_steps if task_status.progress else 7,
|
||||
completed_steps=completed_steps,
|
||||
percentage=progress_percentage,
|
||||
message=message
|
||||
)
|
||||
self.task_statuses[task_id] = task_status
|
||||
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[DatasetBuildTaskStatus]:
|
||||
"""Get task status by task ID"""
|
||||
with self._task_status_lock:
|
||||
return self.task_statuses.get(task_id)
|
||||
|
||||
|
||||
def list_tasks(self, status_filter: Optional[TaskStatus] = None) -> List[DatasetBuildTaskStatus]:
|
||||
"""List all tasks, optionally filtered by status"""
|
||||
with self._task_status_lock:
|
||||
tasks = list(self.task_statuses.values())
|
||||
if status_filter:
|
||||
tasks = [task for task in tasks if task.status == status_filter]
|
||||
# Sort by creation time (newest first)
|
||||
tasks.sort(key=lambda x: x.created_at, reverse=True)
|
||||
return tasks
|
||||
|
||||
|
||||
def get_task_statistics(self) -> Dict[str, Any]:
|
||||
"""Get statistics about all tasks"""
|
||||
with self._task_status_lock:
|
||||
total_tasks = len(self.task_statuses)
|
||||
status_counts = {
|
||||
TaskStatus.PENDING: 0,
|
||||
TaskStatus.RUNNING: 0,
|
||||
TaskStatus.COMPLETED: 0,
|
||||
TaskStatus.FAILED: 0,
|
||||
TaskStatus.CANCELLED: 0
|
||||
}
|
||||
|
||||
for task_status in self.task_statuses.values():
|
||||
status_counts[task_status.status] += 1
|
||||
|
||||
return {
|
||||
"total_tasks": total_tasks,
|
||||
"status_counts": status_counts,
|
||||
"running_tasks": status_counts[TaskStatus.RUNNING]
|
||||
}
|
||||
|
||||
|
||||
async def cleanup_completed_tasks(self, max_age_hours: int = 24):
|
||||
"""Clean up completed/failed tasks older than specified hours"""
|
||||
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
|
||||
cleaned_count = 0
|
||||
|
||||
with self._task_status_lock:
|
||||
tasks_to_remove = []
|
||||
|
||||
for task_id, task_status in self.task_statuses.items():
|
||||
if task_status.completed_at:
|
||||
try:
|
||||
completed_time = task_status.completed_at.timestamp()
|
||||
if completed_time < cutoff_time:
|
||||
tasks_to_remove.append(task_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check completion time for task {task_id}: {e}")
|
||||
|
||||
for task_id in tasks_to_remove:
|
||||
# Remove from task statuses
|
||||
del self.task_statuses[task_id]
|
||||
|
||||
# Remove from running tasks if still there
|
||||
if task_id in self.running_tasks:
|
||||
del self.running_tasks[task_id]
|
||||
|
||||
cleaned_count += 1
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"Cleaned up {cleaned_count} old tasks")
|
||||
|
||||
return cleaned_count
|
||||
|
||||
|
||||
async def cleanup_old_datasets(self, max_age_days: int = 30):
|
||||
"""Remove datasets older than specified days"""
|
||||
try:
|
||||
@ -110,43 +244,54 @@ class DatasetBuilder:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old datasets: {e}")
|
||||
|
||||
async def build_dataset(self, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool = False) -> str:
|
||||
async def build_dataset_with_task_tracking(self, task_id: str, dataset_id: str, aid_list: List[int],
|
||||
embedding_model: str, force_regenerate: bool = False,
|
||||
description: Optional[str] = None) -> str:
|
||||
"""
|
||||
Build dataset with the specified flow:
|
||||
1. Select embedding model (from TOML config)
|
||||
2. Pull raw text from database
|
||||
3. Preprocess (placeholder for now)
|
||||
4. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
|
||||
5. Write to embeddings table
|
||||
6. Pull all needed embeddings to create dataset with format: embeddings, label
|
||||
Build dataset with task status tracking
|
||||
|
||||
Steps:
|
||||
1. Initialize task status
|
||||
2. Select embedding model (from TOML config)
|
||||
3. Pull raw text from database
|
||||
4. Preprocess (placeholder for now)
|
||||
5. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
|
||||
6. Write to embeddings table
|
||||
7. Pull all needed embeddings to create dataset with format: embeddings, label
|
||||
"""
|
||||
|
||||
# Update task status to running
|
||||
self._update_task_status(task_id, status=TaskStatus.RUNNING, started_at=datetime.now())
|
||||
|
||||
try:
|
||||
logger.info(f"Starting dataset building task {dataset_id}")
|
||||
logger.info(f"Starting dataset building task {dataset_id} (task_id: {task_id})")
|
||||
|
||||
# Step 1: Get model configuration
|
||||
EMBEDDING_MODELS = config_loader.get_embedding_models()
|
||||
|
||||
# Get model configuration
|
||||
if embedding_model not in EMBEDDING_MODELS:
|
||||
raise ValueError(f"Invalid embedding model: {embedding_model}")
|
||||
|
||||
model_config = EMBEDDING_MODELS[embedding_model]
|
||||
self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database")
|
||||
|
||||
# Step 1: Get video metadata from database
|
||||
# Step 2: Get video metadata from database
|
||||
metadata = await self.db_manager.get_video_metadata(aid_list)
|
||||
self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database")
|
||||
|
||||
# Step 2: Get user labels
|
||||
# Step 3: Get user labels
|
||||
labels = await self.db_manager.get_user_labels(aid_list)
|
||||
self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums")
|
||||
|
||||
# Step 3: Prepare text data and checksums
|
||||
# Step 4: Prepare text data and checksums
|
||||
text_data = []
|
||||
total_aids = len(aid_list)
|
||||
|
||||
for aid in aid_list:
|
||||
for i, aid in enumerate(aid_list):
|
||||
if aid in metadata:
|
||||
# Combine title, description, tags
|
||||
combined_text = self.embedding_service.combine_video_text(
|
||||
metadata[aid]['title'],
|
||||
metadata[aid]['description'],
|
||||
metadata[aid]['description'],
|
||||
metadata[aid]['tags']
|
||||
)
|
||||
|
||||
@ -158,12 +303,24 @@ class DatasetBuilder:
|
||||
'text': combined_text,
|
||||
'checksum': checksum
|
||||
})
|
||||
|
||||
# Update progress for text preparation
|
||||
if i % 10 == 0 or i == total_aids - 1: # Update every 10 items or at the end
|
||||
progress_pct = 3 + (i + 1) / total_aids
|
||||
self._update_task_progress(
|
||||
task_id,
|
||||
"preparing_text",
|
||||
min(3, int(progress_pct)),
|
||||
f"Prepared {i + 1}/{total_aids} text entries"
|
||||
)
|
||||
|
||||
# Step 4: Check existing embeddings
|
||||
self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings")
|
||||
|
||||
# Step 5: Check existing embeddings
|
||||
checksums = [item['checksum'] for item in text_data]
|
||||
existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model)
|
||||
|
||||
# Step 5: Generate new embeddings for texts that don't have them
|
||||
# Step 6: Generate new embeddings for texts that don't have them
|
||||
new_embeddings_needed = []
|
||||
for item in text_data:
|
||||
if item['checksum'] not in existing_embeddings or force_regenerate:
|
||||
@ -171,13 +328,20 @@ class DatasetBuilder:
|
||||
|
||||
new_embeddings_count = 0
|
||||
if new_embeddings_needed:
|
||||
self._update_task_progress(
|
||||
task_id,
|
||||
"generating_embeddings",
|
||||
5,
|
||||
f"Generating {len(new_embeddings_needed)} new embeddings"
|
||||
)
|
||||
|
||||
logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
|
||||
generated_embeddings = await self.embedding_service.generate_embeddings_batch(
|
||||
new_embeddings_needed,
|
||||
new_embeddings_needed,
|
||||
embedding_model
|
||||
)
|
||||
|
||||
# Step 6: Store new embeddings in database
|
||||
# Step 7: Store new embeddings in database
|
||||
embeddings_to_store = []
|
||||
for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)):
|
||||
checksum = self.embedding_service.create_text_checksum(text)
|
||||
@ -198,11 +362,13 @@ class DatasetBuilder:
|
||||
f'vec_{model_config.dimensions}': emb_data['vector']
|
||||
}
|
||||
|
||||
# Step 7: Build final dataset
|
||||
self._update_task_progress(task_id, "building_dataset", 6, "Building final dataset")
|
||||
|
||||
# Step 8: Build final dataset
|
||||
dataset = []
|
||||
inconsistent_count = 0
|
||||
|
||||
for item in text_data:
|
||||
for i, item in enumerate(text_data):
|
||||
aid = item['aid']
|
||||
checksum = item['checksum']
|
||||
|
||||
@ -241,6 +407,16 @@ class DatasetBuilder:
|
||||
'inconsistent': inconsistent,
|
||||
'text_checksum': checksum
|
||||
})
|
||||
|
||||
# Update progress for dataset building
|
||||
if i % 10 == 0 or i == len(text_data) - 1: # Update every 10 items or at the end
|
||||
progress_pct = 6 + (i + 1) / len(text_data)
|
||||
self._update_task_progress(
|
||||
task_id,
|
||||
"building_dataset",
|
||||
min(6, int(progress_pct)),
|
||||
f"Built {i + 1}/{len(text_data)} dataset records"
|
||||
)
|
||||
|
||||
reused_count = len(dataset) - new_embeddings_count
|
||||
|
||||
@ -249,6 +425,7 @@ class DatasetBuilder:
|
||||
# Prepare dataset data
|
||||
dataset_data = {
|
||||
'dataset': dataset,
|
||||
'description': description,
|
||||
'stats': {
|
||||
'total_records': len(dataset),
|
||||
'new_embeddings': new_embeddings_count,
|
||||
@ -259,6 +436,8 @@ class DatasetBuilder:
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self._update_task_progress(task_id, "saving_dataset", 7, "Saving dataset to storage")
|
||||
|
||||
# Save to file and memory cache
|
||||
if self._save_dataset_to_file(dataset_id, dataset_data):
|
||||
self.dataset_storage[dataset_id] = dataset_data
|
||||
@ -267,11 +446,46 @@ class DatasetBuilder:
|
||||
logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only")
|
||||
self.dataset_storage[dataset_id] = dataset_data
|
||||
|
||||
# Update task status to completed
|
||||
result = {
|
||||
'dataset_id': dataset_id,
|
||||
'stats': dataset_data['stats']
|
||||
}
|
||||
|
||||
self._update_task_status(
|
||||
task_id,
|
||||
status=TaskStatus.COMPLETED,
|
||||
completed_at=datetime.now(),
|
||||
result=result,
|
||||
progress=TaskProgress(
|
||||
current_step="completed",
|
||||
total_steps=7,
|
||||
completed_steps=7,
|
||||
percentage=100.0,
|
||||
message="Dataset building completed successfully"
|
||||
)
|
||||
)
|
||||
|
||||
return dataset_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dataset building failed for {dataset_id}: {str(e)}")
|
||||
|
||||
# Update task status to failed
|
||||
self._update_task_status(
|
||||
task_id,
|
||||
status=TaskStatus.FAILED,
|
||||
completed_at=datetime.now(),
|
||||
error_message=str(e),
|
||||
progress=TaskProgress(
|
||||
current_step="failed",
|
||||
total_steps=7,
|
||||
completed_steps=0,
|
||||
percentage=0.0,
|
||||
message=f"Task failed: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
# Store error information
|
||||
error_data = {
|
||||
'error': str(e),
|
||||
@ -283,6 +497,30 @@ class DatasetBuilder:
|
||||
self.dataset_storage[dataset_id] = error_data
|
||||
raise
|
||||
|
||||
|
||||
async def start_dataset_build_task(self, dataset_id: str, aid_list: List[int],
|
||||
embedding_model: str, force_regenerate: bool = False,
|
||||
description: Optional[str] = None) -> str:
|
||||
"""
|
||||
Start a dataset building task and return task ID for status tracking
|
||||
"""
|
||||
import uuid
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Create task status
|
||||
task_status = self._create_task_status(task_id, dataset_id, aid_list, embedding_model, force_regenerate)
|
||||
|
||||
# Start the actual task
|
||||
task = asyncio.create_task(
|
||||
self.build_dataset_with_task_tracking(task_id, dataset_id, aid_list, embedding_model, force_regenerate, description)
|
||||
)
|
||||
|
||||
# Store the running task
|
||||
with self._task_status_lock:
|
||||
self.running_tasks[task_id] = task
|
||||
|
||||
return task_id
|
||||
|
||||
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get built dataset by ID"""
|
||||
# First check memory cache
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Embedding Models Configuration
|
||||
|
||||
model = "qwen3-embedding"
|
||||
model = "jina-embedding-v3-m2v"
|
||||
|
||||
[models.qwen3-embedding]
|
||||
name = "text-embedding-v4"
|
||||
@ -10,3 +10,11 @@ api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
max_tokens = 8192
|
||||
max_batch_size = 10
|
||||
api_key_env = "ALIYUN_KEY"
|
||||
|
||||
[models.jina-embedding-v3-m2v]
|
||||
name = "jina-embedding-v3-m2v-1024"
|
||||
dimensions = 1024
|
||||
type = "legacy"
|
||||
model_path = "../../model/embedding/model.onnx"
|
||||
tokenizer_name = "jinaai/jina-embeddings-v3"
|
||||
max_batch_size = 128
|
||||
|
||||
@ -1,18 +1,22 @@
|
||||
"""
|
||||
Embedding service for generating embeddings using OpenAI-compatible API
|
||||
Embedding service for generating embeddings using OpenAI-compatible API and legacy methods
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
from config_loader import config_loader
|
||||
from dotenv import load_dotenv
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer
|
||||
import onnxruntime as ort
|
||||
from logger_config import get_logger
|
||||
|
||||
load_dotenv()
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class EmbeddingService:
|
||||
def __init__(self):
|
||||
@ -21,7 +25,10 @@ class EmbeddingService:
|
||||
|
||||
# Initialize OpenAI client (will be configured per model)
|
||||
self.clients: Dict[str, AsyncOpenAI] = {}
|
||||
self.legacy_models: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
self._initialize_clients()
|
||||
self._initialize_legacy_models()
|
||||
|
||||
# Rate limiting
|
||||
self.max_requests_per_minute = int(os.getenv("MAX_REQUESTS_PER_MINUTE", "100"))
|
||||
@ -30,19 +37,93 @@ class EmbeddingService:
|
||||
def _initialize_clients(self):
|
||||
"""Initialize OpenAI clients for different models/endpoints"""
|
||||
for model_name, model_config in self.embedding_models.items():
|
||||
if model_config.type == "openai-compatible":
|
||||
# Get API key from environment variable specified in config
|
||||
api_key = os.getenv(model_config.api_key_env)
|
||||
if model_config.type != "openai-compatible":
|
||||
continue
|
||||
|
||||
# Get API key from environment variable specified in config
|
||||
api_key = os.getenv(model_config.api_key_env)
|
||||
|
||||
self.clients[model_name] = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=model_config.api_endpoint
|
||||
)
|
||||
logger.info(f"Initialized client for model {model_name}")
|
||||
|
||||
def _initialize_legacy_models(self):
|
||||
"""Initialize legacy ONNX models for embedding generation"""
|
||||
for model_name, model_config in self.embedding_models.items():
|
||||
if model_config.type != "legacy":
|
||||
continue
|
||||
try:
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_config.tokenizer_name)
|
||||
|
||||
self.clients[model_name] = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=model_config.api_endpoint
|
||||
)
|
||||
logger.info(f"Initialized client for model {model_name}")
|
||||
|
||||
|
||||
# Load ONNX model
|
||||
session = ort.InferenceSession(model_config.model_path)
|
||||
|
||||
self.legacy_models[model_name] = {
|
||||
"tokenizer": tokenizer,
|
||||
"session": session,
|
||||
"config": model_config
|
||||
}
|
||||
logger.info(f"Initialized legacy model {model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize legacy model {model_name}: {e}")
|
||||
|
||||
def get_jina_embeddings_1024(self, texts: List[str], model_name: str) -> np.ndarray:
|
||||
"""Generate embeddings using legacy Jina method (same as ml/api/main.py)"""
|
||||
if model_name not in self.legacy_models:
|
||||
raise ValueError(f"Legacy model '{model_name}' not initialized")
|
||||
|
||||
legacy_model = self.legacy_models[model_name]
|
||||
tokenizer = legacy_model["tokenizer"]
|
||||
session = legacy_model["session"]
|
||||
|
||||
# Encode inputs using tokenizer
|
||||
encoded_inputs = tokenizer(
|
||||
texts,
|
||||
add_special_tokens=False, # Don't add special tokens (consistent with JS)
|
||||
return_attention_mask=False,
|
||||
return_tensors=None # Return native Python lists for easier processing
|
||||
)
|
||||
input_ids = encoded_inputs["input_ids"] # Shape: [batch_size, seq_len_i] (variable length per sample)
|
||||
|
||||
# Calculate offsets (consistent with JS cumsum logic)
|
||||
# Get token length for each sample first
|
||||
lengths = [len(ids) for ids in input_ids]
|
||||
# Calculate cumulative sum (exclude last sample)
|
||||
cumsum = []
|
||||
current_sum = 0
|
||||
for l in lengths[:-1]: # Only accumulate first n-1 samples
|
||||
current_sum += l
|
||||
cumsum.append(current_sum)
|
||||
# Build offsets: start with 0, followed by cumulative sums
|
||||
offsets = [0] + cumsum # Shape: [batch_size]
|
||||
|
||||
# Flatten input_ids to 1D array
|
||||
flattened_input_ids = []
|
||||
for ids in input_ids:
|
||||
flattened_input_ids.extend(ids) # Directly concatenate all token ids
|
||||
flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64)
|
||||
|
||||
# Prepare ONNX inputs (consistent tensor shapes with JS)
|
||||
inputs = {
|
||||
"input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids),
|
||||
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64))
|
||||
}
|
||||
|
||||
# Run model inference
|
||||
outputs = session.run(None, inputs)
|
||||
embeddings = outputs[0] # Assume first output is embeddings, shape: [batch_size, embedding_dim]
|
||||
|
||||
return torch.tensor(embeddings, dtype=torch.float32).numpy()
|
||||
|
||||
async def generate_embeddings_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
self,
|
||||
texts: List[str],
|
||||
model: str,
|
||||
batch_size: Optional[int] = None
|
||||
) -> List[List[float]]:
|
||||
@ -54,6 +135,64 @@ class EmbeddingService:
|
||||
|
||||
model_config = self.embedding_models[model]
|
||||
|
||||
# Handle different model types
|
||||
if model_config.type == "legacy":
|
||||
return self._generate_legacy_embeddings_batch(texts, model, batch_size)
|
||||
elif model_config.type == "openai-compatible":
|
||||
return await self._generate_openai_embeddings_batch(texts, model, batch_size)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_config.type}")
|
||||
|
||||
def _generate_legacy_embeddings_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
model: str,
|
||||
batch_size: Optional[int] = None
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings using legacy ONNX model"""
|
||||
if model not in self.legacy_models:
|
||||
raise ValueError(f"Legacy model '{model}' not initialized")
|
||||
|
||||
model_config = self.embedding_models[model]
|
||||
|
||||
# Use model's max_batch_size if not specified
|
||||
if batch_size is None:
|
||||
batch_size = model_config.max_batch_size
|
||||
|
||||
expected_dims = model_config.dimensions
|
||||
all_embeddings = []
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
|
||||
try:
|
||||
# Generate embeddings using legacy method
|
||||
embeddings = self.get_jina_embeddings_1024(batch, model)
|
||||
|
||||
# Convert to list of lists (expected format)
|
||||
batch_embeddings = embeddings.tolist()
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
logger.info(f"Generated legacy embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating legacy embeddings for batch {i//batch_size + 1}: {e}")
|
||||
# Fill with zeros as fallback
|
||||
zero_embedding = [0.0] * expected_dims
|
||||
all_embeddings.extend([zero_embedding] * len(batch))
|
||||
|
||||
return all_embeddings
|
||||
|
||||
async def _generate_openai_embeddings_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
model: str,
|
||||
batch_size: Optional[int] = None
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings using OpenAI-compatible API"""
|
||||
model_config = self.embedding_models[model]
|
||||
|
||||
# Use model's max_batch_size if not specified
|
||||
if batch_size is None:
|
||||
batch_size = model_config.max_batch_size
|
||||
@ -115,11 +254,19 @@ class EmbeddingService:
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Check if embedding service is healthy"""
|
||||
try:
|
||||
if not self.embedding_models:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "embedding_service",
|
||||
"error": "No embedding models configured"
|
||||
}
|
||||
|
||||
# Test with a simple embedding using the first available model
|
||||
model_name = list(self.embedding_models.keys())[0]
|
||||
model_name = config_loader.get_selected_model()
|
||||
model_config = self.embedding_models[model_name]
|
||||
|
||||
test_embedding = await self.generate_embeddings_batch(
|
||||
["health check"],
|
||||
["health check"],
|
||||
model_name,
|
||||
batch_size=1
|
||||
)
|
||||
@ -128,13 +275,16 @@ class EmbeddingService:
|
||||
"status": "healthy",
|
||||
"service": "embedding_service",
|
||||
"model": model_name,
|
||||
"model_type": model_config.type,
|
||||
"dimensions": len(test_embedding[0]) if test_embedding else 0,
|
||||
"available_models": list(self.embedding_models.keys())
|
||||
"available_models": list(self.embedding_models.keys()),
|
||||
"legacy_models": list(self.legacy_models.keys()),
|
||||
"openai_clients": list(self.clients.keys())
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"status": "unhealthy",
|
||||
"service": "embedding_service",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
168
ml_new/training/logger_config.py
Normal file
168
ml_new/training/logger_config.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""
|
||||
Unified logging configuration for ml_new training project
|
||||
Provides colorful level formatting with [level]: [msg] format
|
||||
"""
|
||||
import logging
|
||||
import sys
|
||||
|
||||
class ColorfulFormatter(logging.Formatter):
|
||||
"""Custom formatter with colorful level names and [level]: [msg] format"""
|
||||
|
||||
# ANSI color codes for different log levels
|
||||
COLORS = {
|
||||
'DEBUG': '\033[36m', # Cyan
|
||||
'INFO': '\033[32m', # Green
|
||||
'WARNING': '\033[33m', # Yellow
|
||||
'ERROR': '\033[31m', # Red
|
||||
'CRITICAL': '\033[35m', # Magenta
|
||||
}
|
||||
RESET = '\033[0m' # Reset color
|
||||
|
||||
def format(self, record):
|
||||
# Extract level name
|
||||
level_name = record.levelname
|
||||
|
||||
# Apply color to level name
|
||||
color = self.COLORS.get(level_name, self.RESET)
|
||||
colored_level = f"{color}{level_name}{self.RESET}"
|
||||
|
||||
# Create new format: [colored_level]: [msg]
|
||||
# Get the message part only (without level name, module, etc.)
|
||||
if hasattr(record, 'message'):
|
||||
message = record.message
|
||||
else:
|
||||
message = record.getMessage()
|
||||
|
||||
# Return in the requested format
|
||||
return f"{colored_level}: {message}"
|
||||
|
||||
|
||||
class ColoredConsoleHandler(logging.StreamHandler):
|
||||
"""Console handler that outputs to stderr with colors"""
|
||||
|
||||
def __init__(self, stream=None):
|
||||
if stream is None:
|
||||
stream = sys.stderr
|
||||
super().__init__(stream)
|
||||
self.setFormatter(ColorfulFormatter())
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: str = None,
|
||||
level: str = "INFO",
|
||||
log_file: str = None,
|
||||
console_output: bool = True,
|
||||
file_output: bool = False
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Setup a unified logger with colorful formatting
|
||||
|
||||
Args:
|
||||
name: Logger name (defaults to project name)
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
log_file: Optional log file path
|
||||
console_output: Whether to output to console
|
||||
file_output: Whether to output to file
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(getattr(logging, level.upper()))
|
||||
|
||||
# Clear existing handlers to avoid duplicates
|
||||
logger.handlers.clear()
|
||||
|
||||
# Console handler with colorful formatting
|
||||
if console_output:
|
||||
console_handler = ColoredConsoleHandler()
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# File handler with standard formatting (no colors in files)
|
||||
if file_output and log_file:
|
||||
file_handler = logging.FileHandler(log_file)
|
||||
file_formatter = logging.Formatter(
|
||||
'[%(levelname)s]: %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
file_handler.setFormatter(file_formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
# Prevent propagation to root logger to avoid duplicate messages
|
||||
logger.propagate = False
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str = None) -> logging.Logger:
|
||||
"""
|
||||
Get a logger instance with unified configuration
|
||||
|
||||
Args:
|
||||
name: Logger name (defaults to 'ml_new.training')
|
||||
|
||||
Returns:
|
||||
Logger instance
|
||||
"""
|
||||
if name is None:
|
||||
name = 'ml_new.training'
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
# Only configure if not already configured
|
||||
if not logger.handlers:
|
||||
return setup_logger(name=name)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
# Project-wide logger configuration
|
||||
def configure_project_logger():
|
||||
"""Configure the main project logger with all modules"""
|
||||
|
||||
# Configure root logger for the project
|
||||
root_logger = logging.getLogger('ml_new')
|
||||
root_logger.setLevel(logging.INFO)
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Add console handler
|
||||
console_handler = ColoredConsoleHandler()
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# Prevent propagation to avoid duplicates
|
||||
root_logger.propagate = False
|
||||
|
||||
return root_logger
|
||||
|
||||
|
||||
# Convenience function for quick setup
|
||||
def quick_setup(level: str = "INFO") -> logging.Logger:
|
||||
"""
|
||||
Quick setup for individual modules
|
||||
|
||||
Args:
|
||||
level: Logging level
|
||||
|
||||
Returns:
|
||||
Logger instance for the calling module
|
||||
"""
|
||||
# Get the calling module name
|
||||
import inspect
|
||||
frame = inspect.currentframe().f_back
|
||||
module_name = frame.f_globals.get('__name__', 'ml_new.training')
|
||||
|
||||
return setup_logger(name=module_name, level=level)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the logger configuration
|
||||
logger = get_logger('test_logger')
|
||||
logger.debug("This is a debug message")
|
||||
logger.info("This is an info message")
|
||||
logger.warning("This is a warning message")
|
||||
logger.error("This is an error message")
|
||||
logger.critical("This is a critical message")
|
||||
@ -2,7 +2,6 @@
|
||||
Main FastAPI application for ML training service
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
@ -12,10 +11,9 @@ from database import DatabaseManager
|
||||
from embedding_service import EmbeddingService
|
||||
from dataset_service import DatasetBuilder
|
||||
from api_routes import router, set_dataset_builder
|
||||
from logger_config import get_logger
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Global service instances
|
||||
@ -105,7 +103,7 @@ def main():
|
||||
host="0.0.0.0",
|
||||
port=8322,
|
||||
log_level="info",
|
||||
access_log=True
|
||||
access_log=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ class DatasetBuildRequest(BaseModel):
|
||||
aid_list: List[int] = Field(..., description="List of video AIDs")
|
||||
embedding_model: str = Field(..., description="Embedding model name")
|
||||
force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings")
|
||||
description: Optional[str] = Field(None, description="Optional description for the dataset")
|
||||
|
||||
|
||||
class DatasetBuildResponse(BaseModel):
|
||||
@ -20,6 +21,7 @@ class DatasetBuildResponse(BaseModel):
|
||||
total_records: int
|
||||
status: str
|
||||
message: str
|
||||
description: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@ -59,4 +61,67 @@ class EmbeddingModelInfo(BaseModel):
|
||||
type: str
|
||||
api_endpoint: Optional[str] = None
|
||||
max_tokens: Optional[int] = None
|
||||
max_batch_size: Optional[int] = None
|
||||
max_batch_size: Optional[int] = None
|
||||
|
||||
|
||||
from typing import List, Optional, Dict, Any, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task status enumeration"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TaskProgress(BaseModel):
|
||||
"""Progress information for a task"""
|
||||
current_step: str
|
||||
total_steps: int
|
||||
completed_steps: int
|
||||
percentage: float
|
||||
message: Optional[str] = None
|
||||
estimated_time_remaining: Optional[float] = None
|
||||
|
||||
|
||||
class DatasetBuildTaskStatus(BaseModel):
|
||||
"""Status model for dataset building task"""
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
dataset_id: Optional[str] = None
|
||||
aid_list: List[int]
|
||||
embedding_model: str
|
||||
force_regenerate: bool
|
||||
progress: Optional[TaskProgress] = None
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
"""Response model for task status endpoint"""
|
||||
task_id: str
|
||||
status: TaskStatus
|
||||
progress: Optional[Dict[str, Any]] = None
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class TaskListResponse(BaseModel):
|
||||
"""Response model for listing tasks"""
|
||||
tasks: List[TaskStatusResponse]
|
||||
total_count: int
|
||||
pending_count: int
|
||||
running_count: int
|
||||
completed_count: int
|
||||
failed_count: int
|
||||
Loading…
Reference in New Issue
Block a user