1
0

add: legacy embedding model support, task-based dataset building

This commit is contained in:
alikia2x (寒寒) 2025-12-10 17:10:14 +08:00
parent 77668bbb52
commit 7dbef68cdc
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG Key ID: 56209E0CCD8420C6
11 changed files with 831 additions and 213 deletions

View File

@ -2,4 +2,4 @@
1. Always use bun as package manager.
2. Always write comments in English.
2. IMPORTANT: **Always write comments inside your code (through tool calls) in English, and respond (including description of changes, requests and questions to user ) in the same language as the user's query.**

View File

@ -1,140 +0,0 @@
# ML Training Service
A FastAPI-based ML training service for dataset building, embedding generation, and experiment management.
## Architecture
The service is organized into modular components:
```
ml_new/training/
├── main.py # FastAPI application entry point
├── models.py # Pydantic data models
├── config_loader.py # Configuration loading from TOML
├── database.py # Database connection and operations
├── embedding_service.py # Embedding generation service
├── dataset_service.py # Dataset building logic
├── api_routes.py # API endpoint definitions
├── embedding_models.toml # Embedding model configurations
└── requirements.txt # Python dependencies
```
## Key Components
### 1. Main Application (`main.py`)
- FastAPI app initialization
- CORS middleware configuration
- Service dependency injection
- Startup/shutdown event handlers
### 2. Data Models (`models.py`)
- `DatasetBuildRequest`: Request model for dataset building
- `DatasetBuildResponse`: Response model for dataset building
- `DatasetRecord`: Individual dataset record structure
- `EmbeddingModelInfo`: Embedding model configuration
### 3. Configuration (`config_loader.py`)
- Loads embedding model configurations from TOML
- Manages model parameters (dimensions, API endpoints, etc.)
### 4. Database Layer (`database.py`)
- PostgreSQL connection management
- CRUD operations for video metadata, user labels, and embeddings
- Optimized batch queries to avoid N+1 problems
### 5. Embedding Service (`embedding_service.py`)
- Integration with OpenAI-compatible embedding APIs
- Text preprocessing and checksum generation
- Batch embedding generation with rate limiting
### 6. Dataset Building (`dataset_service.py`)
- Complete dataset construction workflow:
1. Pull raw text from database
2. Text preprocessing (placeholder)
3. Batch embedding generation with deduplication
4. Embedding storage and caching
5. Final dataset compilation with labels
### 7. API Routes (`api_routes.py`)
- `/api/v1/health`: Health check
- `/api/v1/models/embedding`: List available embedding models
- `/api/v1/dataset/build`: Build new dataset
- `/api/v1/dataset/{id}`: Retrieve built dataset
- `/api/v1/datasets`: List all datasets
- `/api/v1/dataset/{id}`: Delete dataset
## Dataset Building Flow
1. **Model Selection**: Choose embedding model from TOML configuration
2. **Data Retrieval**: Pull video metadata and user labels from PostgreSQL
3. **Text Processing**: Combine title, description, and tags
4. **Deduplication**: Generate checksums to avoid duplicate embeddings
5. **Batch Processing**: Generate embeddings for new texts only
6. **Storage**: Store embeddings in database with caching
7. **Final Assembly**: Combine embeddings with labels using consensus mechanism
## Configuration
### Embedding Models (`embedding_models.toml`)
```toml
[text-embedding-3-large]
name = "text-embedding-3-large"
dimensions = 3072
type = "openai"
api_endpoint = "https://api.openai.com/v1/embeddings"
max_tokens = 8192
max_batch_size = 100
```
### Environment Variables
- `DATABASE_URL`: PostgreSQL connection string
- `OPENAI_API_KEY`: OpenAI API key for embedding generation
## Usage
### Start the Service
```bash
cd ml_new/training
python main.py
```
### Build a Dataset
```bash
curl -X POST "http://localhost:8322/v1/dataset/build" \
-H "Content-Type: application/json" \
-d '{
"aid_list": [170001, 170002, 170003],
"embedding_model": "text-embedding-3-large",
"force_regenerate": false
}'
```
### Check Health
```bash
curl "http://localhost:8322/v1/health"
```
### List Embedding Models
```bash
curl "http://localhost:8322/v1/models/embedding"
```
## Features
- **High Performance**: Optimized database queries with batch operations
- **Deduplication**: Text-level deduplication using MD5 checksums
- **Consensus Labels**: Majority vote mechanism for user annotations
- **Batch Processing**: Efficient embedding generation and storage
- **Error Handling**: Comprehensive error handling and logging
- **Async Support**: Fully asynchronous operations for scalability
- **CORS Enabled**: Ready for frontend integration
## Production Considerations
- Replace in-memory dataset storage with database
- Add authentication and authorization
- Implement rate limiting for API endpoints
- Add monitoring and metrics collection
- Configure proper logging levels
- Set up database connection pooling
- Add API documentation with OpenAPI/Swagger

