489 lines
21 KiB
Python
489 lines
21 KiB
Python
"""
|
|
Dataset building service - handles the complete dataset construction flow
|
|
"""
|
|
|
|
import asyncio
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import datetime
|
|
import threading
|
|
|
|
from ml_new.data.database import DatabaseManager
|
|
from ml_new.data.embedding_service import EmbeddingService
|
|
from ml_new.config.config_loader import config_loader
|
|
from ml_new.config.logger_config import get_logger
|
|
from ml_new.models import TaskStatus, DatasetBuildTaskStatus, TaskProgress
|
|
from ml_new.data.dataset_storage_parquet import ParquetDatasetStorage
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class DatasetBuilder:
|
|
"""Service for building datasets with the specified flow"""
|
|
|
|
def __init__(self, db_manager: DatabaseManager, embedding_service: EmbeddingService, storage_dir: str = "datasets"):
|
|
self.db_manager = db_manager
|
|
self.embedding_service = embedding_service
|
|
self.storage_dir = Path(storage_dir)
|
|
self.storage_dir.mkdir(exist_ok=True)
|
|
|
|
self.storage = ParquetDatasetStorage(str(self.storage_dir))
|
|
|
|
# Task status tracking
|
|
self._task_status_lock = threading.Lock()
|
|
self.task_statuses: Dict[str, DatasetBuildTaskStatus] = {}
|
|
self.running_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
|
|
def _create_task_status(self, task_id: str, dataset_id: str, aid_list: List[int],
|
|
embedding_model: str, force_regenerate: bool) -> DatasetBuildTaskStatus:
|
|
"""Create initial task status"""
|
|
with self._task_status_lock:
|
|
task_status = DatasetBuildTaskStatus(
|
|
task_id=task_id,
|
|
status=TaskStatus.PENDING,
|
|
dataset_id=dataset_id,
|
|
aid_list=aid_list,
|
|
embedding_model=embedding_model,
|
|
force_regenerate=force_regenerate,
|
|
created_at=datetime.now(),
|
|
progress=TaskProgress(
|
|
current_step="initialized",
|
|
total_steps=8,
|
|
completed_steps=0,
|
|
percentage=0.0,
|
|
message="Task initialized"
|
|
)
|
|
)
|
|
self.task_statuses[task_id] = task_status
|
|
return task_status
|
|
|
|
|
|
def _update_task_status(self, task_id: str, **kwargs):
|
|
"""Update task status with new values"""
|
|
with self._task_status_lock:
|
|
if task_id in self.task_statuses:
|
|
task_status = self.task_statuses[task_id]
|
|
for key, value in kwargs.items():
|
|
if hasattr(task_status, key):
|
|
setattr(task_status, key, value)
|
|
self.task_statuses[task_id] = task_status
|
|
|
|
|
|
def _update_task_progress(self, task_id: str, current_step: str, completed_steps: int,
|
|
message: str = None, percentage: float = None):
|
|
"""Update task progress"""
|
|
with self._task_status_lock:
|
|
if task_id in self.task_statuses:
|
|
task_status = self.task_statuses[task_id]
|
|
if percentage is not None:
|
|
progress_percentage = percentage
|
|
else:
|
|
progress_percentage = (completed_steps / task_status.progress.total_steps) * 100 if task_status.progress else 0.0
|
|
|
|
task_status.progress = TaskProgress(
|
|
current_step=current_step,
|
|
total_steps=task_status.progress.total_steps if task_status.progress else 7,
|
|
completed_steps=completed_steps,
|
|
percentage=progress_percentage,
|
|
message=message
|
|
)
|
|
self.task_statuses[task_id] = task_status
|
|
|
|
|
|
def get_task_status(self, task_id: str) -> Optional[DatasetBuildTaskStatus]:
|
|
"""Get task status by task ID"""
|
|
with self._task_status_lock:
|
|
return self.task_statuses.get(task_id)
|
|
|
|
|
|
def list_tasks(self, status_filter: Optional[TaskStatus] = None) -> List[DatasetBuildTaskStatus]:
|
|
"""List all tasks, optionally filtered by status"""
|
|
with self._task_status_lock:
|
|
tasks = list(self.task_statuses.values())
|
|
if status_filter:
|
|
tasks = [task for task in tasks if task.status == status_filter]
|
|
# Sort by creation time (newest first)
|
|
tasks.sort(key=lambda x: x.created_at, reverse=True)
|
|
return tasks
|
|
|
|
|
|
def get_task_statistics(self) -> Dict[str, Any]:
|
|
"""Get statistics about all tasks"""
|
|
with self._task_status_lock:
|
|
total_tasks = len(self.task_statuses)
|
|
status_counts = {
|
|
TaskStatus.PENDING: 0,
|
|
TaskStatus.RUNNING: 0,
|
|
TaskStatus.COMPLETED: 0,
|
|
TaskStatus.FAILED: 0,
|
|
TaskStatus.CANCELLED: 0
|
|
}
|
|
|
|
for task_status in self.task_statuses.values():
|
|
status_counts[task_status.status] += 1
|
|
|
|
return {
|
|
"total_tasks": total_tasks,
|
|
"status_counts": status_counts,
|
|
"running_tasks": status_counts[TaskStatus.RUNNING]
|
|
}
|
|
|
|
|
|
async def cleanup_completed_tasks(self, max_age_hours: int = 24):
|
|
"""Clean up completed/failed tasks older than specified hours"""
|
|
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
|
|
cleaned_count = 0
|
|
|
|
with self._task_status_lock:
|
|
tasks_to_remove = []
|
|
|
|
for task_id, task_status in self.task_statuses.items():
|
|
if task_status.completed_at:
|
|
try:
|
|
completed_time = task_status.completed_at.timestamp()
|
|
if completed_time < cutoff_time:
|
|
tasks_to_remove.append(task_id)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to check completion time for task {task_id}: {e}")
|
|
|
|
for task_id in tasks_to_remove:
|
|
# Remove from task statuses
|
|
del self.task_statuses[task_id]
|
|
|
|
# Remove from running tasks if still there
|
|
if task_id in self.running_tasks:
|
|
del self.running_tasks[task_id]
|
|
|
|
cleaned_count += 1
|
|
|
|
if cleaned_count > 0:
|
|
logger.info(f"Cleaned up {cleaned_count} old tasks")
|
|
|
|
return cleaned_count
|
|
|
|
|
|
async def build_dataset_with_task_tracking(self, task_id: str, dataset_id: str, aid_list: List[int],
|
|
embedding_model: str, force_regenerate: bool = False,
|
|
description: Optional[str] = None) -> str:
|
|
"""
|
|
Build dataset with task status tracking
|
|
|
|
Steps:
|
|
1. Initialize task status
|
|
2. Select embedding model (from TOML config)
|
|
3. Pull raw text from database
|
|
4. Preprocess (placeholder for now)
|
|
5. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
|
|
6. Write to embeddings table
|
|
7. Pull all needed embeddings to create dataset with format: embeddings, label
|
|
"""
|
|
|
|
# Update task status to running
|
|
self._update_task_status(task_id, status=TaskStatus.RUNNING, started_at=datetime.now())
|
|
|
|
try:
|
|
logger.info(f"Starting dataset building task {dataset_id} (task_id: {task_id})")
|
|
|
|
# Step 1: Get model configuration
|
|
EMBEDDING_MODELS = config_loader.get_embedding_models()
|
|
if embedding_model not in EMBEDDING_MODELS:
|
|
raise ValueError(f"Invalid embedding model: {embedding_model}")
|
|
|
|
model_config = EMBEDDING_MODELS[embedding_model]
|
|
self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database", 3)
|
|
|
|
# Step 2: Get video metadata from database
|
|
metadata = await self.db_manager.get_video_metadata(aid_list)
|
|
self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database", 15)
|
|
|
|
# Step 3: Get user labels
|
|
labels = await self.db_manager.get_user_labels(aid_list)
|
|
self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums", 30)
|
|
|
|
# Step 4: Prepare text data and checksums
|
|
text_data = []
|
|
total_aids = len(aid_list)
|
|
|
|
for i, aid in enumerate(aid_list):
|
|
if aid in metadata:
|
|
# Combine title, description, tags
|
|
combined_text = self.embedding_service.combine_video_text(
|
|
metadata[aid]['title'],
|
|
metadata[aid]['description'],
|
|
metadata[aid]['tags']
|
|
)
|
|
|
|
# Create checksum for deduplication
|
|
checksum = self.embedding_service.create_text_checksum(combined_text)
|
|
|
|
text_data.append({
|
|
'aid': aid,
|
|
'text': combined_text,
|
|
'checksum': checksum
|
|
})
|
|
|
|
# Update progress for text preparation
|
|
if i % 10 == 0 or i == total_aids - 1: # Update every 10 items or at the end
|
|
progress_pct = 3 + (i + 1) / total_aids
|
|
self._update_task_progress(
|
|
task_id,
|
|
"preparing_text",
|
|
min(3, int(progress_pct)),
|
|
f"Prepared {i + 1}/{total_aids} text entries"
|
|
)
|
|
|
|
self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings", 33)
|
|
|
|
# Step 5: Check existing embeddings
|
|
checksums = [item['checksum'] for item in text_data]
|
|
existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model)
|
|
|
|
# Step 6: Generate new embeddings for texts that don't have them
|
|
new_embeddings_needed = []
|
|
for item in text_data:
|
|
if item['checksum'] not in existing_embeddings or force_regenerate:
|
|
new_embeddings_needed.append(item['text'])
|
|
|
|
new_embeddings_count = 0
|
|
if new_embeddings_needed:
|
|
self._update_task_progress(
|
|
task_id,
|
|
"generating_embeddings",
|
|
5,
|
|
f"Generating {len(new_embeddings_needed)} new embeddings",
|
|
50
|
|
)
|
|
|
|
# Create progress callback for embedding generation
|
|
def embedding_progress_callback(current: int, total: int):
|
|
percentage = (current / total) * 25 + 50 # Map to progress range 5-6
|
|
self._update_task_progress(
|
|
task_id,
|
|
"generating_embeddings",
|
|
5,
|
|
f"Generated {current}/{total} embeddings",
|
|
percentage
|
|
)
|
|
|
|
logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
|
|
generated_embeddings = await self.embedding_service.generate_embeddings_batch(
|
|
new_embeddings_needed,
|
|
embedding_model,
|
|
progress_callback=embedding_progress_callback
|
|
)
|
|
|
|
# Step 7: Store new embeddings in database
|
|
embeddings_to_store = []
|
|
for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)):
|
|
checksum = self.embedding_service.create_text_checksum(text)
|
|
embeddings_to_store.append({
|
|
'model_name': embedding_model,
|
|
'checksum': checksum,
|
|
'dimensions': model_config.dimensions,
|
|
'vector': embedding
|
|
})
|
|
|
|
self._update_task_progress(
|
|
task_id,
|
|
"inserting_embeddings",
|
|
6,
|
|
f"Inserting {len(embeddings_to_store)} embeddings into database",
|
|
75
|
|
)
|
|
|
|
await self.db_manager.insert_embeddings(embeddings_to_store)
|
|
new_embeddings_count = len(embeddings_to_store)
|
|
|
|
# Update existing embeddings cache
|
|
for emb_data in embeddings_to_store:
|
|
existing_embeddings[emb_data['checksum']] = {
|
|
'checksum': emb_data['checksum'],
|
|
f'vec_{model_config.dimensions}': emb_data['vector']
|
|
}
|
|
|
|
self._update_task_progress(task_id, "building_dataset", 7, "Building final dataset", 95)
|
|
|
|
# Step 8: Build final dataset
|
|
dataset = []
|
|
inconsistent_count = 0
|
|
|
|
for i, item in enumerate(text_data):
|
|
aid = item['aid']
|
|
checksum = item['checksum']
|
|
|
|
# Get embedding vector
|
|
embedding_vector = None
|
|
if checksum in existing_embeddings:
|
|
vec_key = f'vec_{model_config.dimensions}'
|
|
if vec_key in existing_embeddings[checksum]:
|
|
embedding_vector = existing_embeddings[checksum][vec_key]
|
|
|
|
# Get labels for this aid
|
|
aid_labels = labels.get(aid, [])
|
|
|
|
# Determine final label using consensus (majority vote)
|
|
final_label = None
|
|
if aid_labels:
|
|
positive_votes = sum(1 for lbl in aid_labels if lbl['label'])
|
|
final_label = positive_votes > len(aid_labels) / 2
|
|
|
|
# Check for inconsistent labels
|
|
inconsistent = len(aid_labels) > 1 and (
|
|
sum(1 for lbl in aid_labels if lbl['label']) != 0 and
|
|
sum(1 for lbl in aid_labels if lbl['label']) != len(aid_labels)
|
|
)
|
|
|
|
if inconsistent:
|
|
inconsistent_count += 1
|
|
|
|
if embedding_vector and final_label is not None:
|
|
dataset.append({
|
|
'aid': aid,
|
|
'embedding': embedding_vector,
|
|
'label': final_label,
|
|
'metadata': metadata.get(aid, {}),
|
|
'user_labels': aid_labels,
|
|
'inconsistent': inconsistent,
|
|
'text_checksum': checksum
|
|
})
|
|
|
|
# Update progress for dataset building
|
|
if i % 10 == 0 or i == len(text_data) - 1: # Update every 10 items or at the end
|
|
progress_pct = 7 + (i + 1) / len(text_data)
|
|
self._update_task_progress(
|
|
task_id,
|
|
"building_dataset",
|
|
min(7, int(progress_pct)),
|
|
f"Built {i + 1}/{len(text_data)} dataset records",
|
|
100
|
|
)
|
|
|
|
reused_count = len(dataset) - new_embeddings_count
|
|
|
|
logger.info(f"Dataset building completed: {len(dataset)} records, {new_embeddings_count} new, {reused_count} reused, {inconsistent_count} inconsistent")
|
|
|
|
# Prepare dataset data
|
|
dataset_data = {
|
|
'dataset': dataset,
|
|
'description': description,
|
|
'stats': {
|
|
'total_records': len(dataset),
|
|
'new_embeddings': new_embeddings_count,
|
|
'reused_embeddings': reused_count,
|
|
'inconsistent_labels': inconsistent_count,
|
|
'embedding_model': embedding_model
|
|
},
|
|
'created_at': datetime.now().isoformat()
|
|
}
|
|
|
|
self._update_task_progress(task_id, "saving_dataset", 8, "Saving dataset to storage")
|
|
|
|
# Save to file using efficient storage (Parquet format)
|
|
if self.storage.save_dataset(dataset_id, dataset_data['dataset'], description, dataset_data['stats']):
|
|
logger.info(f"Dataset {dataset_id} saved to efficient storage (Parquet format)")
|
|
else:
|
|
logger.warning(f"Failed to save dataset {dataset_id} to storage")
|
|
|
|
# Update task status to completed
|
|
result = {
|
|
'dataset_id': dataset_id,
|
|
'stats': dataset_data['stats']
|
|
}
|
|
|
|
self._update_task_status(
|
|
task_id,
|
|
status=TaskStatus.COMPLETED,
|
|
completed_at=datetime.now(),
|
|
result=result,
|
|
progress=TaskProgress(
|
|
current_step="completed",
|
|
total_steps=8,
|
|
completed_steps=8,
|
|
percentage=100.0,
|
|
message="Dataset building completed successfully"
|
|
)
|
|
)
|
|
|
|
return dataset_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Dataset building failed for {dataset_id}: {str(e)}")
|
|
|
|
# Update task status to failed
|
|
self._update_task_status(
|
|
task_id,
|
|
status=TaskStatus.FAILED,
|
|
completed_at=datetime.now(),
|
|
error_message=str(e),
|
|
progress=TaskProgress(
|
|
current_step="failed",
|
|
total_steps=8,
|
|
completed_steps=0,
|
|
percentage=0.0,
|
|
message=f"Task failed: {str(e)}"
|
|
)
|
|
)
|
|
|
|
# Store error information
|
|
error_data = {
|
|
'error': str(e),
|
|
'created_at': datetime.now().isoformat()
|
|
}
|
|
|
|
# Try to save error using new storage
|
|
self.storage.save_dataset(dataset_id, [], description=str(e), stats={'error': True})
|
|
raise
|
|
|
|
|
|
async def start_dataset_build_task(self, dataset_id: str, aid_list: List[int],
|
|
embedding_model: str, force_regenerate: bool = False,
|
|
description: Optional[str] = None) -> str:
|
|
"""
|
|
Start a dataset building task and return task ID for status tracking
|
|
"""
|
|
import uuid
|
|
task_id = str(uuid.uuid4())
|
|
|
|
# Create task status
|
|
task_status = self._create_task_status(task_id, dataset_id, aid_list, embedding_model, force_regenerate)
|
|
|
|
# Start the actual task
|
|
task = asyncio.create_task(
|
|
self.build_dataset_with_task_tracking(task_id, dataset_id, aid_list, embedding_model, force_regenerate, description)
|
|
)
|
|
|
|
# Store the running task
|
|
with self._task_status_lock:
|
|
self.running_tasks[task_id] = task
|
|
|
|
return task_id
|
|
|
|
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get built dataset by ID"""
|
|
# Use pure Parquet storage interface
|
|
return self.storage.load_dataset_full(dataset_id)
|
|
|
|
def dataset_exists(self, dataset_id: str) -> bool:
|
|
"""Check if dataset exists"""
|
|
return self.storage.dataset_exists(dataset_id)
|
|
|
|
def delete_dataset(self, dataset_id: str) -> bool:
|
|
"""Delete dataset from both memory and file system"""
|
|
return self.storage.delete_dataset(dataset_id)
|
|
|
|
def list_datasets(self) -> List[Dict[str, Any]]:
|
|
"""List all datasets with their basic information"""
|
|
return self.storage.list_datasets()
|
|
|
|
def get_dataset_stats(self) -> Dict[str, Any]:
|
|
"""Get overall statistics about stored datasets"""
|
|
return self.storage.get_dataset_stats()
|
|
|
|
def load_dataset_metadata_fast(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
|
return self.storage.load_dataset_metadata(dataset_id)
|
|
|
|
def load_dataset_partial(self, dataset_id: str, columns: Optional[List[str]] = None,
|
|
filters: Optional[Dict[str, Any]] = None):
|
|
return self.storage.load_dataset_partial(dataset_id, columns, filters) |