1
0
cvsa/ml_new/training/dataset_service.py

373 lines
15 KiB
Python

"""
Dataset building service - handles the complete dataset construction flow
"""
import os
import json
import logging
import asyncio
from pathlib import Path
from typing import List, Dict, Any, Optional
from datetime import datetime
from database import DatabaseManager
from embedding_service import EmbeddingService
from config_loader import config_loader
logger = logging.getLogger(__name__)
class DatasetBuilder:
"""Service for building datasets with the specified flow"""
def __init__(self, db_manager: DatabaseManager, embedding_service: EmbeddingService, storage_dir: str = "datasets"):
self.db_manager = db_manager
self.embedding_service = embedding_service
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(exist_ok=True)
# Load existing datasets from file system
self.dataset_storage: Dict[str, Dict] = self._load_all_datasets()
def _get_dataset_file_path(self, dataset_id: str) -> Path:
"""Get file path for dataset"""
return self.storage_dir / f"{dataset_id}.json"
def _load_dataset_from_file(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Load dataset from file"""
file_path = self._get_dataset_file_path(dataset_id)
if not file_path.exists():
return None
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load dataset {dataset_id} from file: {e}")
return None
def _save_dataset_to_file(self, dataset_id: str, dataset_data: Dict[str, Any]) -> bool:
"""Save dataset to file"""
file_path = self._get_dataset_file_path(dataset_id)
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(dataset_data, f, ensure_ascii=False, indent=2)
return True
except Exception as e:
logger.error(f"Failed to save dataset {dataset_id} to file: {e}")
return False
def _load_all_datasets(self) -> Dict[str, Dict]:
"""Load all datasets from file system"""
datasets = {}
try:
for file_path in self.storage_dir.glob("*.json"):
dataset_id = file_path.stem
dataset_data = self._load_dataset_from_file(dataset_id)
if dataset_data:
datasets[dataset_id] = dataset_data
logger.info(f"Loaded {len(datasets)} datasets from file system")
except Exception as e:
logger.error(f"Failed to load datasets from file system: {e}")
return datasets
async def cleanup_old_datasets(self, max_age_days: int = 30):
"""Remove datasets older than specified days"""
try:
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60)
removed_count = 0
for dataset_id in list(self.dataset_storage.keys()):
dataset_info = self.dataset_storage[dataset_id]
if "created_at" in dataset_info:
try:
created_time = datetime.fromisoformat(dataset_info["created_at"]).timestamp()
if created_time < cutoff_time:
# Remove from memory
del self.dataset_storage[dataset_id]
# Remove file
file_path = self._get_dataset_file_path(dataset_id)
if file_path.exists():
file_path.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"Failed to process dataset {dataset_id} for cleanup: {e}")
if removed_count > 0:
logger.info(f"Cleaned up {removed_count} old datasets")
except Exception as e:
logger.error(f"Failed to cleanup old datasets: {e}")
async def build_dataset(self, dataset_id: str, aid_list: List[int], embedding_model: str, force_regenerate: bool = False) -> str:
"""
Build dataset with the specified flow:
1. Select embedding model (from TOML config)
2. Pull raw text from database
3. Preprocess (placeholder for now)
4. Batch get embeddings (deduplicate by hash, skip if already in embeddings table)
5. Write to embeddings table
6. Pull all needed embeddings to create dataset with format: embeddings, label
"""
try:
logger.info(f"Starting dataset building task {dataset_id}")
EMBEDDING_MODELS = config_loader.get_embedding_models()
# Get model configuration
if embedding_model not in EMBEDDING_MODELS:
raise ValueError(f"Invalid embedding model: {embedding_model}")
model_config = EMBEDDING_MODELS[embedding_model]
# Step 1: Get video metadata from database
metadata = await self.db_manager.get_video_metadata(aid_list)
# Step 2: Get user labels
labels = await self.db_manager.get_user_labels(aid_list)
# Step 3: Prepare text data and checksums
text_data = []
for aid in aid_list:
if aid in metadata:
# Combine title, description, tags
combined_text = self.embedding_service.combine_video_text(
metadata[aid]['title'],
metadata[aid]['description'],
metadata[aid]['tags']
)
# Create checksum for deduplication
checksum = self.embedding_service.create_text_checksum(combined_text)
text_data.append({
'aid': aid,
'text': combined_text,
'checksum': checksum
})
# Step 4: Check existing embeddings
checksums = [item['checksum'] for item in text_data]
existing_embeddings = await self.db_manager.get_existing_embeddings(checksums, embedding_model)
# Step 5: Generate new embeddings for texts that don't have them
new_embeddings_needed = []
for item in text_data:
if item['checksum'] not in existing_embeddings or force_regenerate:
new_embeddings_needed.append(item['text'])
new_embeddings_count = 0
if new_embeddings_needed:
logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
generated_embeddings = await self.embedding_service.generate_embeddings_batch(
new_embeddings_needed,
embedding_model
)
# Step 6: Store new embeddings in database
embeddings_to_store = []
for i, (text, embedding) in enumerate(zip(new_embeddings_needed, generated_embeddings)):
checksum = self.embedding_service.create_text_checksum(text)
embeddings_to_store.append({
'model_name': embedding_model,
'checksum': checksum,
'dimensions': model_config.dimensions,
'vector': embedding
})
await self.db_manager.insert_embeddings(embeddings_to_store)
new_embeddings_count = len(embeddings_to_store)
# Update existing embeddings cache
for emb_data in embeddings_to_store:
existing_embeddings[emb_data['checksum']] = {
'checksum': emb_data['checksum'],
f'vec_{model_config.dimensions}': emb_data['vector']
}
# Step 7: Build final dataset
dataset = []
inconsistent_count = 0
for item in text_data:
aid = item['aid']
checksum = item['checksum']
# Get embedding vector
embedding_vector = None
if checksum in existing_embeddings:
vec_key = f'vec_{model_config.dimensions}'
if vec_key in existing_embeddings[checksum]:
embedding_vector = existing_embeddings[checksum][vec_key]
# Get labels for this aid
aid_labels = labels.get(aid, [])
# Determine final label using consensus (majority vote)
final_label = None
if aid_labels:
positive_votes = sum(1 for lbl in aid_labels if lbl['label'])
final_label = positive_votes > len(aid_labels) / 2
# Check for inconsistent labels
inconsistent = len(aid_labels) > 1 and (
sum(1 for lbl in aid_labels if lbl['label']) != 0 and
sum(1 for lbl in aid_labels if lbl['label']) != len(aid_labels)
)
if inconsistent:
inconsistent_count += 1
if embedding_vector and final_label is not None:
dataset.append({
'aid': aid,
'embedding': embedding_vector,
'label': final_label,
'metadata': metadata.get(aid, {}),
'user_labels': aid_labels,
'inconsistent': inconsistent,
'text_checksum': checksum
})
reused_count = len(dataset) - new_embeddings_count
logger.info(f"Dataset building completed: {len(dataset)} records, {new_embeddings_count} new, {reused_count} reused, {inconsistent_count} inconsistent")
# Prepare dataset data
dataset_data = {
'dataset': dataset,
'stats': {
'total_records': len(dataset),
'new_embeddings': new_embeddings_count,
'reused_embeddings': reused_count,
'inconsistent_labels': inconsistent_count,
'embedding_model': embedding_model
},
'created_at': datetime.now().isoformat()
}
# Save to file and memory cache
if self._save_dataset_to_file(dataset_id, dataset_data):
self.dataset_storage[dataset_id] = dataset_data
logger.info(f"Dataset {dataset_id} saved to file system")
else:
logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only")
self.dataset_storage[dataset_id] = dataset_data
return dataset_id
except Exception as e:
logger.error(f"Dataset building failed for {dataset_id}: {str(e)}")
# Store error information
error_data = {
'error': str(e),
'created_at': datetime.now().isoformat()
}
# Try to save error to file as well
self._save_dataset_to_file(dataset_id, error_data)
self.dataset_storage[dataset_id] = error_data
raise
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get built dataset by ID"""
# First check memory cache
if dataset_id in self.dataset_storage:
return self.dataset_storage[dataset_id]
# If not in memory, try to load from file
dataset_data = self._load_dataset_from_file(dataset_id)
if dataset_data:
# Add to memory cache
self.dataset_storage[dataset_id] = dataset_data
return dataset_data
return None
def dataset_exists(self, dataset_id: str) -> bool:
"""Check if dataset exists"""
# Check memory cache first
if dataset_id in self.dataset_storage:
return True
# Check file system
return self._get_dataset_file_path(dataset_id).exists()
def delete_dataset(self, dataset_id: str) -> bool:
"""Delete dataset from both memory and file system"""
try:
# Remove from memory
if dataset_id in self.dataset_storage:
del self.dataset_storage[dataset_id]
# Remove file
file_path = self._get_dataset_file_path(dataset_id)
if file_path.exists():
file_path.unlink()
logger.info(f"Dataset {dataset_id} deleted from file system")
return True
else:
logger.warning(f"Dataset file {dataset_id} not found for deletion")
return False
except Exception as e:
logger.error(f"Failed to delete dataset {dataset_id}: {e}")
return False
def list_datasets(self) -> List[Dict[str, Any]]:
"""List all datasets with their basic information"""
datasets = []
for dataset_id, dataset_info in self.dataset_storage.items():
if "error" not in dataset_info:
datasets.append({
"dataset_id": dataset_id,
"stats": dataset_info["stats"],
"created_at": dataset_info["created_at"]
})
# Sort by creation time (newest first)
datasets.sort(key=lambda x: x["created_at"], reverse=True)
return datasets
def get_dataset_stats(self) -> Dict[str, Any]:
"""Get overall statistics about stored datasets"""
total_datasets = len(self.dataset_storage)
error_datasets = sum(1 for data in self.dataset_storage.values() if "error" in data)
valid_datasets = total_datasets - error_datasets
total_records = 0
total_new_embeddings = 0
total_reused_embeddings = 0
for dataset_info in self.dataset_storage.values():
if "stats" in dataset_info:
stats = dataset_info["stats"]
total_records += stats.get("total_records", 0)
total_new_embeddings += stats.get("new_embeddings", 0)
total_reused_embeddings += stats.get("reused_embeddings", 0)
return {
"total_datasets": total_datasets,
"valid_datasets": valid_datasets,
"error_datasets": error_datasets,
"total_records": total_records,
"total_new_embeddings": total_new_embeddings,
"total_reused_embeddings": total_reused_embeddings,
"storage_directory": str(self.storage_dir)
}