View File

@ -2,18 +2,19 @@
API routes for the ML training service
"""
import logging
import uuid
from typing import Optional
from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi import APIRouter, HTTPException
from fastapi.responses import JSONResponse
from config_loader import config_loader
from models import DatasetBuildRequest, DatasetBuildResponse
from models import DatasetBuildRequest, DatasetBuildResponse, TaskStatus, TaskStatusResponse, TaskListResponse
from dataset_service import DatasetBuilder
from logger_config import get_logger
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
# Create router
router = APIRouter(prefix="/v1")
@ -80,8 +81,8 @@ async def get_embedding_models():
@router.post("/dataset/build", response_model=DatasetBuildResponse)
async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks: BackgroundTasks):
"""Build dataset endpoint"""
async def build_dataset_endpoint(request: DatasetBuildRequest):
"""Build dataset endpoint with task tracking"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
@ -91,20 +92,22 @@ async def build_dataset_endpoint(request: DatasetBuildRequest, background_tasks:
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
dataset_id = str(uuid.uuid4())
# Start background task for dataset building
background_tasks.add_task(
dataset_builder.build_dataset,
# Start task-based dataset building
task_id = await dataset_builder.start_dataset_build_task(
dataset_id,
request.aid_list,
request.embedding_model,
request.force_regenerate
request.force_regenerate,
request.description
)
return DatasetBuildResponse(
dataset_id=dataset_id,
total_records=len(request.aid_list),
status="started",
message="Dataset building started"
message=f"Dataset building started with task ID: {task_id}",
description=request.description
)
@ -126,6 +129,7 @@ async def get_dataset_endpoint(dataset_id: str):
return {
"dataset_id": dataset_id,
"dataset": dataset_info["dataset"],
"description": dataset_info.get("description"),
"stats": dataset_info["stats"],
"created_at": dataset_info["created_at"]
}
@ -143,6 +147,7 @@ async def list_datasets():
if "error" not in dataset_info:
datasets.append({
"dataset_id": dataset_id,
"description": dataset_info.get("description"),
"stats": dataset_info["stats"],
"created_at": dataset_info["created_at"]
})
@ -171,7 +176,16 @@ async def list_datasets_endpoint():
raise HTTPException(status_code=503, detail="Dataset builder not available")
datasets = dataset_builder.list_datasets()
return {"datasets": datasets}
# Add description to each dataset
datasets_with_description = []
for dataset in datasets:
dataset_info = dataset_builder.get_dataset(dataset["dataset_id"])
if dataset_info and "description" in dataset_info:
dataset["description"] = dataset_info["description"]
else:
dataset["description"] = None
datasets_with_description.append(dataset)
return {"datasets": datasets_with_description}
@router.get("/datasets/stats")
@ -193,4 +207,111 @@ async def cleanup_datasets_endpoint(max_age_days: int = 30):
raise HTTPException(status_code=503, detail="Dataset builder not available")
await dataset_builder.cleanup_old_datasets(max_age_days)
return {"message": f"Cleanup completed for datasets older than {max_age_days} days"}
return {"message": f"Cleanup completed for datasets older than {max_age_days} days"}
# Task Status Endpoints
@router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
async def get_task_status_endpoint(task_id: str):
"""Get status of a specific task"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
task_status = dataset_builder.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail="Task not found")
# Convert to response model
progress_dict = None
if task_status.progress:
progress_dict = {
"current_step": task_status.progress.current_step,
"total_steps": task_status.progress.total_steps,
"completed_steps": task_status.progress.completed_steps,
"percentage": task_status.progress.percentage,
"message": task_status.progress.message
}
return TaskStatusResponse(
task_id=task_status.task_id,
status=task_status.status,
progress=progress_dict,
result=task_status.result,
error=task_status.error_message,
created_at=task_status.created_at,
started_at=task_status.started_at,
completed_at=task_status.completed_at
)
@router.get("/tasks", response_model=TaskListResponse)
async def list_tasks_endpoint(status: Optional[TaskStatus] = None, limit: int = 50):
"""List all tasks, optionally filtered by status"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
tasks = dataset_builder.list_tasks(status_filter=status)
# Limit results
if limit > 0:
tasks = tasks[:limit]
# Convert to response models
task_responses = []
for task_status in tasks:
progress_dict = None
if task_status.progress:
progress_dict = {
"current_step": task_status.progress.current_step,
"total_steps": task_status.progress.total_steps,
"completed_steps": task_status.progress.completed_steps,
"percentage": task_status.progress.percentage,
"message": task_status.progress.message
}
task_responses.append(TaskStatusResponse(
task_id=task_status.task_id,
status=task_status.status,
progress=progress_dict,
result=task_status.result,
error=task_status.error_message,
created_at=task_status.created_at,
started_at=task_status.started_at,
completed_at=task_status.completed_at
))
# Get statistics
stats = dataset_builder.get_task_statistics()
return TaskListResponse(
tasks=task_responses,
total_count=stats["total_tasks"],
pending_count=stats["status_counts"][TaskStatus.PENDING],
running_count=stats["status_counts"][TaskStatus.RUNNING],
completed_count=stats["status_counts"][TaskStatus.COMPLETED],
failed_count=stats["status_counts"][TaskStatus.FAILED]
)
@router.get("/tasks/stats")
async def get_task_statistics_endpoint():
"""Get statistics about all tasks"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
return dataset_builder.get_task_statistics()
@router.post("/tasks/cleanup")
async def cleanup_tasks_endpoint(max_age_hours: int = 24):
"""Clean up completed/failed tasks older than specified hours"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
cleaned_count = await dataset_builder.cleanup_completed_tasks(max_age_hours)
return {"message": f"Cleaned up {cleaned_count} tasks older than {max_age_hours} hours"}

View File

@ -6,9 +6,9 @@ import toml
import os
from typing import Dict
from pydantic import BaseModel
import logging
from logger_config import get_logger
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class EmbeddingModelConfig(BaseModel):
@ -19,6 +19,8 @@ class EmbeddingModelConfig(BaseModel):
max_tokens: int = 8191
max_batch_size: int = 8
api_key_env: str = "OPENAI_API_KEY"
model_path: str = ""
tokenizer_name: str = ""
class ConfigLoader:
@ -31,6 +33,7 @@ class ConfigLoader:
self.config_path = config_path
self.embedding_models: Dict[str, EmbeddingModelConfig] = {}
self.selected_model: str = None
self._load_config()
def _load_config(self):
@ -51,6 +54,8 @@ class ConfigLoader:
self.embedding_models[model_key] = EmbeddingModelConfig(
**model_data
)
self.selected_model = config_data.get("model", list(self.embedding_models.keys())[0])
logger.info(
f"Loaded {len(self.embedding_models)} embedding models from {self.config_path}"
@ -58,6 +63,10 @@ class ConfigLoader:
except Exception as e:
logger.error(f"Failed to load config from {self.config_path}: {e}")
def get_selected_model(self) -> str:
"""Get selected model for health check"""
return self.selected_model
def get_embedding_models(self) -> Dict[str, EmbeddingModelConfig]:
"""Get all available embedding models"""

View File

@ -7,13 +7,13 @@ import hashlib
from typing import List, Dict, Optional, Any
from datetime import datetime
import asyncpg
import logging
from config_loader import config_loader
from dotenv import load_dotenv
from logger_config import get_logger
load_dotenv()
load_dotenv()
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
# Database configuration
DATABASE_URL = os.getenv("DATABASE_URL")
@ -120,7 +120,7 @@ class DatabaseManager:
async with self.pool.acquire() as conn:
query = """
SELECT data_checksum, vec_2048, vec_1536, vec_1024, created_at
SELECT data_checksum, dimensions, vec_2048, vec_1536, vec_1024, created_at
FROM internal.embeddings
WHERE model_name = $1 AND data_checksum = ANY($2::text[])
"""
@ -177,14 +177,15 @@ class DatabaseManager:
query = f"""
INSERT INTO internal.embeddings
(model_name, data_checksum, {vec_column}, created_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (data_checksum) DO NOTHING
(model_name, dimensions, data_checksum, {vec_column}, created_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (model_name, dimensions, data_checksum) DO NOTHING
"""
await conn.execute(
query,
data["model_name"],
data["dimensions"],
data["checksum"],
vector_str,
datetime.now(),

View File

@ -2,20 +2,21 @@
Dataset building service - handles the complete dataset construction flow
"""
import os
import json
import logging
import asyncio
from pathlib import Path
from typing import List, Dict, Any, Optional
from datetime import datetime
import threading
from database import DatabaseManager
from embedding_service import EmbeddingService
from config_loader import config_loader
from logger_config import get_logger
from models import TaskStatus, DatasetBuildTaskStatus, TaskProgress
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class DatasetBuilder:
@ -29,6 +30,11 @@ class DatasetBuilder:
# Load existing datasets from file system
self.dataset_storage: Dict[str, Dict] = self._load_all_datasets()
# Task status tracking
self._task_status_lock = threading.Lock()
self.task_statuses: Dict[str, DatasetBuildTaskStatus] = {}
self.running_tasks: Dict[str, asyncio.Task] = {}
def _get_dataset_file_path(self, dataset_id: str) -> Path:
@ -80,6 +86,134 @@ class DatasetBuilder:
return datasets
def _create_task_status(self, task_id: str, dataset_id: str, aid_list: List[int],
embedding_model: str, force_regenerate: bool) -> DatasetBuildTaskStatus:
"""Create initial task status"""
with self._task_status_lock:
task_status = DatasetBuildTaskStatus(
task_id=task_id,
status=TaskStatus.PENDING,
dataset_id=dataset_id,
aid_list=aid_list,
embedding_model=embedding_model,
force_regenerate=force_regenerate,
created_at=datetime.now(),
progress=TaskProgress(
current_step="initialized",
total_steps=7,
completed_steps=0,
percentage=0.0,
message="Task initialized"
)
)
self.task_statuses[task_id] = task_status
return task_status
def _update_task_status(self, task_id: str, **kwargs):
"""Update task status with new values"""
with self._task_status_lock:
if task_id in self.task_statuses:
task_status = self.task_statuses[task_id]
for key, value in kwargs.items():
if hasattr(task_status, key):
setattr(task_status, key, value)
self.task_statuses[task_id] = task_status
def _update_task_progress(self, task_id: str, current_step: str, completed_steps: int,
message: str = None, percentage: float = None):
"""Update task progress"""
with self._task_status_lock:
if task_id in self.task_statuses:
task_status = self.task_statuses[task_id]
if percentage is not None:
progress_percentage = percentage
else:
progress_percentage = (completed_steps / task_status.progress.total_steps) * 100 if task_status.progress else 0.0
task_status.progress = TaskProgress(
current_step=current_step,
total_steps=task_status.progress.total_steps if task_status.progress else 7,
completed_steps=completed_steps,
percentage=progress_percentage,
message=message
)
self.task_statuses[task_id] = task_status
def get_task_status(self, task_id: str) -> Optional[DatasetBuildTaskStatus]:
"""Get task status by task ID"""
with self._task_status_lock:
return self.task_statuses.get(task_id)
def list_tasks(self, status_filter: Optional[TaskStatus] = None) -> List[DatasetBuildTaskStatus]:
"""List all tasks, optionally filtered by status"""
with self._task_status_lock:
tasks = list(self.task_statuses.values())
if status_filter:
tasks = [task for task in tasks if task.status == status_filter]
# Sort by creation time (newest first)
tasks.sort(key=lambda x: x.created_at, reverse=True)
return tasks
def get_task_statistics(self) -> Dict[str, Any]:
"""Get statistics about all tasks"""
with self._task_status_lock:
total_tasks = len(self.task_statuses)
status_counts = {
TaskStatus.PENDING: 0,
TaskStatus.RUNNING: 0,
TaskStatus.COMPLETED: 0,
TaskStatus.FAILED: 0,
TaskStatus.CANCELLED: 0
}
for task_status in self.task_statuses.values():
status_counts[task_status.status] += 1
return {
"total_tasks": total_tasks,
"status_counts": status_counts,
"running_tasks": status_counts[TaskStatus.RUNNING]
}
async def cleanup_completed_tasks(self, max_age_hours: int = 24):
"""Clean up completed/failed tasks older than specified hours"""
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
cleaned_count = 0
with self._task_status_lock:
tasks_to_remove = []
for task_id, task_status in self.task_statuses.items():
if task_status.completed_at:
try:
completed_time = task_status.completed_at.timestamp()
if completed_time < cutoff_time:
tasks_to_remove.append(task_id)
except Exception as e:
logger.warning(f"Failed to check completion time for task {task_id}: {e}")
for task_id in tasks_to_remove:
# Remove from task statuses
del self.task_statuses[task_id]
# Remove from running tasks if still there
if task_id in self.running_tasks:
del self.running_tasks[task_id]
cleaned_count += 1
if cleaned_count > 0:
logger.info(f"Cleaned up {cleaned_count} old tasks")
return cleaned_count
async def cleanup_old_datasets(self, max_age_days: int = 30):
"""Remove datasets older than specified days"""
try:
@ -110,43 +244,54 @@ class DatasetBuilder:
except Exception as e:
logger.error(f"Failed to cleanup old datasets: {e}")
async def build_dataset(self, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool = False) -> str:
async def build_dataset_with_task_tracking(self, task_id: str, dataset_id: str, aid_list: List[int],
embedding_model: str, force_regenerate: bool = False,
description: Optional[str] = None) -> str:
"""
Build dataset with the specified flow:
1. Select embedding model (from TOML config)
2. Pull raw text from database
3. Preprocess (placeholder for now)
4. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
5. Write to embeddings table
6. Pull all needed embeddings to create dataset with format: embeddings, label
Build dataset with task status tracking
Steps:
1. Initialize task status
2. Select embedding model (from TOML config)
3. Pull raw text from database
4. Preprocess (placeholder for now)
5. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
6. Write to embeddings table
7. Pull all needed embeddings to create dataset with format: embeddings, label
"""
# Update task status to running
self._update_task_status(task_id, status=TaskStatus.RUNNING, started_at=datetime.now())
try:
logger.info(f"Starting dataset building task {dataset_id}")
logger.info(f"Starting dataset building task {dataset_id} (task_id: {task_id})")
# Step 1: Get model configuration
EMBEDDING_MODELS = config_loader.get_embedding_models()
# Get model configuration
if embedding_model not in EMBEDDING_MODELS:
raise ValueError(f"Invalid embedding model: {embedding_model}")
model_config = EMBEDDING_MODELS[embedding_model]
self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database")
# Step 1: Get video metadata from database
# Step 2: Get video metadata from database
metadata = await self.db_manager.get_video_metadata(aid_list)
self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database")
# Step 2: Get user labels
# Step 3: Get user labels
labels = await self.db_manager.get_user_labels(aid_list)
self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums")
# Step 3: Prepare text data and checksums
# Step 4: Prepare text data and checksums
text_data = []
total_aids = len(aid_list)
for aid in aid_list:
for i, aid in enumerate(aid_list):
if aid in metadata:
# Combine title, description, tags
combined_text = self.embedding_service.combine_video_text(
metadata[aid]['title'],
metadata[aid]['description'],
metadata[aid]['description'],
metadata[aid]['tags']
)
@ -158,12 +303,24 @@ class DatasetBuilder:
'text': combined_text,
'checksum': checksum
})
# Update progress for text preparation
if i % 10 == 0 or i == total_aids - 1: # Update every 10 items or at the end
progress_pct = 3 + (i + 1) / total_aids
self._update_task_progress(
task_id,
"preparing_text",
min(3, int(progress_pct)),
f"Prepared {i + 1}/{total_aids} text entries"
)
# Step 4: Check existing embeddings
self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings")
# Step 5: Check existing embeddings
checksums = [item['checksum'] for item in text_data]
existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model)
# Step 5: Generate new embeddings for texts that don't have them
# Step 6: Generate new embeddings for texts that don't have them
new_embeddings_needed = []
for item in text_data:
if item['checksum'] not in existing_embeddings or force_regenerate:
@ -171,13 +328,20 @@ class DatasetBuilder:
new_embeddings_count = 0
if new_embeddings_needed:
self._update_task_progress(
task_id,
"generating_embeddings",
5,
f"Generating {len(new_embeddings_needed)} new embeddings"
)
logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
generated_embeddings = await self.embedding_service.generate_embeddings_batch(
new_embeddings_needed,
new_embeddings_needed,
embedding_model
)
# Step 6: Store new embeddings in database
# Step 7: Store new embeddings in database
embeddings_to_store = []
for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)):
checksum = self.embedding_service.create_text_checksum(text)
@ -198,11 +362,13 @@ class DatasetBuilder:
f'vec_{model_config.dimensions}': emb_data['vector']
}
# Step 7: Build final dataset
self._update_task_progress(task_id, "building_dataset", 6, "Building final dataset")
# Step 8: Build final dataset
dataset = []
inconsistent_count = 0
for item in text_data:
for i, item in enumerate(text_data):
aid = item['aid']
checksum = item['checksum']
@ -241,6 +407,16 @@ class DatasetBuilder:
'inconsistent': inconsistent,
'text_checksum': checksum
})
# Update progress for dataset building
if i % 10 == 0 or i == len(text_data) - 1: # Update every 10 items or at the end
progress_pct = 6 + (i + 1) / len(text_data)
self._update_task_progress(
task_id,
"building_dataset",
min(6, int(progress_pct)),
f"Built {i + 1}/{len(text_data)} dataset records"
)
reused_count = len(dataset) - new_embeddings_count
@ -249,6 +425,7 @@ class DatasetBuilder:
# Prepare dataset data
dataset_data = {
'dataset': dataset,
'description': description,
'stats': {
'total_records': len(dataset),
'new_embeddings': new_embeddings_count,
@ -259,6 +436,8 @@ class DatasetBuilder:
'created_at': datetime.now().isoformat()
}
self._update_task_progress(task_id, "saving_dataset", 7, "Saving dataset to storage")
# Save to file and memory cache
if self._save_dataset_to_file(dataset_id, dataset_data):
self.dataset_storage[dataset_id] = dataset_data
@ -267,11 +446,46 @@ class DatasetBuilder:
logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only")
self.dataset_storage[dataset_id] = dataset_data
# Update task status to completed
result = {
'dataset_id': dataset_id,
'stats': dataset_data['stats']
}
self._update_task_status(
task_id,
status=TaskStatus.COMPLETED,
completed_at=datetime.now(),
result=result,
progress=TaskProgress(
current_step="completed",
total_steps=7,
completed_steps=7,
percentage=100.0,
message="Dataset building completed successfully"
)
)
return dataset_id
except Exception as e:
logger.error(f"Dataset building failed for {dataset_id}: {str(e)}")
# Update task status to failed
self._update_task_status(
task_id,
status=TaskStatus.FAILED,
completed_at=datetime.now(),
error_message=str(e),
progress=TaskProgress(
current_step="failed",
total_steps=7,
completed_steps=0,
percentage=0.0,
message=f"Task failed: {str(e)}"
)
)
# Store error information
error_data = {
'error': str(e),
@ -283,6 +497,30 @@ class DatasetBuilder:
self.dataset_storage[dataset_id] = error_data
raise
async def start_dataset_build_task(self, dataset_id: str, aid_list: List[int],
embedding_model: str, force_regenerate: bool = False,
description: Optional[str] = None) -> str:
"""
Start a dataset building task and return task ID for status tracking
"""
import uuid
task_id = str(uuid.uuid4())
# Create task status
task_status = self._create_task_status(task_id, dataset_id, aid_list, embedding_model, force_regenerate)
# Start the actual task
task = asyncio.create_task(
self.build_dataset_with_task_tracking(task_id, dataset_id, aid_list, embedding_model, force_regenerate, description)
)
# Store the running task
with self._task_status_lock:
self.running_tasks[task_id] = task
return task_id
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get built dataset by ID"""
# First check memory cache

View File

@ -1,6 +1,6 @@
# Embedding Models Configuration
model = "qwen3-embedding"
model = "jina-embedding-v3-m2v"
[models.qwen3-embedding]
name = "text-embedding-v4"
@ -10,3 +10,11 @@ api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
max_tokens = 8192
max_batch_size = 10
api_key_env = "ALIYUN_KEY"
[models.jina-embedding-v3-m2v]
name = "jina-embedding-v3-m2v-1024"
dimensions = 1024
type = "legacy"
model_path = "../../model/embedding/model.onnx"
tokenizer_name = "jinaai/jina-embeddings-v3"
max_batch_size = 128

View File

@ -1,18 +1,22 @@
"""
Embedding service for generating embeddings using OpenAI-compatible API
Embedding service for generating embeddings using OpenAI-compatible API and legacy methods
"""
import asyncio
import hashlib
from typing import List, Dict, Any, Optional
import logging
from openai import AsyncOpenAI
import os
from config_loader import config_loader
from dotenv import load_dotenv
import torch
import numpy as np
from transformers import AutoTokenizer
import onnxruntime as ort
from logger_config import get_logger
load_dotenv()
load_dotenv()
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class EmbeddingService:
def __init__(self):
@ -21,7 +25,10 @@ class EmbeddingService:
# Initialize OpenAI client (will be configured per model)
self.clients: Dict[str, AsyncOpenAI] = {}
self.legacy_models: Dict[str, Dict[str, Any]] = {}
self._initialize_clients()
self._initialize_legacy_models()
# Rate limiting
self.max_requests_per_minute = int(os.getenv("MAX_REQUESTS_PER_MINUTE", "100"))
@ -30,19 +37,93 @@ class EmbeddingService:
def _initialize_clients(self):
"""Initialize OpenAI clients for different models/endpoints"""
for model_name, model_config in self.embedding_models.items():
if model_config.type == "openai-compatible":
# Get API key from environment variable specified in config
api_key = os.getenv(model_config.api_key_env)
if model_config.type != "openai-compatible":
continue
# Get API key from environment variable specified in config
api_key = os.getenv(model_config.api_key_env)
self.clients[model_name] = AsyncOpenAI(
api_key=api_key,
base_url=model_config.api_endpoint
)
logger.info(f"Initialized client for model {model_name}")
def _initialize_legacy_models(self):
"""Initialize legacy ONNX models for embedding generation"""
for model_name, model_config in self.embedding_models.items():
if model_config.type != "legacy":
continue
try:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_config.tokenizer_name)
self.clients[model_name] = AsyncOpenAI(
api_key=api_key,
base_url=model_config.api_endpoint
)
logger.info(f"Initialized client for model {model_name}")
# Load ONNX model
session = ort.InferenceSession(model_config.model_path)
self.legacy_models[model_name] = {
"tokenizer": tokenizer,
"session": session,
"config": model_config
}
logger.info(f"Initialized legacy model {model_name}")
except Exception as e:
logger.error(f"Failed to initialize legacy model {model_name}: {e}")
def get_jina_embeddings_1024(self, texts: List[str], model_name: str) -> np.ndarray:
"""Generate embeddings using legacy Jina method (same as ml/api/main.py)"""
if model_name not in self.legacy_models:
raise ValueError(f"Legacy model '{model_name}' not initialized")
legacy_model = self.legacy_models[model_name]
tokenizer = legacy_model["tokenizer"]
session = legacy_model["session"]
# Encode inputs using tokenizer
encoded_inputs = tokenizer(
texts,
add_special_tokens=False, # Don't add special tokens (consistent with JS)
return_attention_mask=False,
return_tensors=None # Return native Python lists for easier processing
)
input_ids = encoded_inputs["input_ids"] # Shape: [batch_size, seq_len_i] (variable length per sample)
# Calculate offsets (consistent with JS cumsum logic)
# Get token length for each sample first
lengths = [len(ids) for ids in input_ids]
# Calculate cumulative sum (exclude last sample)
cumsum = []
current_sum = 0
for l in lengths[:-1]: # Only accumulate first n-1 samples
current_sum += l
cumsum.append(current_sum)
# Build offsets: start with 0, followed by cumulative sums
offsets = [0] + cumsum # Shape: [batch_size]
# Flatten input_ids to 1D array
flattened_input_ids = []
for ids in input_ids:
flattened_input_ids.extend(ids) # Directly concatenate all token ids
flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64)
# Prepare ONNX inputs (consistent tensor shapes with JS)
inputs = {
"input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids),
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64))
}
# Run model inference
outputs = session.run(None, inputs)
embeddings = outputs[0] # Assume first output is embeddings, shape: [batch_size, embedding_dim]
return torch.tensor(embeddings, dtype=torch.float32).numpy()
async def generate_embeddings_batch(
self,
texts: List[str],
self,
texts: List[str],
model: str,
batch_size: Optional[int] = None
) -> List[List[float]]:
@ -54,6 +135,64 @@ class EmbeddingService:
model_config = self.embedding_models[model]
# Handle different model types
if model_config.type == "legacy":
return self._generate_legacy_embeddings_batch(texts, model, batch_size)
elif model_config.type == "openai-compatible":
return await self._generate_openai_embeddings_batch(texts, model, batch_size)
else:
raise ValueError(f"Unsupported model type: {model_config.type}")
def _generate_legacy_embeddings_batch(
self,
texts: List[str],
model: str,
batch_size: Optional[int] = None
) -> List[List[float]]:
"""Generate embeddings using legacy ONNX model"""
if model not in self.legacy_models:
raise ValueError(f"Legacy model '{model}' not initialized")
model_config = self.embedding_models[model]
# Use model's max_batch_size if not specified
if batch_size is None:
batch_size = model_config.max_batch_size
expected_dims = model_config.dimensions
all_embeddings = []
# Process in batches
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
try:
# Generate embeddings using legacy method
embeddings = self.get_jina_embeddings_1024(batch, model)
# Convert to list of lists (expected format)
batch_embeddings = embeddings.tolist()
all_embeddings.extend(batch_embeddings)
logger.info(f"Generated legacy embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
except Exception as e:
logger.error(f"Error generating legacy embeddings for batch {i//batch_size + 1}: {e}")
# Fill with zeros as fallback
zero_embedding = [0.0] * expected_dims
all_embeddings.extend([zero_embedding] * len(batch))
return all_embeddings
async def _generate_openai_embeddings_batch(
self,
texts: List[str],
model: str,
batch_size: Optional[int] = None
) -> List[List[float]]:
"""Generate embeddings using OpenAI-compatible API"""
model_config = self.embedding_models[model]
# Use model's max_batch_size if not specified
if batch_size is None:
batch_size = model_config.max_batch_size
@ -115,11 +254,19 @@ class EmbeddingService:
async def health_check(self) -> Dict[str, Any]:
"""Check if embedding service is healthy"""
try:
if not self.embedding_models:
return {
"status": "unhealthy",
"service": "embedding_service",
"error": "No embedding models configured"
}
# Test with a simple embedding using the first available model
model_name = list(self.embedding_models.keys())[0]
model_name = config_loader.get_selected_model()
model_config = self.embedding_models[model_name]
test_embedding = await self.generate_embeddings_batch(
["health check"],
["health check"],
model_name,
batch_size=1
)
@ -128,13 +275,16 @@ class EmbeddingService:
"status": "healthy",
"service": "embedding_service",
"model": model_name,
"model_type": model_config.type,
"dimensions": len(test_embedding[0]) if test_embedding else 0,
"available_models": list(self.embedding_models.keys())
"available_models": list(self.embedding_models.keys()),
"legacy_models": list(self.legacy_models.keys()),
"openai_clients": list(self.clients.keys())
}
except Exception as e:
return {
"status": "unhealthy",
"status": "unhealthy",
"service": "embedding_service",
"error": str(e)
}

