update: use parquet as dataset format
This commit is contained in:
parent
d3f26f7e7b
commit
24582ccaf5
@ -188,18 +188,6 @@ async def get_dataset_stats_endpoint():
|
||||
stats = dataset_builder.get_dataset_stats()
|
||||
return stats
|
||||
|
||||
|
||||
@router.post("/datasets/cleanup")
|
||||
async def cleanup_datasets_endpoint(max_age_days: int = 30):
|
||||
"""Remove datasets older than specified days"""
|
||||
|
||||
if not dataset_builder:
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
await dataset_builder.cleanup_old_datasets(max_age_days)
|
||||
return {"message": f"Cleanup completed for datasets older than {max_age_days} days"}
|
||||
|
||||
|
||||
# Task Status Endpoints
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
Dataset building service - handles the complete dataset construction flow
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
@ -14,6 +13,7 @@ from embedding_service import EmbeddingService
|
||||
from config_loader import config_loader
|
||||
from logger_config import get_logger
|
||||
from models import TaskStatus, DatasetBuildTaskStatus, TaskProgress
|
||||
from dataset_storage_parquet import ParquetDatasetStorage
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -28,63 +28,13 @@ class DatasetBuilder:
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Load existing datasets from file system
|
||||
self.dataset_storage: Dict[str, Dict] = self._load_all_datasets()
|
||||
self.storage = ParquetDatasetStorage(str(self.storage_dir))
|
||||
|
||||
# Task status tracking
|
||||
self._task_status_lock = threading.Lock()
|
||||
self.task_statuses: Dict[str, DatasetBuildTaskStatus] = {}
|
||||
self.running_tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
|
||||
def _get_dataset_file_path(self, dataset_id: str) -> Path:
|
||||
"""Get file path for dataset"""
|
||||
return self.storage_dir / f"{dataset_id}.json"
|
||||
|
||||
|
||||
def _load_dataset_from_file(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load dataset from file"""
|
||||
file_path = self._get_dataset_file_path(dataset_id)
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load dataset {dataset_id} from file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _save_dataset_to_file(self, dataset_id: str, dataset_data: Dict[str, Any]) -> bool:
|
||||
"""Save dataset to file"""
|
||||
file_path = self._get_dataset_file_path(dataset_id)
|
||||
|
||||
try:
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(dataset_data, f, ensure_ascii=False, indent=2)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save dataset {dataset_id} to file: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _load_all_datasets(self) -> Dict[str, Dict]:
|
||||
"""Load all datasets from file system"""
|
||||
datasets = {}
|
||||
|
||||
try:
|
||||
for file_path in self.storage_dir.glob("*.json"):
|
||||
dataset_id = file_path.stem
|
||||
dataset_data = self._load_dataset_from_file(dataset_id)
|
||||
if dataset_data:
|
||||
datasets[dataset_id] = dataset_data
|
||||
logger.info(f"Loaded {len(datasets)} datasets from file system")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load datasets from file system: {e}")
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
def _create_task_status(self, task_id: str, dataset_id: str, aid_list: List[int],
|
||||
embedding_model: str, force_regenerate: bool) -> DatasetBuildTaskStatus:
|
||||
@ -100,7 +50,7 @@ class DatasetBuilder:
|
||||
created_at=datetime.now(),
|
||||
progress=TaskProgress(
|
||||
current_step="initialized",
|
||||
total_steps=7,
|
||||
total_steps=8,
|
||||
completed_steps=0,
|
||||
percentage=0.0,
|
||||
message="Task initialized"
|
||||
@ -214,36 +164,6 @@ class DatasetBuilder:
|
||||
return cleaned_count
|
||||
|
||||
|
||||
async def cleanup_old_datasets(self, max_age_days: int = 30):
|
||||
"""Remove datasets older than specified days"""
|
||||
try:
|
||||
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60)
|
||||
removed_count = 0
|
||||
|
||||
for dataset_id in list(self.dataset_storage.keys()):
|
||||
dataset_info = self.dataset_storage[dataset_id]
|
||||
if "created_at" in dataset_info:
|
||||
try:
|
||||
created_time = datetime.fromisoformat(dataset_info["created_at"]).timestamp()
|
||||
if created_time < cutoff_time:
|
||||
# Remove from memory
|
||||
del self.dataset_storage[dataset_id]
|
||||
|
||||
# Remove file
|
||||
file_path = self._get_dataset_file_path(dataset_id)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
removed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process dataset {dataset_id} for cleanup: {e}")
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"Cleaned up {removed_count} old datasets")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old datasets: {e}")
|
||||
|
||||
async def build_dataset_with_task_tracking(self, task_id: str, dataset_id: str, aid_list: List[int],
|
||||
embedding_model: str, force_regenerate: bool = False,
|
||||
description: Optional[str] = None) -> str:
|
||||
@ -272,15 +192,15 @@ class DatasetBuilder:
|
||||
raise ValueError(f"Invalid embedding model: {embedding_model}")
|
||||
|
||||
model_config = EMBEDDING_MODELS[embedding_model]
|
||||
self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database")
|
||||
self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database", 3)
|
||||
|
||||
# Step 2: Get video metadata from database
|
||||
metadata = await self.db_manager.get_video_metadata(aid_list)
|
||||
self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database")
|
||||
self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database", 15)
|
||||
|
||||
# Step 3: Get user labels
|
||||
labels = await self.db_manager.get_user_labels(aid_list)
|
||||
self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums")
|
||||
self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums", 30)
|
||||
|
||||
# Step 4: Prepare text data and checksums
|
||||
text_data = []
|
||||
@ -314,7 +234,7 @@ class DatasetBuilder:
|
||||
f"Prepared {i + 1}/{total_aids} text entries"
|
||||
)
|
||||
|
||||
self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings")
|
||||
self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings", 33)
|
||||
|
||||
# Step 5: Check existing embeddings
|
||||
checksums = [item['checksum'] for item in text_data]
|
||||
@ -332,13 +252,26 @@ class DatasetBuilder:
|
||||
task_id,
|
||||
"generating_embeddings",
|
||||
5,
|
||||
f"Generating {len(new_embeddings_needed)} new embeddings"
|
||||
f"Generating {len(new_embeddings_needed)} new embeddings",
|
||||
50
|
||||
)
|
||||
|
||||
# Create progress callback for embedding generation
|
||||
def embedding_progress_callback(current: int, total: int):
|
||||
percentage = (current / total) * 25 + 50 # Map to progress range 5-6
|
||||
self._update_task_progress(
|
||||
task_id,
|
||||
"generating_embeddings",
|
||||
5,
|
||||
f"Generated {current}/{total} embeddings",
|
||||
percentage
|
||||
)
|
||||
|
||||
logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
|
||||
generated_embeddings = await self.embedding_service.generate_embeddings_batch(
|
||||
new_embeddings_needed,
|
||||
embedding_model
|
||||
embedding_model,
|
||||
progress_callback=embedding_progress_callback
|
||||
)
|
||||
|
||||
# Step 7: Store new embeddings in database
|
||||
@ -352,6 +285,14 @@ class DatasetBuilder:
|
||||
'vector': embedding
|
||||
})
|
||||
|
||||
self._update_task_progress(
|
||||
task_id,
|
||||
"inserting_embeddings",
|
||||
6,
|
||||
f"Inserting {len(embeddings_to_store)} embeddings into database",
|
||||
75
|
||||
)
|
||||
|
||||
await self.db_manager.insert_embeddings(embeddings_to_store)
|
||||
new_embeddings_count = len(embeddings_to_store)
|
||||
|
||||
@ -362,7 +303,7 @@ class DatasetBuilder:
|
||||
f'vec_{model_config.dimensions}': emb_data['vector']
|
||||
}
|
||||
|
||||
self._update_task_progress(task_id, "building_dataset", 6, "Building final dataset")
|
||||
self._update_task_progress(task_id, "building_dataset", 7, "Building final dataset", 95)
|
||||
|
||||
# Step 8: Build final dataset
|
||||
dataset = []
|
||||
@ -410,12 +351,13 @@ class DatasetBuilder:
|
||||
|
||||
# Update progress for dataset building
|
||||
if i % 10 == 0 or i == len(text_data) - 1: # Update every 10 items or at the end
|
||||
progress_pct = 6 + (i + 1) / len(text_data)
|
||||
progress_pct = 7 + (i + 1) / len(text_data)
|
||||
self._update_task_progress(
|
||||
task_id,
|
||||
"building_dataset",
|
||||
min(6, int(progress_pct)),
|
||||
f"Built {i + 1}/{len(text_data)} dataset records"
|
||||
min(7, int(progress_pct)),
|
||||
f"Built {i + 1}/{len(text_data)} dataset records",
|
||||
100
|
||||
)
|
||||
|
||||
reused_count = len(dataset) - new_embeddings_count
|
||||
@ -436,15 +378,13 @@ class DatasetBuilder:
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self._update_task_progress(task_id, "saving_dataset", 7, "Saving dataset to storage")
|
||||
self._update_task_progress(task_id, "saving_dataset", 8, "Saving dataset to storage")
|
||||
|
||||
# Save to file and memory cache
|
||||
if self._save_dataset_to_file(dataset_id, dataset_data):
|
||||
self.dataset_storage[dataset_id] = dataset_data
|
||||
logger.info(f"Dataset {dataset_id} saved to file system")
|
||||
# Save to file using efficient storage (Parquet format)
|
||||
if self.storage.save_dataset(dataset_id, dataset_data['dataset'], description, dataset_data['stats']):
|
||||
logger.info(f"Dataset {dataset_id} saved to efficient storage (Parquet format)")
|
||||
else:
|
||||
logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only")
|
||||
self.dataset_storage[dataset_id] = dataset_data
|
||||
logger.warning(f"Failed to save dataset {dataset_id} to storage")
|
||||
|
||||
# Update task status to completed
|
||||
result = {
|
||||
@ -459,8 +399,8 @@ class DatasetBuilder:
|
||||
result=result,
|
||||
progress=TaskProgress(
|
||||
current_step="completed",
|
||||
total_steps=7,
|
||||
completed_steps=7,
|
||||
total_steps=8,
|
||||
completed_steps=8,
|
||||
percentage=100.0,
|
||||
message="Dataset building completed successfully"
|
||||
)
|
||||
@ -479,7 +419,7 @@ class DatasetBuilder:
|
||||
error_message=str(e),
|
||||
progress=TaskProgress(
|
||||
current_step="failed",
|
||||
total_steps=7,
|
||||
total_steps=8,
|
||||
completed_steps=0,
|
||||
percentage=0.0,
|
||||
message=f"Task failed: {str(e)}"
|
||||
@ -492,9 +432,8 @@ class DatasetBuilder:
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Try to save error to file as well
|
||||
self._save_dataset_to_file(dataset_id, error_data)
|
||||
self.dataset_storage[dataset_id] = error_data
|
||||
# Try to save error using new storage
|
||||
self.storage.save_dataset(dataset_id, [], description=str(e), stats={'error': True})
|
||||
raise
|
||||
|
||||
|
||||
@ -523,91 +462,28 @@ class DatasetBuilder:
|
||||
|
||||
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get built dataset by ID"""
|
||||
# First check memory cache
|
||||
if dataset_id in self.dataset_storage:
|
||||
return self.dataset_storage[dataset_id]
|
||||
|
||||
# If not in memory, try to load from file
|
||||
dataset_data = self._load_dataset_from_file(dataset_id)
|
||||
if dataset_data:
|
||||
# Add to memory cache
|
||||
self.dataset_storage[dataset_id] = dataset_data
|
||||
return dataset_data
|
||||
|
||||
return None
|
||||
# Use pure Parquet storage interface
|
||||
return self.storage.load_dataset_full(dataset_id)
|
||||
|
||||
def dataset_exists(self, dataset_id: str) -> bool:
|
||||
"""Check if dataset exists"""
|
||||
# Check memory cache first
|
||||
if dataset_id in self.dataset_storage:
|
||||
return True
|
||||
|
||||
# Check file system
|
||||
return self._get_dataset_file_path(dataset_id).exists()
|
||||
return self.storage.dataset_exists(dataset_id)
|
||||
|
||||
def delete_dataset(self, dataset_id: str) -> bool:
|
||||
"""Delete dataset from both memory and file system"""
|
||||
try:
|
||||
# Remove from memory
|
||||
if dataset_id in self.dataset_storage:
|
||||
del self.dataset_storage[dataset_id]
|
||||
|
||||
# Remove file
|
||||
file_path = self._get_dataset_file_path(dataset_id)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
logger.info(f"Dataset {dataset_id} deleted from file system")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Dataset file {dataset_id} not found for deletion")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete dataset {dataset_id}: {e}")
|
||||
return False
|
||||
return self.storage.delete_dataset(dataset_id)
|
||||
|
||||
def list_datasets(self) -> List[Dict[str, Any]]:
|
||||
"""List all datasets with their basic information"""
|
||||
datasets = []
|
||||
|
||||
self._load_all_datasets()
|
||||
|
||||
for dataset_id, dataset_info in self.dataset_storage.items():
|
||||
if "error" not in dataset_info:
|
||||
datasets.append({
|
||||
"dataset_id": dataset_id,
|
||||
"stats": dataset_info["stats"],
|
||||
"created_at": dataset_info["created_at"]
|
||||
})
|
||||
|
||||
# Sort by creation time (newest first)
|
||||
datasets.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
return datasets
|
||||
return self.storage.list_datasets()
|
||||
|
||||
def get_dataset_stats(self) -> Dict[str, Any]:
|
||||
"""Get overall statistics about stored datasets"""
|
||||
total_datasets = len(self.dataset_storage)
|
||||
error_datasets = sum(1 for data in self.dataset_storage.values() if "error" in data)
|
||||
valid_datasets = total_datasets - error_datasets
|
||||
|
||||
total_records = 0
|
||||
total_new_embeddings = 0
|
||||
total_reused_embeddings = 0
|
||||
|
||||
for dataset_info in self.dataset_storage.values():
|
||||
if "stats" in dataset_info:
|
||||
stats = dataset_info["stats"]
|
||||
total_records += stats.get("total_records", 0)
|
||||
total_new_embeddings += stats.get("new_embeddings", 0)
|
||||
total_reused_embeddings += stats.get("reused_embeddings", 0)
|
||||
|
||||
return {
|
||||
"total_datasets": total_datasets,
|
||||
"valid_datasets": valid_datasets,
|
||||
"error_datasets": error_datasets,
|
||||
"total_records": total_records,
|
||||
"total_new_embeddings": total_new_embeddings,
|
||||
"total_reused_embeddings": total_reused_embeddings,
|
||||
"storage_directory": str(self.storage_dir)
|
||||
}
|
||||
return self.storage.get_dataset_stats()
|
||||
|
||||
def load_dataset_metadata_fast(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self.storage.load_dataset_metadata(dataset_id)
|
||||
|
||||
def load_dataset_partial(self, dataset_id: str, columns: Optional[List[str]] = None,
|
||||
filters: Optional[Dict[str, Any]] = None):
|
||||
return self.storage.load_dataset_partial(dataset_id, columns, filters)
|
||||
393
ml_new/training/dataset_storage_parquet.py
Normal file
393
ml_new/training/dataset_storage_parquet.py
Normal file
@ -0,0 +1,393 @@
|
||||
"""
|
||||
Efficient dataset storage using Parquet format for better space utilization and loading performance
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from datetime import datetime
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
from logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class ParquetDatasetStorage:
|
||||
def __init__(self, storage_dir: str = "datasets"):
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.storage_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Parquet file extension
|
||||
self.parquet_ext = ".parquet"
|
||||
self.metadata_ext = ".metadata.json"
|
||||
|
||||
# In-memory cache: only cache metadata to avoid large file memory usage
|
||||
self.metadata_cache: Dict[str, Dict[str, Any]] = {}
|
||||
self._load_metadata_cache()
|
||||
|
||||
def _get_dataset_files(self, dataset_id: str) -> tuple[Path, Path]:
|
||||
"""Get file paths for the dataset"""
|
||||
base_path = self.storage_dir / dataset_id
|
||||
data_file = base_path.with_suffix(self.parquet_ext)
|
||||
metadata_file = base_path.with_suffix(self.metadata_ext)
|
||||
return data_file, metadata_file
|
||||
|
||||
def _load_metadata_cache(self):
|
||||
"""Load metadata cache"""
|
||||
try:
|
||||
for metadata_file in self.storage_dir.glob("*.metadata.json"):
|
||||
try:
|
||||
# Remove ".metadata" suffix
|
||||
dataset_id = metadata_file.stem[:-9]
|
||||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
self.metadata_cache[dataset_id] = metadata
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load metadata for {metadata_file}: {e}")
|
||||
|
||||
logger.info(f"Loaded metadata for {len(self.metadata_cache)} datasets")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load metadata cache: {e}")
|
||||
|
||||
def save_dataset(self, dataset_id: str, dataset: List[Dict[str, Any]],
|
||||
description: Optional[str] = None, stats: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Save dataset using Parquet format
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
dataset: Dataset content
|
||||
description: Dataset description
|
||||
stats: Dataset statistics
|
||||
|
||||
Returns:
|
||||
bool: Whether the save was successful
|
||||
"""
|
||||
try:
|
||||
data_file, metadata_file = self._get_dataset_files(dataset_id)
|
||||
|
||||
# Ensure directory exists
|
||||
data_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Prepare data: convert embedding vectors to numpy arrays
|
||||
if not dataset:
|
||||
logger.warning(f"Empty dataset for {dataset_id}")
|
||||
return False
|
||||
|
||||
# Analyze data structure
|
||||
first_item = dataset[0]
|
||||
embedding_dim = len(first_item.get('embedding', []))
|
||||
|
||||
# Build DataFrame
|
||||
records = []
|
||||
for item in dataset:
|
||||
record = {
|
||||
'aid': item.get('aid'),
|
||||
'label': item.get('label'),
|
||||
'inconsistent': item.get('inconsistent', False),
|
||||
'text_checksum': item.get('text_checksum'),
|
||||
# Store embedding as a separate column
|
||||
'embedding': item.get('embedding', []),
|
||||
# Store metadata as JSON string
|
||||
'metadata_json': json.dumps(item.get('metadata', {}), ensure_ascii=False),
|
||||
'user_labels_json': json.dumps(item.get('user_labels', []), ensure_ascii=False)
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
# Create DataFrame
|
||||
df = pd.DataFrame(records)
|
||||
|
||||
# Convert embedding column to numpy arrays
|
||||
df['embedding'] = df['embedding'].apply(lambda x: np.array(x, dtype=np.float32) if x else np.array([], dtype=np.float32))
|
||||
|
||||
# Use PyArrow Schema for type safety
|
||||
schema = pa.schema([
|
||||
('aid', pa.int64()),
|
||||
('label', pa.bool_()),
|
||||
('inconsistent', pa.bool_()),
|
||||
('text_checksum', pa.string()),
|
||||
('embedding', pa.list_(pa.float32())),
|
||||
('metadata_json', pa.string()),
|
||||
('user_labels_json', pa.string())
|
||||
])
|
||||
|
||||
# Convert to PyArrow Table
|
||||
table = pa.Table.from_pandas(df, schema=schema)
|
||||
|
||||
# Write Parquet file with efficient compression settings
|
||||
pq.write_table(
|
||||
table,
|
||||
data_file,
|
||||
compression='zstd', # Better compression ratio
|
||||
compression_level=6, # Balance compression ratio and speed
|
||||
use_dictionary=True, # Enable dictionary encoding
|
||||
write_page_index=True, # Support fast metadata access
|
||||
write_statistics=True # Enable statistics
|
||||
)
|
||||
|
||||
# Save metadata
|
||||
metadata = {
|
||||
'dataset_id': dataset_id,
|
||||
'description': description,
|
||||
'stats': stats or {},
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'file_format': 'parquet_v1',
|
||||
'embedding_dimension': embedding_dim,
|
||||
'total_records': len(dataset),
|
||||
'columns': list(df.columns),
|
||||
'file_size_bytes': data_file.stat().st_size,
|
||||
'compression': 'zstd'
|
||||
}
|
||||
|
||||
with open(metadata_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# Update cache
|
||||
self.metadata_cache[dataset_id] = metadata
|
||||
|
||||
logger.info(f"Saved dataset {dataset_id} to Parquet: {len(dataset)} records, {data_file.stat().st_size} bytes")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
def load_dataset_metadata(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Quickly load dataset metadata (without loading the entire file)
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
|
||||
Returns:
|
||||
Dict: Metadata, or None if not found
|
||||
"""
|
||||
# Check cache
|
||||
if dataset_id in self.metadata_cache:
|
||||
return self.metadata_cache[dataset_id]
|
||||
|
||||
# Load from file
|
||||
_, metadata_file = self._get_dataset_files(dataset_id)
|
||||
if not metadata_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(metadata_file, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# Update cache
|
||||
self.metadata_cache[dataset_id] = metadata
|
||||
return metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load metadata for {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
def load_dataset_partial(self, dataset_id: str, columns: Optional[List[str]] = None,
|
||||
filters: Optional[Dict[str, Any]] = None) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Partially load the dataset (only reading specified columns or rows meeting criteria)
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
columns: Columns to read, None to read all
|
||||
filters: Filtering conditions, format {column: value}
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: Loaded data, or None if failed
|
||||
"""
|
||||
data_file, _ = self._get_dataset_files(dataset_id)
|
||||
if not data_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
# Read Parquet file, supporting column selection and filtering
|
||||
if columns:
|
||||
# Ensure necessary columns exist
|
||||
all_columns = ['aid', 'label', 'inconsistent', 'text_checksum', 'embedding', 'metadata_json', 'user_labels_json']
|
||||
required_cols = ['aid', 'label', 'embedding'] # These are fundamentally needed
|
||||
columns = list(set(columns + required_cols))
|
||||
|
||||
# Filter out non-existent columns
|
||||
columns = [col for col in columns if col in all_columns]
|
||||
|
||||
# Use pyarrow to read, supporting filters
|
||||
if filters:
|
||||
# Build filter expressions
|
||||
expressions = []
|
||||
for col, value in filters.items():
|
||||
if col == 'label':
|
||||
expressions.append(pa.compute.equal(pa.field(col), value))
|
||||
elif col == 'aid':
|
||||
expressions.append(pa.compute.equal(pa.field(col), value))
|
||||
|
||||
if expressions:
|
||||
filter_expr = expressions[0]
|
||||
for expr in expressions[1:]:
|
||||
filter_expr = pa.compute.and_(filter_expr, expr)
|
||||
else:
|
||||
filter_expr = None
|
||||
else:
|
||||
filter_expr = None
|
||||
|
||||
# Read data
|
||||
if columns and filter_expr:
|
||||
table = pq.read_table(data_file, columns=columns, filter=filter_expr)
|
||||
elif columns:
|
||||
table = pq.read_table(data_file, columns=columns)
|
||||
elif filter_expr:
|
||||
table = pq.read_table(data_file, filter=filter_expr)
|
||||
else:
|
||||
table = pq.read_table(data_file)
|
||||
|
||||
# Convert to DataFrame
|
||||
df = table.to_pandas()
|
||||
|
||||
# Handle embedding column
|
||||
if 'embedding' in df.columns:
|
||||
df['embedding'] = df['embedding'].apply(lambda x: x.tolist() if hasattr(x, 'tolist') else list(x))
|
||||
|
||||
logger.info(f"Loaded partial dataset {dataset_id}: {len(df)} rows, {len(df.columns)} columns")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load partial dataset {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
def load_dataset_full(self, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fully load the dataset (maintaining backward compatibility format)
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
|
||||
Returns:
|
||||
Dict: Full dataset data, or None if failed
|
||||
"""
|
||||
data_file, _ = self._get_dataset_files(dataset_id)
|
||||
if not data_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
# Load metadata
|
||||
metadata = self.load_dataset_metadata(dataset_id)
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# Load data
|
||||
df = self.load_dataset_partial(dataset_id)
|
||||
if df is None:
|
||||
return None
|
||||
|
||||
# Convert to original format
|
||||
dataset = []
|
||||
for _, row in df.iterrows():
|
||||
record = {
|
||||
'aid': int(row['aid']),
|
||||
'embedding': row['embedding'],
|
||||
'label': bool(row['label']),
|
||||
'metadata': json.loads(row['metadata_json']) if row['metadata_json'] else {},
|
||||
'user_labels': json.loads(row['user_labels_json']) if row['user_labels_json'] else [],
|
||||
'inconsistent': bool(row['inconsistent']),
|
||||
'text_checksum': row['text_checksum']
|
||||
}
|
||||
dataset.append(record)
|
||||
|
||||
return {
|
||||
'dataset': dataset,
|
||||
'description': metadata.get('description'),
|
||||
'stats': metadata.get('stats', {}),
|
||||
'created_at': metadata.get('created_at')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load full dataset {dataset_id}: {e}")
|
||||
return None
|
||||
|
||||
def dataset_exists(self, dataset_id: str) -> bool:
|
||||
"""Check if the dataset exists"""
|
||||
data_file, _ = self._get_dataset_files(dataset_id)
|
||||
return data_file.exists()
|
||||
|
||||
def delete_dataset(self, dataset_id: str) -> bool:
|
||||
"""Delete a dataset"""
|
||||
try:
|
||||
data_file, metadata_file = self._get_dataset_files(dataset_id)
|
||||
|
||||
# Delete files
|
||||
if data_file.exists():
|
||||
data_file.unlink()
|
||||
if metadata_file.exists():
|
||||
metadata_file.unlink()
|
||||
|
||||
# Remove from cache
|
||||
if dataset_id in self.metadata_cache:
|
||||
del self.metadata_cache[dataset_id]
|
||||
|
||||
logger.info(f"Deleted dataset {dataset_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete dataset {dataset_id}: {e}")
|
||||
return False
|
||||
|
||||
def list_datasets(self) -> List[Dict[str, Any]]:
|
||||
"""List metadata for all datasets"""
|
||||
datasets = []
|
||||
|
||||
for dataset_id, metadata in self.metadata_cache.items():
|
||||
datasets.append({
|
||||
"dataset_id": dataset_id,
|
||||
"description": metadata.get("description"),
|
||||
"stats": metadata.get("stats", {}),
|
||||
"created_at": metadata.get("created_at"),
|
||||
"total_records": metadata.get("total_records", 0),
|
||||
"file_size_mb": round(metadata.get("file_size_bytes", 0) / (1024 * 1024), 2),
|
||||
"embedding_dimension": metadata.get("embedding_dimension"),
|
||||
"file_format": metadata.get("file_format")
|
||||
})
|
||||
|
||||
# Sort by creation time descending
|
||||
datasets.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
return datasets
|
||||
|
||||
def get_dataset_stats(self) -> Dict[str, Any]:
|
||||
"""Get overall statistics"""
|
||||
total_datasets = len(self.metadata_cache)
|
||||
total_records = sum(m.get("total_records", 0) for m in self.metadata_cache.values())
|
||||
total_size_bytes = sum(m.get("file_size_bytes", 0) for m in self.metadata_cache.values())
|
||||
|
||||
return {
|
||||
"total_datasets": total_datasets,
|
||||
"total_records": total_records,
|
||||
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
|
||||
"average_size_mb": round(total_size_bytes / total_datasets / (1024 * 1024), 2) if total_datasets > 0 else 0,
|
||||
"storage_directory": str(self.storage_dir),
|
||||
"storage_format": "parquet_v1"
|
||||
}
|
||||
|
||||
def migrate_from_json(self, dataset_id: str, json_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Migrate a dataset from JSON format to Parquet format
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
json_data: Data in JSON format
|
||||
|
||||
Returns:
|
||||
bool: Migration success status
|
||||
"""
|
||||
try:
|
||||
dataset = json_data.get('dataset', [])
|
||||
description = json_data.get('description')
|
||||
stats = json_data.get('stats')
|
||||
|
||||
return self.save_dataset(dataset_id, dataset, description, stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to migrate dataset {dataset_id} from JSON: {e}")
|
||||
return False
|
||||
@ -125,7 +125,8 @@ class EmbeddingService:
|
||||
self,
|
||||
texts: List[str],
|
||||
model: str,
|
||||
batch_size: Optional[int] = None
|
||||
batch_size: Optional[int] = None,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings for a batch of texts"""
|
||||
|
||||
@ -137,9 +138,9 @@ class EmbeddingService:
|
||||
|
||||
# Handle different model types
|
||||
if model_config.type == "legacy":
|
||||
return self._generate_legacy_embeddings_batch(texts, model, batch_size)
|
||||
return self._generate_legacy_embeddings_batch(texts, model, batch_size, progress_callback)
|
||||
elif model_config.type == "openai-compatible":
|
||||
return await self._generate_openai_embeddings_batch(texts, model, batch_size)
|
||||
return await self._generate_openai_embeddings_batch(texts, model, batch_size, progress_callback=progress_callback)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {model_config.type}")
|
||||
|
||||
@ -147,7 +148,8 @@ class EmbeddingService:
|
||||
self,
|
||||
texts: List[str],
|
||||
model: str,
|
||||
batch_size: Optional[int] = None
|
||||
batch_size: Optional[int] = None,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings using legacy ONNX model"""
|
||||
if model not in self.legacy_models:
|
||||
@ -162,9 +164,13 @@ class EmbeddingService:
|
||||
expected_dims = model_config.dimensions
|
||||
all_embeddings = []
|
||||
|
||||
# Calculate total batches for progress tracking
|
||||
total_batches = (len(texts) + batch_size - 1) // batch_size
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
batch_idx = i // batch_size + 1
|
||||
|
||||
try:
|
||||
# Generate embeddings using legacy method
|
||||
@ -174,23 +180,33 @@ class EmbeddingService:
|
||||
batch_embeddings = embeddings.tolist()
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
logger.info(f"Generated legacy embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
|
||||
logger.info(f"Generated legacy embeddings for batch {batch_idx}/{total_batches}")
|
||||
|
||||
# Update progress if callback provided
|
||||
if progress_callback:
|
||||
progress_callback(batch_idx, total_batches)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating legacy embeddings for batch {i//batch_size + 1}: {e}")
|
||||
logger.error(f"Error generating legacy embeddings for batch {batch_idx}: {e}")
|
||||
# Fill with zeros as fallback
|
||||
zero_embedding = [0.0] * expected_dims
|
||||
all_embeddings.extend([zero_embedding] * len(batch))
|
||||
|
||||
# Final progress update
|
||||
if progress_callback:
|
||||
progress_callback(total_batches, total_batches)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
async def _generate_openai_embeddings_batch(
|
||||
self,
|
||||
texts: List[str],
|
||||
model: str,
|
||||
batch_size: Optional[int] = None
|
||||
batch_size: Optional[int] = None,
|
||||
max_concurrent: int = 6,
|
||||
progress_callback: Optional[callable] = None
|
||||
) -> List[List[float]]:
|
||||
"""Generate embeddings using OpenAI-compatible API"""
|
||||
"""Generate embeddings using OpenAI-compatible API with parallel requests"""
|
||||
model_config = self.embedding_models[model]
|
||||
|
||||
# Use model's max_batch_size if not specified
|
||||
@ -204,34 +220,71 @@ class EmbeddingService:
|
||||
raise ValueError(f"No client configured for model '{model}'")
|
||||
|
||||
client = self.clients[model]
|
||||
all_embeddings = []
|
||||
|
||||
# Process in batches to avoid API limits
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
|
||||
try:
|
||||
# Rate limiting
|
||||
if i > 0:
|
||||
await asyncio.sleep(self.request_interval)
|
||||
|
||||
# Generate embeddings
|
||||
response = await client.embeddings.create(
|
||||
model=model_config.name,
|
||||
input=batch,
|
||||
dimensions=expected_dims
|
||||
)
|
||||
|
||||
batch_embeddings = [data.embedding for data in response.data]
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
logger.info(f"Generated embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings for batch {i//batch_size + 1}: {e}")
|
||||
# For now, fill with zeros as fallback (could implement retry logic)
|
||||
zero_embedding = [0.0] * expected_dims
|
||||
all_embeddings.extend([zero_embedding] * len(batch))
|
||||
# Split texts into batches
|
||||
batches = [(i, texts[i:i + batch_size]) for i in range(0, len(texts), batch_size)]
|
||||
total_batches = len(batches)
|
||||
|
||||
# Track completed batches for progress reporting
|
||||
completed_batches = 0
|
||||
completed_batches_lock = asyncio.Lock()
|
||||
|
||||
# Semaphore for concurrency control
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_batch(batch_idx: int, batch: List[str]) -> tuple[int, List[List[float]]]:
|
||||
"""Process a single batch with semaphore-controlled concurrency"""
|
||||
nonlocal completed_batches
|
||||
async with semaphore:
|
||||
try:
|
||||
# Rate limiting
|
||||
if batch_idx > 0:
|
||||
await asyncio.sleep(self.request_interval)
|
||||
|
||||
# Generate embeddings
|
||||
response = await client.embeddings.create(
|
||||
model=model_config.name,
|
||||
input=batch,
|
||||
dimensions=expected_dims
|
||||
)
|
||||
|
||||
batch_embeddings = [data.embedding for data in response.data]
|
||||
logger.info(f"Generated embeddings for batch {batch_idx//batch_size + 1}/{total_batches}")
|
||||
|
||||
# Update completed batch count and progress
|
||||
async with completed_batches_lock:
|
||||
completed_batches += 1
|
||||
if progress_callback:
|
||||
progress_callback(completed_batches, total_batches)
|
||||
|
||||
return (batch_idx, batch_embeddings)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating embeddings for batch {batch_idx//batch_size + 1}: {e}")
|
||||
# Fill with zeros as fallback
|
||||
zero_embedding = [0.0] * expected_dims
|
||||
|
||||
# Update completed batch count and progress even for failed batches
|
||||
async with completed_batches_lock:
|
||||
completed_batches += 1
|
||||
if progress_callback:
|
||||
progress_callback(completed_batches, total_batches)
|
||||
|
||||
return (batch_idx, [zero_embedding] * len(batch))
|
||||
|
||||
# Process all batches concurrently with semaphore limit
|
||||
tasks = [process_batch(i, batch) for i, batch in batches]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Final progress update
|
||||
if progress_callback:
|
||||
progress_callback(total_batches, total_batches)
|
||||
|
||||
# Sort results by batch index and flatten
|
||||
results.sort(key=lambda x: x[0])
|
||||
all_embeddings = []
|
||||
for _, embeddings in results:
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@ -268,7 +321,8 @@ class EmbeddingService:
|
||||
test_embedding = await self.generate_embeddings_batch(
|
||||
["health check"],
|
||||
model_name,
|
||||
batch_size=1
|
||||
batch_size=1,
|
||||
progress_callback=None
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
@ -13,4 +13,5 @@ sqlalchemy==2.0.23
|
||||
asyncpg==0.29.0
|
||||
toml==0.10.2
|
||||
aiohttp==3.9.0
|
||||
python-dotenv==1.1.0
|
||||
python-dotenv==1.1.0
|
||||
pyarrow==15.0.0
|
||||
@ -30,7 +30,7 @@ export function TaskMonitor() {
|
||||
const params = statusFilter === "all" ? {} : { status: statusFilter };
|
||||
return apiClient.getTasks(params.status, 50);
|
||||
},
|
||||
refetchInterval: 5000 // Refresh every 5 seconds
|
||||
refetchInterval: 500
|
||||
});
|
||||
|
||||
const getStatusIcon = (status: string) => {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user