1
0
cvsa/ml_new/training/dataset_service.py

611 lines
25 KiB
Python

"""
Dataset building service - handles the complete dataset construction flow
"""
import json
import asyncio
from pathlib import Path
from typing import List, Dict, Any, Optional
from datetime import datetime
import threading
from database import DatabaseManager
from embedding_service import EmbeddingService
from config_loader import config_loader
from logger_config import get_logger
from models import TaskStatus, DatasetBuildTaskStatus, TaskProgress
logger = get_logger(__name__)
class DatasetBuilder:
"""Service for building datasets with the specified flow"""
def __init__(self, db_manager: DatabaseManager, embedding_service: EmbeddingService, storage_dir: str = "datasets"):
self.db_manager = db_manager
self.embedding_service = embedding_service
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(exist_ok=True)
# Load existing datasets from file system
self.dataset_storage: Dict[str, Dict] = self._load_all_datasets()
# Task status tracking
self._task_status_lock = threading.Lock()
self.task_statuses: Dict[str, DatasetBuildTaskStatus] = {}
self.running_tasks: Dict[str, asyncio.Task] = {}
def _get_dataset_file_path(self, dataset_id: str) -> Path:
"""Get file path for dataset"""
return self.storage_dir / f"{dataset_id}.json"
def _load_dataset_from_file(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Load dataset from file"""
file_path = self._get_dataset_file_path(dataset_id)
if not file_path.exists():
return None
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load dataset {dataset_id} from file: {e}")
return None
def _save_dataset_to_file(self, dataset_id: str, dataset_data: Dict[str, Any]) -> bool:
"""Save dataset to file"""
file_path = self._get_dataset_file_path(dataset_id)
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(dataset_data, f, ensure_ascii=False, indent=2)
return True
except Exception as e:
logger.error(f"Failed to save dataset {dataset_id} to file: {e}")
return False
def _load_all_datasets(self) -> Dict[str, Dict]:
"""Load all datasets from file system"""
datasets = {}
try:
for file_path in self.storage_dir.glob("*.json"):
dataset_id = file_path.stem
dataset_data = self._load_dataset_from_file(dataset_id)
if dataset_data:
datasets[dataset_id] = dataset_data
logger.info(f"Loaded {len(datasets)} datasets from file system")
except Exception as e:
logger.error(f"Failed to load datasets from file system: {e}")
return datasets
def _create_task_status(self, task_id: str, dataset_id: str, aid_list: List[int],
embedding_model: str, force_regenerate: bool) -> DatasetBuildTaskStatus:
"""Create initial task status"""
with self._task_status_lock:
task_status = DatasetBuildTaskStatus(
task_id=task_id,
status=TaskStatus.PENDING,
dataset_id=dataset_id,
aid_list=aid_list,
embedding_model=embedding_model,
force_regenerate=force_regenerate,
created_at=datetime.now(),
progress=TaskProgress(
current_step="initialized",
total_steps=7,
completed_steps=0,
percentage=0.0,
message="Task initialized"
)
)
self.task_statuses[task_id] = task_status
return task_status
def _update_task_status(self, task_id: str, **kwargs):
"""Update task status with new values"""
with self._task_status_lock:
if task_id in self.task_statuses:
task_status = self.task_statuses[task_id]
for key, value in kwargs.items():
if hasattr(task_status, key):
setattr(task_status, key, value)
self.task_statuses[task_id] = task_status
def _update_task_progress(self, task_id: str, current_step: str, completed_steps: int,
message: str = None, percentage: float = None):
"""Update task progress"""
with self._task_status_lock:
if task_id in self.task_statuses:
task_status = self.task_statuses[task_id]
if percentage is not None:
progress_percentage = percentage
else:
progress_percentage = (completed_steps / task_status.progress.total_steps) * 100 if task_status.progress else 0.0
task_status.progress = TaskProgress(
current_step=current_step,
total_steps=task_status.progress.total_steps if task_status.progress else 7,
completed_steps=completed_steps,
percentage=progress_percentage,
message=message
)
self.task_statuses[task_id] = task_status
def get_task_status(self, task_id: str) -> Optional[DatasetBuildTaskStatus]:
"""Get task status by task ID"""
with self._task_status_lock:
return self.task_statuses.get(task_id)
def list_tasks(self, status_filter: Optional[TaskStatus] = None) -> List[DatasetBuildTaskStatus]:
"""List all tasks, optionally filtered by status"""
with self._task_status_lock:
tasks = list(self.task_statuses.values())
if status_filter:
tasks = [task for task in tasks if task.status == status_filter]
# Sort by creation time (newest first)
tasks.sort(key=lambda x: x.created_at, reverse=True)
return tasks
def get_task_statistics(self) -> Dict[str, Any]:
"""Get statistics about all tasks"""
with self._task_status_lock:
total_tasks = len(self.task_statuses)
status_counts = {
TaskStatus.PENDING: 0,
TaskStatus.RUNNING: 0,
TaskStatus.COMPLETED: 0,
TaskStatus.FAILED: 0,
TaskStatus.CANCELLED: 0
}
for task_status in self.task_statuses.values():
status_counts[task_status.status] += 1
return {
"total_tasks": total_tasks,
"status_counts": status_counts,
"running_tasks": status_counts[TaskStatus.RUNNING]
}
async def cleanup_completed_tasks(self, max_age_hours: int = 24):
"""Clean up completed/failed tasks older than specified hours"""
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
cleaned_count = 0
with self._task_status_lock:
tasks_to_remove = []
for task_id, task_status in self.task_statuses.items():
if task_status.completed_at:
try:
completed_time = task_status.completed_at.timestamp()
if completed_time < cutoff_time:
tasks_to_remove.append(task_id)
except Exception as e:
logger.warning(f"Failed to check completion time for task {task_id}: {e}")
for task_id in tasks_to_remove:
# Remove from task statuses
del self.task_statuses[task_id]
# Remove from running tasks if still there
if task_id in self.running_tasks:
del self.running_tasks[task_id]
cleaned_count += 1
if cleaned_count > 0:
logger.info(f"Cleaned up {cleaned_count} old tasks")
return cleaned_count
async def cleanup_old_datasets(self, max_age_days: int = 30):
"""Remove datasets older than specified days"""
try:
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60)
removed_count = 0
for dataset_id in list(self.dataset_storage.keys()):
dataset_info = self.dataset_storage[dataset_id]
if "created_at" in dataset_info:
try:
created_time = datetime.fromisoformat(dataset_info["created_at"]).timestamp()
if created_time < cutoff_time:
# Remove from memory
del self.dataset_storage[dataset_id]
# Remove file
file_path = self._get_dataset_file_path(dataset_id)
if file_path.exists():
file_path.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"Failed to process dataset {dataset_id} for cleanup: {e}")
if removed_count > 0:
logger.info(f"Cleaned up {removed_count} old datasets")
except Exception as e:
logger.error(f"Failed to cleanup old datasets: {e}")
async def build_dataset_with_task_tracking(self, task_id: str, dataset_id: str, aid_list: List[int],
embedding_model: str, force_regenerate: bool = False,
description: Optional[str] = None) -> str:
"""
Build dataset with task status tracking
Steps:
1. Initialize task status
2. Select embedding model (from TOML config)
3. Pull raw text from database
4. Preprocess (placeholder for now)
5. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
6. Write to embeddings table
7. Pull all needed embeddings to create dataset with format: embeddings, label
"""
# Update task status to running
self._update_task_status(task_id, status=TaskStatus.RUNNING, started_at=datetime.now())
try:
logger.info(f"Starting dataset building task {dataset_id} (task_id: {task_id})")
# Step 1: Get model configuration
EMBEDDING_MODELS = config_loader.get_embedding_models()
if embedding_model not in EMBEDDING_MODELS:
raise ValueError(f"Invalid embedding model: {embedding_model}")
model_config = EMBEDDING_MODELS[embedding_model]
self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database")
# Step 2: Get video metadata from database
metadata = await self.db_manager.get_video_metadata(aid_list)
self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database")
# Step 3: Get user labels
labels = await self.db_manager.get_user_labels(aid_list)
self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums")
# Step 4: Prepare text data and checksums
text_data = []
total_aids = len(aid_list)
for i, aid in enumerate(aid_list):
if aid in metadata:
# Combine title, description, tags
combined_text = self.embedding_service.combine_video_text(
metadata[aid]['title'],
metadata[aid]['description'],
metadata[aid]['tags']
)
# Create checksum for deduplication
checksum = self.embedding_service.create_text_checksum(combined_text)
text_data.append({
'aid': aid,
'text': combined_text,
'checksum': checksum
})
# Update progress for text preparation
if i % 10 == 0 or i == total_aids - 1: # Update every 10 items or at the end
progress_pct = 3 + (i + 1) / total_aids
self._update_task_progress(
task_id,
"preparing_text",
min(3, int(progress_pct)),
f"Prepared {i + 1}/{total_aids} text entries"
)
self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings")
# Step 5: Check existing embeddings
checksums = [item['checksum'] for item in text_data]
existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model)
# Step 6: Generate new embeddings for texts that don't have them
new_embeddings_needed = []
for item in text_data:
if item['checksum'] not in existing_embeddings or force_regenerate:
new_embeddings_needed.append(item['text'])
new_embeddings_count = 0
if new_embeddings_needed:
self._update_task_progress(
task_id,
"generating_embeddings",
5,
f"Generating {len(new_embeddings_needed)} new embeddings"
)
logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
generated_embeddings = await self.embedding_service.generate_embeddings_batch(
new_embeddings_needed,
embedding_model
)
# Step 7: Store new embeddings in database
embeddings_to_store = []
for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)):
checksum = self.embedding_service.create_text_checksum(text)
embeddings_to_store.append({
'model_name': embedding_model,
'checksum': checksum,
'dimensions': model_config.dimensions,
'vector': embedding
})
await self.db_manager.insert_embeddings(embeddings_to_store)
new_embeddings_count = len(embeddings_to_store)
# Update existing embeddings cache
for emb_data in embeddings_to_store:
existing_embeddings[emb_data['checksum']] = {
'checksum': emb_data['checksum'],
f'vec_{model_config.dimensions}': emb_data['vector']
}
self._update_task_progress(task_id, "building_dataset", 6, "Building final dataset")
# Step 8: Build final dataset
dataset = []
inconsistent_count = 0
for i, item in enumerate(text_data):
aid = item['aid']
checksum = item['checksum']
# Get embedding vector
embedding_vector = None
if checksum in existing_embeddings:
vec_key = f'vec_{model_config.dimensions}'
if vec_key in existing_embeddings[checksum]:
embedding_vector = existing_embeddings[checksum][vec_key]
# Get labels for this aid
aid_labels = labels.get(aid, [])
# Determine final label using consensus (majority vote)
final_label = None
if aid_labels:
positive_votes = sum(1 for lbl in aid_labels if lbl['label'])
final_label = positive_votes > len(aid_labels) / 2
# Check for inconsistent labels
inconsistent = len(aid_labels) > 1 and (
sum(1 for lbl in aid_labels if lbl['label']) != 0 and
sum(1 for lbl in aid_labels if lbl['label']) != len(aid_labels)
)
if inconsistent:
inconsistent_count += 1
if embedding_vector and final_label is not None:
dataset.append({
'aid': aid,
'embedding': embedding_vector,
'label': final_label,
'metadata': metadata.get(aid, {}),
'user_labels': aid_labels,
'inconsistent': inconsistent,
'text_checksum': checksum
})
# Update progress for dataset building
if i % 10 == 0 or i == len(text_data) - 1: # Update every 10 items or at the end
progress_pct = 6 + (i + 1) / len(text_data)
self._update_task_progress(
task_id,
"building_dataset",
min(6, int(progress_pct)),
f"Built {i + 1}/{len(text_data)} dataset records"
)
reused_count = len(dataset) - new_embeddings_count
logger.info(f"Dataset building completed: {len(dataset)} records, {new_embeddings_count} new, {reused_count} reused, {inconsistent_count} inconsistent")
# Prepare dataset data
dataset_data = {
'dataset': dataset,
'description': description,
'stats': {
'total_records': len(dataset),
'new_embeddings': new_embeddings_count,
'reused_embeddings': reused_count,
'inconsistent_labels': inconsistent_count,
'embedding_model': embedding_model
},
'created_at': datetime.now().isoformat()
}
self._update_task_progress(task_id, "saving_dataset", 7, "Saving dataset to storage")
# Save to file and memory cache
if self._save_dataset_to_file(dataset_id, dataset_data):
self.dataset_storage[dataset_id] = dataset_data
logger.info(f"Dataset {dataset_id} saved to file system")
else:
logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only")
self.dataset_storage[dataset_id] = dataset_data
# Update task status to completed
result = {
'dataset_id': dataset_id,
'stats': dataset_data['stats']
}
self._update_task_status(
task_id,
status=TaskStatus.COMPLETED,
completed_at=datetime.now(),
result=result,
progress=TaskProgress(
current_step="completed",
total_steps=7,
completed_steps=7,
percentage=100.0,
message="Dataset building completed successfully"
)
)
return dataset_id
except Exception as e:
logger.error(f"Dataset building failed for {dataset_id}: {str(e)}")
# Update task status to failed
self._update_task_status(
task_id,
status=TaskStatus.FAILED,
completed_at=datetime.now(),
error_message=str(e),
progress=TaskProgress(
current_step="failed",
total_steps=7,
completed_steps=0,
percentage=0.0,
message=f"Task failed: {str(e)}"
)
)
# Store error information
error_data = {
'error': str(e),
'created_at': datetime.now().isoformat()
}
# Try to save error to file as well
self._save_dataset_to_file(dataset_id, error_data)
self.dataset_storage[dataset_id] = error_data
raise
async def start_dataset_build_task(self, dataset_id: str, aid_list: List[int],
embedding_model: str, force_regenerate: bool = False,
description: Optional[str] = None) -> str:
"""
Start a dataset building task and return task ID for status tracking
"""
import uuid
task_id = str(uuid.uuid4())
# Create task status
task_status = self._create_task_status(task_id, dataset_id, aid_list, embedding_model, force_regenerate)
# Start the actual task
task = asyncio.create_task(
self.build_dataset_with_task_tracking(task_id, dataset_id, aid_list, embedding_model, force_regenerate, description)
)
# Store the running task
with self._task_status_lock:
self.running_tasks[task_id] = task
return task_id
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get built dataset by ID"""
# First check memory cache
if dataset_id in self.dataset_storage:
return self.dataset_storage[dataset_id]
# If not in memory, try to load from file
dataset_data = self._load_dataset_from_file(dataset_id)
if dataset_data:
# Add to memory cache
self.dataset_storage[dataset_id] = dataset_data
return dataset_data
return None
def dataset_exists(self, dataset_id: str) -> bool:
"""Check if dataset exists"""
# Check memory cache first
if dataset_id in self.dataset_storage:
return True
# Check file system
return self._get_dataset_file_path(dataset_id).exists()
def delete_dataset(self, dataset_id: str) -> bool:
"""Delete dataset from both memory and file system"""
try:
# Remove from memory
if dataset_id in self.dataset_storage:
del self.dataset_storage[dataset_id]
# Remove file
file_path = self._get_dataset_file_path(dataset_id)
if file_path.exists():
file_path.unlink()
logger.info(f"Dataset {dataset_id} deleted from file system")
return True
else:
logger.warning(f"Dataset file {dataset_id} not found for deletion")
return False
except Exception as e:
logger.error(f"Failed to delete dataset {dataset_id}: {e}")
return False
def list_datasets(self) -> List[Dict[str, Any]]:
"""List all datasets with their basic information"""
datasets = []
for dataset_id, dataset_info in self.dataset_storage.items():
if "error" not in dataset_info:
datasets.append({
"dataset_id": dataset_id,
"stats": dataset_info["stats"],
"created_at": dataset_info["created_at"]
})
# Sort by creation time (newest first)
datasets.sort(key=lambda x: x["created_at"], reverse=True)
return datasets
def get_dataset_stats(self) -> Dict[str, Any]:
"""Get overall statistics about stored datasets"""
total_datasets = len(self.dataset_storage)
error_datasets = sum(1 for data in self.dataset_storage.values() if "error" in data)
valid_datasets = total_datasets - error_datasets
total_records = 0
total_new_embeddings = 0
total_reused_embeddings = 0
for dataset_info in self.dataset_storage.values():
if "stats" in dataset_info:
stats = dataset_info["stats"]
total_records += stats.get("total_records", 0)
total_new_embeddings += stats.get("new_embeddings", 0)
total_reused_embeddings += stats.get("reused_embeddings", 0)
return {
"total_datasets": total_datasets,
"valid_datasets": valid_datasets,
"error_datasets": error_datasets,
"total_records": total_records,
"total_new_embeddings": total_new_embeddings,
"total_reused_embeddings": total_reused_embeddings,
"storage_directory": str(self.storage_dir)
}