1
0

update: use parquet as dataset format

This commit is contained in:
alikia2x (寒寒) 2025-12-10 20:57:44 +08:00
parent d3f26f7e7b
commit 24582ccaf5
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG Key ID: 56209E0CCD8420C6
6 changed files with 544 additions and 232 deletions

View File

@ -188,18 +188,6 @@ async def get_dataset_stats_endpoint():
stats = dataset_builder.get_dataset_stats()
return stats
@router.post("/datasets/cleanup")
async def cleanup_datasets_endpoint(max_age_days: int = 30):
"""Remove datasets older than specified days"""
if not dataset_builder:
raise HTTPException(status_code=503, detail="Dataset builder not available")
await dataset_builder.cleanup_old_datasets(max_age_days)
return {"message": f"Cleanup completed for datasets older than {max_age_days} days"}
# Task Status Endpoints
@router.get("/tasks/{task_id}", response_model=TaskStatusResponse)

View File

@ -2,7 +2,6 @@
Dataset building service - handles the complete dataset construction flow
"""
import json
import asyncio
from pathlib import Path
from typing import List, Dict, Any, Optional
@ -14,6 +13,7 @@ from embedding_service import EmbeddingService
from config_loader import config_loader
from logger_config import get_logger
from models import TaskStatus, DatasetBuildTaskStatus, TaskProgress
from dataset_storage_parquet import ParquetDatasetStorage
logger = get_logger(__name__)
@ -28,63 +28,13 @@ class DatasetBuilder:
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(exist_ok=True)
# Load existing datasets from file system
self.dataset_storage: Dict[str, Dict] = self._load_all_datasets()
self.storage = ParquetDatasetStorage(str(self.storage_dir))
# Task status tracking
self._task_status_lock = threading.Lock()
self.task_statuses: Dict[str, DatasetBuildTaskStatus] = {}
self.running_tasks: Dict[str, asyncio.Task] = {}
def _get_dataset_file_path(self, dataset_id: str) -> Path:
"""Get file path for dataset"""
return self.storage_dir / f"{dataset_id}.json"
def _load_dataset_from_file(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Load dataset from file"""
file_path = self._get_dataset_file_path(dataset_id)
if not file_path.exists():
return None
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load dataset {dataset_id} from file: {e}")
return None
def _save_dataset_to_file(self, dataset_id: str, dataset_data: Dict[str, Any]) -> bool:
"""Save dataset to file"""
file_path = self._get_dataset_file_path(dataset_id)
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(dataset_data, f, ensure_ascii=False, indent=2)
return True
except Exception as e:
logger.error(f"Failed to save dataset {dataset_id} to file: {e}")
return False
def _load_all_datasets(self) -> Dict[str, Dict]:
"""Load all datasets from file system"""
datasets = {}
try:
for file_path in self.storage_dir.glob("*.json"):
dataset_id = file_path.stem
dataset_data = self._load_dataset_from_file(dataset_id)
if dataset_data:
datasets[dataset_id] = dataset_data
logger.info(f"Loaded {len(datasets)} datasets from file system")
except Exception as e:
logger.error(f"Failed to load datasets from file system: {e}")
return datasets
def _create_task_status(self, task_id: str, dataset_id: str, aid_list: List[int],
embedding_model: str, force_regenerate: bool) -> DatasetBuildTaskStatus:
@ -100,7 +50,7 @@ class DatasetBuilder:
created_at=datetime.now(),
progress=TaskProgress(
current_step="initialized",
total_steps=7,
total_steps=8,
completed_steps=0,
percentage=0.0,
message="Task initialized"
@ -214,36 +164,6 @@ class DatasetBuilder:
return cleaned_count
async def cleanup_old_datasets(self, max_age_days: int = 30):
"""Remove datasets older than specified days"""
try:
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 60 * 60)
removed_count = 0
for dataset_id in list(self.dataset_storage.keys()):
dataset_info = self.dataset_storage[dataset_id]
if "created_at" in dataset_info:
try:
created_time = datetime.fromisoformat(dataset_info["created_at"]).timestamp()
if created_time < cutoff_time:
# Remove from memory
del self.dataset_storage[dataset_id]
# Remove file
file_path = self._get_dataset_file_path(dataset_id)
if file_path.exists():
file_path.unlink()
removed_count += 1
except Exception as e:
logger.warning(f"Failed to process dataset {dataset_id} for cleanup: {e}")
if removed_count > 0:
logger.info(f"Cleaned up {removed_count} old datasets")
except Exception as e:
logger.error(f"Failed to cleanup old datasets: {e}")
async def build_dataset_with_task_tracking(self, task_id: str, dataset_id: str, aid_list: List[int],
embedding_model: str, force_regenerate: bool = False,
description: Optional[str] = None) -> str:
@ -272,15 +192,15 @@ class DatasetBuilder:
raise ValueError(f"Invalid embedding model: {embedding_model}")
model_config = EMBEDDING_MODELS[embedding_model]
self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database")
self._update_task_progress(task_id, "getting_metadata", 1, "Retrieving video metadata from database", 3)
# Step 2: Get video metadata from database
metadata = await self.db_manager.get_video_metadata(aid_list)
self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database")
self._update_task_progress(task_id, "getting_labels", 2, "Retrieving user labels from database", 15)
# Step 3: Get user labels
labels = await self.db_manager.get_user_labels(aid_list)
self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums")
self._update_task_progress(task_id, "preparing_text", 3, "Preparing text data and checksums", 30)
# Step 4: Prepare text data and checksums
text_data = []
@ -314,7 +234,7 @@ class DatasetBuilder:
f"Prepared {i + 1}/{total_aids} text entries"
)
self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings")
self._update_task_progress(task_id, "checking_embeddings", 4, "Checking existing embeddings", 33)
# Step 5: Check existing embeddings
checksums = [item['checksum'] for item in text_data]
@ -332,13 +252,26 @@ class DatasetBuilder:
task_id,
"generating_embeddings",
5,
f"Generating {len(new_embeddings_needed)} new embeddings"
f"Generating {len(new_embeddings_needed)} new embeddings",
50
)
# Create progress callback for embedding generation
def embedding_progress_callback(current: int, total: int):
percentage = (current / total) * 25 + 50 # Map to progress range 5-6
self._update_task_progress(
task_id,
"generating_embeddings",
5,
f"Generated {current}/{total} embeddings",
percentage
)
logger.info(f"Generating {len(new_embeddings_needed)} new embeddings")
generated_embeddings = await self.embedding_service.generate_embeddings_batch(
new_embeddings_needed,
embedding_model
embedding_model,
progress_callback=embedding_progress_callback
)
# Step 7: Store new embeddings in database
@ -352,6 +285,14 @@ class DatasetBuilder:
'vector': embedding
})
self._update_task_progress(
task_id,
"inserting_embeddings",
6,
f"Inserting {len(embeddings_to_store)} embeddings into database",
75
)
await self.db_manager.insert_embeddings(embeddings_to_store)
new_embeddings_count = len(embeddings_to_store)
@ -362,7 +303,7 @@ class DatasetBuilder:
f'vec_{model_config.dimensions}': emb_data['vector']
}
self._update_task_progress(task_id, "building_dataset", 6, "Building final dataset")
self._update_task_progress(task_id, "building_dataset", 7, "Building final dataset", 95)
# Step 8: Build final dataset
dataset = []
@ -410,12 +351,13 @@ class DatasetBuilder:
# Update progress for dataset building
if i % 10 == 0 or i == len(text_data) - 1: # Update every 10 items or at the end
progress_pct = 6 + (i + 1) / len(text_data)
progress_pct = 7 + (i + 1) / len(text_data)
self._update_task_progress(
task_id,
"building_dataset",
min(6, int(progress_pct)),
f"Built {i + 1}/{len(text_data)} dataset records"
min(7, int(progress_pct)),
f"Built {i + 1}/{len(text_data)} dataset records",
100
)
reused_count = len(dataset) - new_embeddings_count
@ -436,15 +378,13 @@ class DatasetBuilder:
'created_at': datetime.now().isoformat()
}
self._update_task_progress(task_id, "saving_dataset", 7, "Saving dataset to storage")
self._update_task_progress(task_id, "saving_dataset", 8, "Saving dataset to storage")
# Save to file and memory cache
if self._save_dataset_to_file(dataset_id, dataset_data):
self.dataset_storage[dataset_id] = dataset_data
logger.info(f"Dataset {dataset_id} saved to file system")
# Save to file using efficient storage (Parquet format)
if self.storage.save_dataset(dataset_id, dataset_data['dataset'], description, dataset_data['stats']):
logger.info(f"Dataset {dataset_id} saved to efficient storage (Parquet format)")
else:
logger.warning(f"Failed to save dataset {dataset_id} to file, keeping in memory only")
self.dataset_storage[dataset_id] = dataset_data
logger.warning(f"Failed to save dataset {dataset_id} to storage")
# Update task status to completed
result = {
@ -459,8 +399,8 @@ class DatasetBuilder:
result=result,
progress=TaskProgress(
current_step="completed",
total_steps=7,
completed_steps=7,
total_steps=8,
completed_steps=8,
percentage=100.0,
message="Dataset building completed successfully"
)
@ -479,7 +419,7 @@ class DatasetBuilder:
error_message=str(e),
progress=TaskProgress(
current_step="failed",
total_steps=7,
total_steps=8,
completed_steps=0,
percentage=0.0,
message=f"Task failed: {str(e)}"
@ -492,9 +432,8 @@ class DatasetBuilder:
'created_at': datetime.now().isoformat()
}
# Try to save error to file as well
self._save_dataset_to_file(dataset_id, error_data)
self.dataset_storage[dataset_id] = error_data
# Try to save error using new storage
self.storage.save_dataset(dataset_id, [], description=str(e), stats={'error': True})
raise
@ -523,91 +462,28 @@ class DatasetBuilder:
def get_dataset(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""Get built dataset by ID"""
# First check memory cache
if dataset_id in self.dataset_storage:
return self.dataset_storage[dataset_id]
# If not in memory, try to load from file
dataset_data = self._load_dataset_from_file(dataset_id)
if dataset_data:
# Add to memory cache
self.dataset_storage[dataset_id] = dataset_data
return dataset_data
return None
# Use pure Parquet storage interface
return self.storage.load_dataset_full(dataset_id)
def dataset_exists(self, dataset_id: str) -> bool:
"""Check if dataset exists"""
# Check memory cache first
if dataset_id in self.dataset_storage:
return True
# Check file system
return self._get_dataset_file_path(dataset_id).exists()
return self.storage.dataset_exists(dataset_id)
def delete_dataset(self, dataset_id: str) -> bool:
"""Delete dataset from both memory and file system"""
try:
# Remove from memory
if dataset_id in self.dataset_storage:
del self.dataset_storage[dataset_id]
# Remove file
file_path = self._get_dataset_file_path(dataset_id)
if file_path.exists():
file_path.unlink()
logger.info(f"Dataset {dataset_id} deleted from file system")
return True
else:
logger.warning(f"Dataset file {dataset_id} not found for deletion")
return False
except Exception as e:
logger.error(f"Failed to delete dataset {dataset_id}: {e}")
return False
return self.storage.delete_dataset(dataset_id)
def list_datasets(self) -> List[Dict[str, Any]]:
"""List all datasets with their basic information"""
datasets = []
self._load_all_datasets()
for dataset_id, dataset_info in self.dataset_storage.items():
if "error" not in dataset_info:
datasets.append({
"dataset_id": dataset_id,
"stats": dataset_info["stats"],
"created_at": dataset_info["created_at"]
})
# Sort by creation time (newest first)
datasets.sort(key=lambda x: x["created_at"], reverse=True)
return datasets
return self.storage.list_datasets()
def get_dataset_stats(self) -> Dict[str, Any]:
"""Get overall statistics about stored datasets"""
total_datasets = len(self.dataset_storage)
error_datasets = sum(1 for data in self.dataset_storage.values() if "error" in data)
valid_datasets = total_datasets - error_datasets
total_records = 0
total_new_embeddings = 0
total_reused_embeddings = 0
for dataset_info in self.dataset_storage.values():
if "stats" in dataset_info:
stats = dataset_info["stats"]
total_records += stats.get("total_records", 0)
total_new_embeddings += stats.get("new_embeddings", 0)
total_reused_embeddings += stats.get("reused_embeddings", 0)
return {
"total_datasets": total_datasets,
"valid_datasets": valid_datasets,
"error_datasets": error_datasets,
"total_records": total_records,
"total_new_embeddings": total_new_embeddings,
"total_reused_embeddings": total_reused_embeddings,
"storage_directory": str(self.storage_dir)
}
return self.storage.get_dataset_stats()
def load_dataset_metadata_fast(self, dataset_id: str) -> Optional[Dict[str, Any]]:
return self.storage.load_dataset_metadata(dataset_id)
def load_dataset_partial(self, dataset_id: str, columns: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None):
return self.storage.load_dataset_partial(dataset_id, columns, filters)

View File

@ -0,0 +1,393 @@
"""
Efficient dataset storage using Parquet format for better space utilization and loading performance
"""
import pandas as pd
import numpy as np
import json
import os
from pathlib import Path
from typing import List, Dict, Any, Optional, Union
from datetime import datetime
import pyarrow as pa
import pyarrow.parquet as pq
from logger_config import get_logger
logger = get_logger(__name__)
class ParquetDatasetStorage:
def __init__(self, storage_dir: str = "datasets"):
self.storage_dir = Path(storage_dir)
self.storage_dir.mkdir(exist_ok=True)
# Parquet file extension
self.parquet_ext = ".parquet"
self.metadata_ext = ".metadata.json"
# In-memory cache: only cache metadata to avoid large file memory usage
self.metadata_cache: Dict[str, Dict[str, Any]] = {}
self._load_metadata_cache()
def _get_dataset_files(self, dataset_id: str) -> tuple[Path, Path]:
"""Get file paths for the dataset"""
base_path = self.storage_dir / dataset_id
data_file = base_path.with_suffix(self.parquet_ext)
metadata_file = base_path.with_suffix(self.metadata_ext)
return data_file, metadata_file
def _load_metadata_cache(self):
"""Load metadata cache"""
try:
for metadata_file in self.storage_dir.glob("*.metadata.json"):
try:
# Remove ".metadata" suffix
dataset_id = metadata_file.stem[:-9]
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
self.metadata_cache[dataset_id] = metadata
except Exception as e:
logger.warning(f"Failed to load metadata for {metadata_file}: {e}")
logger.info(f"Loaded metadata for {len(self.metadata_cache)} datasets")
except Exception as e:
logger.error(f"Failed to load metadata cache: {e}")
def save_dataset(self, dataset_id: str, dataset: List[Dict[str, Any]],
description: Optional[str] = None, stats: Optional[Dict[str, Any]] = None) -> bool:
"""
Save dataset using Parquet format
Args:
dataset_id: Dataset ID
dataset: Dataset content
description: Dataset description
stats: Dataset statistics
Returns:
bool: Whether the save was successful
"""
try:
data_file, metadata_file = self._get_dataset_files(dataset_id)
# Ensure directory exists
data_file.parent.mkdir(parents=True, exist_ok=True)
# Prepare data: convert embedding vectors to numpy arrays
if not dataset:
logger.warning(f"Empty dataset for {dataset_id}")
return False
# Analyze data structure
first_item = dataset[0]
embedding_dim = len(first_item.get('embedding', []))
# Build DataFrame
records = []
for item in dataset:
record = {
'aid': item.get('aid'),
'label': item.get('label'),
'inconsistent': item.get('inconsistent', False),
'text_checksum': item.get('text_checksum'),
# Store embedding as a separate column
'embedding': item.get('embedding', []),
# Store metadata as JSON string
'metadata_json': json.dumps(item.get('metadata', {}), ensure_ascii=False),
'user_labels_json': json.dumps(item.get('user_labels', []), ensure_ascii=False)
}
records.append(record)
# Create DataFrame
df = pd.DataFrame(records)
# Convert embedding column to numpy arrays
df['embedding'] = df['embedding'].apply(lambda x: np.array(x, dtype=np.float32) if x else np.array([], dtype=np.float32))
# Use PyArrow Schema for type safety
schema = pa.schema([
('aid', pa.int64()),
('label', pa.bool_()),
('inconsistent', pa.bool_()),
('text_checksum', pa.string()),
('embedding', pa.list_(pa.float32())),
('metadata_json', pa.string()),
('user_labels_json', pa.string())
])
# Convert to PyArrow Table
table = pa.Table.from_pandas(df, schema=schema)
# Write Parquet file with efficient compression settings
pq.write_table(
table,
data_file,
compression='zstd', # Better compression ratio
compression_level=6, # Balance compression ratio and speed
use_dictionary=True, # Enable dictionary encoding
write_page_index=True, # Support fast metadata access
write_statistics=True # Enable statistics
)
# Save metadata
metadata = {
'dataset_id': dataset_id,
'description': description,
'stats': stats or {},
'created_at': datetime.now().isoformat(),
'file_format': 'parquet_v1',
'embedding_dimension': embedding_dim,
'total_records': len(dataset),
'columns': list(df.columns),
'file_size_bytes': data_file.stat().st_size,
'compression': 'zstd'
}
with open(metadata_file, 'w', encoding='utf-8') as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
# Update cache
self.metadata_cache[dataset_id] = metadata
logger.info(f"Saved dataset {dataset_id} to Parquet: {len(dataset)} records, {data_file.stat().st_size} bytes")
return True
except Exception as e:
logger.error(f"Failed to save dataset {dataset_id}: {e}")
return False
def load_dataset_metadata(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""
Quickly load dataset metadata (without loading the entire file)
Args:
dataset_id: Dataset ID
Returns:
Dict: Metadata, or None if not found
"""
# Check cache
if dataset_id in self.metadata_cache:
return self.metadata_cache[dataset_id]
# Load from file
_, metadata_file = self._get_dataset_files(dataset_id)
if not metadata_file.exists():
return None
try:
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update cache
self.metadata_cache[dataset_id] = metadata
return metadata
except Exception as e:
logger.error(f"Failed to load metadata for {dataset_id}: {e}")
return None
def load_dataset_partial(self, dataset_id: str, columns: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None) -> Optional[pd.DataFrame]:
"""
Partially load the dataset (only reading specified columns or rows meeting criteria)
Args:
dataset_id: Dataset ID
columns: Columns to read, None to read all
filters: Filtering conditions, format {column: value}
Returns:
pd.DataFrame: Loaded data, or None if failed
"""
data_file, _ = self._get_dataset_files(dataset_id)
if not data_file.exists():
return None
try:
# Read Parquet file, supporting column selection and filtering
if columns:
# Ensure necessary columns exist
all_columns = ['aid', 'label', 'inconsistent', 'text_checksum', 'embedding', 'metadata_json', 'user_labels_json']
required_cols = ['aid', 'label', 'embedding'] # These are fundamentally needed
columns = list(set(columns + required_cols))
# Filter out non-existent columns
columns = [col for col in columns if col in all_columns]
# Use pyarrow to read, supporting filters
if filters:
# Build filter expressions
expressions = []
for col, value in filters.items():
if col == 'label':
expressions.append(pa.compute.equal(pa.field(col), value))
elif col == 'aid':
expressions.append(pa.compute.equal(pa.field(col), value))
if expressions:
filter_expr = expressions[0]
for expr in expressions[1:]:
filter_expr = pa.compute.and_(filter_expr, expr)
else:
filter_expr = None
else:
filter_expr = None
# Read data
if columns and filter_expr:
table = pq.read_table(data_file, columns=columns, filter=filter_expr)
elif columns:
table = pq.read_table(data_file, columns=columns)
elif filter_expr:
table = pq.read_table(data_file, filter=filter_expr)
else:
table = pq.read_table(data_file)
# Convert to DataFrame
df = table.to_pandas()
# Handle embedding column
if 'embedding' in df.columns:
df['embedding'] = df['embedding'].apply(lambda x: x.tolist() if hasattr(x, 'tolist') else list(x))
logger.info(f"Loaded partial dataset {dataset_id}: {len(df)} rows, {len(df.columns)} columns")
return df
except Exception as e:
logger.error(f"Failed to load partial dataset {dataset_id}: {e}")
return None
def load_dataset_full(self, dataset_id: str) -> Optional[Dict[str, Any]]:
"""
Fully load the dataset (maintaining backward compatibility format)
Args:
dataset_id: Dataset ID
Returns:
Dict: Full dataset data, or None if failed
"""
data_file, _ = self._get_dataset_files(dataset_id)
if not data_file.exists():
return None
try:
# Load metadata
metadata = self.load_dataset_metadata(dataset_id)
if not metadata:
return None
# Load data
df = self.load_dataset_partial(dataset_id)
if df is None:
return None
# Convert to original format
dataset = []
for _, row in df.iterrows():
record = {
'aid': int(row['aid']),
'embedding': row['embedding'],
'label': bool(row['label']),
'metadata': json.loads(row['metadata_json']) if row['metadata_json'] else {},
'user_labels': json.loads(row['user_labels_json']) if row['user_labels_json'] else [],
'inconsistent': bool(row['inconsistent']),
'text_checksum': row['text_checksum']
}
dataset.append(record)
return {
'dataset': dataset,
'description': metadata.get('description'),
'stats': metadata.get('stats', {}),
'created_at': metadata.get('created_at')
}
except Exception as e:
logger.error(f"Failed to load full dataset {dataset_id}: {e}")
return None
def dataset_exists(self, dataset_id: str) -> bool:
"""Check if the dataset exists"""
data_file, _ = self._get_dataset_files(dataset_id)
return data_file.exists()
def delete_dataset(self, dataset_id: str) -> bool:
"""Delete a dataset"""
try:
data_file, metadata_file = self._get_dataset_files(dataset_id)
# Delete files
if data_file.exists():
data_file.unlink()
if metadata_file.exists():
metadata_file.unlink()
# Remove from cache
if dataset_id in self.metadata_cache:
del self.metadata_cache[dataset_id]
logger.info(f"Deleted dataset {dataset_id}")
return True
except Exception as e:
logger.error(f"Failed to delete dataset {dataset_id}: {e}")
return False
def list_datasets(self) -> List[Dict[str, Any]]:
"""List metadata for all datasets"""
datasets = []
for dataset_id, metadata in self.metadata_cache.items():
datasets.append({
"dataset_id": dataset_id,
"description": metadata.get("description"),
"stats": metadata.get("stats", {}),
"created_at": metadata.get("created_at"),
"total_records": metadata.get("total_records", 0),
"file_size_mb": round(metadata.get("file_size_bytes", 0) / (1024 * 1024), 2),
"embedding_dimension": metadata.get("embedding_dimension"),
"file_format": metadata.get("file_format")
})
# Sort by creation time descending
datasets.sort(key=lambda x: x["created_at"], reverse=True)
return datasets
def get_dataset_stats(self) -> Dict[str, Any]:
"""Get overall statistics"""
total_datasets = len(self.metadata_cache)
total_records = sum(m.get("total_records", 0) for m in self.metadata_cache.values())
total_size_bytes = sum(m.get("file_size_bytes", 0) for m in self.metadata_cache.values())
return {
"total_datasets": total_datasets,
"total_records": total_records,
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
"average_size_mb": round(total_size_bytes / total_datasets / (1024 * 1024), 2) if total_datasets > 0 else 0,
"storage_directory": str(self.storage_dir),
"storage_format": "parquet_v1"
}
def migrate_from_json(self, dataset_id: str, json_data: Dict[str, Any]) -> bool:
"""
Migrate a dataset from JSON format to Parquet format
Args:
dataset_id: Dataset ID
json_data: Data in JSON format
Returns:
bool: Migration success status
"""
try:
dataset = json_data.get('dataset', [])
description = json_data.get('description')
stats = json_data.get('stats')
return self.save_dataset(dataset_id, dataset, description, stats)
except Exception as e:
logger.error(f"Failed to migrate dataset {dataset_id} from JSON: {e}")
return False

View File

@ -125,7 +125,8 @@ class EmbeddingService:
self,
texts: List[str],
model: str,
batch_size: Optional[int] = None
batch_size: Optional[int] = None,
progress_callback: Optional[callable] = None
) -> List[List[float]]:
"""Generate embeddings for a batch of texts"""
@ -137,9 +138,9 @@ class EmbeddingService:
# Handle different model types
if model_config.type == "legacy":
return self._generate_legacy_embeddings_batch(texts, model, batch_size)
return self._generate_legacy_embeddings_batch(texts, model, batch_size, progress_callback)
elif model_config.type == "openai-compatible":
return await self._generate_openai_embeddings_batch(texts, model, batch_size)
return await self._generate_openai_embeddings_batch(texts, model, batch_size, progress_callback=progress_callback)
else:
raise ValueError(f"Unsupported model type: {model_config.type}")
@ -147,7 +148,8 @@ class EmbeddingService:
self,
texts: List[str],
model: str,
batch_size: Optional[int] = None
batch_size: Optional[int] = None,
progress_callback: Optional[callable] = None
) -> List[List[float]]:
"""Generate embeddings using legacy ONNX model"""
if model not in self.legacy_models:
@ -162,9 +164,13 @@ class EmbeddingService:
expected_dims = model_config.dimensions
all_embeddings = []
# Calculate total batches for progress tracking
total_batches = (len(texts) + batch_size - 1) // batch_size
# Process in batches
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
batch_idx = i // batch_size + 1
try:
# Generate embeddings using legacy method
@ -174,23 +180,33 @@ class EmbeddingService:
batch_embeddings = embeddings.tolist()
all_embeddings.extend(batch_embeddings)
logger.info(f"Generated legacy embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
logger.info(f"Generated legacy embeddings for batch {batch_idx}/{total_batches}")
# Update progress if callback provided
if progress_callback:
progress_callback(batch_idx, total_batches)
except Exception as e:
logger.error(f"Error generating legacy embeddings for batch {i//batch_size + 1}: {e}")
logger.error(f"Error generating legacy embeddings for batch {batch_idx}: {e}")
# Fill with zeros as fallback
zero_embedding = [0.0] * expected_dims
all_embeddings.extend([zero_embedding] * len(batch))
# Final progress update
if progress_callback:
progress_callback(total_batches, total_batches)
return all_embeddings
async def _generate_openai_embeddings_batch(
self,
texts: List[str],
model: str,
batch_size: Optional[int] = None
batch_size: Optional[int] = None,
max_concurrent: int = 6,
progress_callback: Optional[callable] = None
) -> List[List[float]]:
"""Generate embeddings using OpenAI-compatible API"""
"""Generate embeddings using OpenAI-compatible API with parallel requests"""
model_config = self.embedding_models[model]
# Use model's max_batch_size if not specified
@ -204,34 +220,71 @@ class EmbeddingService:
raise ValueError(f"No client configured for model '{model}'")
client = self.clients[model]
all_embeddings = []
# Process in batches to avoid API limits
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
try:
# Rate limiting
if i > 0:
await asyncio.sleep(self.request_interval)
# Generate embeddings
response = await client.embeddings.create(
model=model_config.name,
input=batch,
dimensions=expected_dims
)
batch_embeddings = [data.embedding for data in response.data]
all_embeddings.extend(batch_embeddings)
logger.info(f"Generated embeddings for batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
except Exception as e:
logger.error(f"Error generating embeddings for batch {i//batch_size + 1}: {e}")
# For now, fill with zeros as fallback (could implement retry logic)
zero_embedding = [0.0] * expected_dims
all_embeddings.extend([zero_embedding] * len(batch))
# Split texts into batches
batches = [(i, texts[i:i + batch_size]) for i in range(0, len(texts), batch_size)]
total_batches = len(batches)
# Track completed batches for progress reporting
completed_batches = 0
completed_batches_lock = asyncio.Lock()
# Semaphore for concurrency control
semaphore = asyncio.Semaphore(max_concurrent)
async def process_batch(batch_idx: int, batch: List[str]) -> tuple[int, List[List[float]]]:
"""Process a single batch with semaphore-controlled concurrency"""
nonlocal completed_batches
async with semaphore:
try:
# Rate limiting
if batch_idx > 0:
await asyncio.sleep(self.request_interval)
# Generate embeddings
response = await client.embeddings.create(
model=model_config.name,
input=batch,
dimensions=expected_dims
)
batch_embeddings = [data.embedding for data in response.data]
logger.info(f"Generated embeddings for batch {batch_idx//batch_size + 1}/{total_batches}")
# Update completed batch count and progress
async with completed_batches_lock:
completed_batches += 1
if progress_callback:
progress_callback(completed_batches, total_batches)
return (batch_idx, batch_embeddings)
except Exception as e:
logger.error(f"Error generating embeddings for batch {batch_idx//batch_size + 1}: {e}")
# Fill with zeros as fallback
zero_embedding = [0.0] * expected_dims
# Update completed batch count and progress even for failed batches
async with completed_batches_lock:
completed_batches += 1
if progress_callback:
progress_callback(completed_batches, total_batches)
return (batch_idx, [zero_embedding] * len(batch))
# Process all batches concurrently with semaphore limit
tasks = [process_batch(i, batch) for i, batch in batches]
results = await asyncio.gather(*tasks)
# Final progress update
if progress_callback:
progress_callback(total_batches, total_batches)
# Sort results by batch index and flatten
results.sort(key=lambda x: x[0])
all_embeddings = []
for _, embeddings in results:
all_embeddings.extend(embeddings)
return all_embeddings
@ -268,7 +321,8 @@ class EmbeddingService:
test_embedding = await self.generate_embeddings_batch(
["health check"],
model_name,
batch_size=1
batch_size=1,
progress_callback=None
)
return {

View File

@ -13,4 +13,5 @@ sqlalchemy==2.0.23
asyncpg==0.29.0
toml==0.10.2
aiohttp==3.9.0
python-dotenv==1.1.0
python-dotenv==1.1.0
pyarrow==15.0.0

View File

@ -30,7 +30,7 @@ export function TaskMonitor() {
const params = statusFilter === "all" ? {} : { status: statusFilter };
return apiClient.getTasks(params.status, 50);
},
refetchInterval: 5000 // Refresh every 5 seconds
refetchInterval: 500
});
const getStatusIcon = (status: string) => {