View File

@ -0,0 +1,168 @@
"""
Unified logging configuration for ml_new training project
Provides colorful level formatting with [level]: [msg] format
"""
import logging
import sys
class ColorfulFormatter(logging.Formatter):
"""Custom formatter with colorful level names and [level]: [msg] format"""
# ANSI color codes for different log levels
COLORS = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
}
RESET = '\033[0m' # Reset color
def format(self, record):
# Extract level name
level_name = record.levelname
# Apply color to level name
color = self.COLORS.get(level_name, self.RESET)
colored_level = f"{color}{level_name}{self.RESET}"
# Create new format: [colored_level]: [msg]
# Get the message part only (without level name, module, etc.)
if hasattr(record, 'message'):
message = record.message
else:
message = record.getMessage()
# Return in the requested format
return f"{colored_level}: {message}"
class ColoredConsoleHandler(logging.StreamHandler):
"""Console handler that outputs to stderr with colors"""
def __init__(self, stream=None):
if stream is None:
stream = sys.stderr
super().__init__(stream)
self.setFormatter(ColorfulFormatter())
def setup_logger(
name: str = None,
level: str = "INFO",
log_file: str = None,
console_output: bool = True,
file_output: bool = False
) -> logging.Logger:
"""
Setup a unified logger with colorful formatting
Args:
name: Logger name (defaults to project name)
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
log_file: Optional log file path
console_output: Whether to output to console
file_output: Whether to output to file
Returns:
Configured logger instance
"""
# Create logger
logger = logging.getLogger(name)
logger.setLevel(getattr(logging, level.upper()))
# Clear existing handlers to avoid duplicates
logger.handlers.clear()
# Console handler with colorful formatting
if console_output:
console_handler = ColoredConsoleHandler()
logger.addHandler(console_handler)
# File handler with standard formatting (no colors in files)
if file_output and log_file:
file_handler = logging.FileHandler(log_file)
file_formatter = logging.Formatter(
'[%(levelname)s]: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
# Prevent propagation to root logger to avoid duplicate messages
logger.propagate = False
return logger
def get_logger(name: str = None) -> logging.Logger:
"""
Get a logger instance with unified configuration
Args:
name: Logger name (defaults to 'ml_new.training')
Returns:
Logger instance
"""
if name is None:
name = 'ml_new.training'
logger = logging.getLogger(name)
# Only configure if not already configured
if not logger.handlers:
return setup_logger(name=name)
return logger
# Project-wide logger configuration
def configure_project_logger():
"""Configure the main project logger with all modules"""
# Configure root logger for the project
root_logger = logging.getLogger('ml_new')
root_logger.setLevel(logging.INFO)
# Clear existing handlers
root_logger.handlers.clear()
# Add console handler
console_handler = ColoredConsoleHandler()
root_logger.addHandler(console_handler)
# Prevent propagation to avoid duplicates
root_logger.propagate = False
return root_logger
# Convenience function for quick setup
def quick_setup(level: str = "INFO") -> logging.Logger:
"""
Quick setup for individual modules
Args:
level: Logging level
Returns:
Logger instance for the calling module
"""
# Get the calling module name
import inspect
frame = inspect.currentframe().f_back
module_name = frame.f_globals.get('__name__', 'ml_new.training')
return setup_logger(name=module_name, level=level)
if __name__ == "__main__":
# Test the logger configuration
logger = get_logger('test_logger')
logger.debug("This is a debug message")
logger.info("This is an info message")
logger.warning("This is a warning message")
logger.error("This is an error message")
logger.critical("This is a critical message")

View File

@ -2,7 +2,6 @@
Main FastAPI application for ML training service
"""
import logging
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI
@ -12,10 +11,9 @@ from database import DatabaseManager
from embedding_service import EmbeddingService
from dataset_service import DatasetBuilder
from api_routes import router, set_dataset_builder
from logger_config import get_logger
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
# Global service instances
@ -105,7 +103,7 @@ def main():
host="0.0.0.0",
port=8322,
log_level="info",
access_log=True
access_log=True,
)

View File

@ -12,6 +12,7 @@ class DatasetBuildRequest(BaseModel):
aid_list: List[int] = Field(..., description="List of video AIDs")
embedding_model: str = Field(..., description="Embedding model name")
force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings")
description: Optional[str] = Field(None, description="Optional description for the dataset")
class DatasetBuildResponse(BaseModel):
@ -20,6 +21,7 @@ class DatasetBuildResponse(BaseModel):
total_records: int
status: str
message: str
description: Optional[str] = None
created_at: Optional[datetime] = None
@ -59,4 +61,67 @@ class EmbeddingModelInfo(BaseModel):
type: str
api_endpoint: Optional[str] = None
max_tokens: Optional[int] = None
max_batch_size: Optional[int] = None
max_batch_size: Optional[int] = None
from typing import List, Optional, Dict, Any, Literal
from pydantic import BaseModel, Field
from datetime import datetime
from enum import Enum
class TaskStatus(str, Enum):
"""Task status enumeration"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TaskProgress(BaseModel):
"""Progress information for a task"""
current_step: str
total_steps: int
completed_steps: int
percentage: float
message: Optional[str] = None
estimated_time_remaining: Optional[float] = None
class DatasetBuildTaskStatus(BaseModel):
"""Status model for dataset building task"""
task_id: str
status: TaskStatus
dataset_id: Optional[str] = None
aid_list: List[int]
embedding_model: str
force_regenerate: bool
progress: Optional[TaskProgress] = None
error_message: Optional[str] = None
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
result: Optional[Dict[str, Any]] = None
class TaskStatusResponse(BaseModel):
"""Response model for task status endpoint"""
task_id: str
status: TaskStatus
progress: Optional[Dict[str, Any]] = None
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
created_at: datetime
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
class TaskListResponse(BaseModel):
"""Response model for listing tasks"""
tasks: List[TaskStatusResponse]
total_count: int
pending_count: int
running_count: int
completed_count: int
failed_count: int