From a4ec4ca01c8554592ac9904c5da9926dcfb23c12 Mon Sep 17 00:00:00 2001 From: alikia2x Date: Wed, 10 Dec 2025 18:39:37 +0800 Subject: [PATCH] add: frontent for new ML: ml_panel --- bun.lock | 11 + ml_new/training/.tokeignore | 1 + ml_new/training/api_routes.py | 154 ++++++-- ml_new/training/database.py | 208 +++++------ ml_new/training/dataset_service.py | 2 + ml_new/training/models.py | 112 ++++-- package.json | 1 + packages/ml_panel/.tokeignore | 1 + packages/ml_panel/README.md | 36 +- packages/ml_panel/components.json | 40 +-- packages/ml_panel/package.json | 6 + packages/ml_panel/src/App.tsx | 69 +++- .../src/components/DatasetManager.tsx | 337 ++++++++++++++++++ .../ml_panel/src/components/SamplingPanel.tsx | 234 ++++++++++++ .../ml_panel/src/components/TaskMonitor.tsx | 210 +++++++++++ packages/ml_panel/src/components/ui/alert.tsx | 60 ++++ packages/ml_panel/src/components/ui/badge.tsx | 40 +++ .../ml_panel/src/components/ui/button.tsx | 58 +++ packages/ml_panel/src/components/ui/card.tsx | 78 ++++ .../ml_panel/src/components/ui/dialog.tsx | 127 +++++++ packages/ml_panel/src/components/ui/input.tsx | 21 ++ packages/ml_panel/src/components/ui/label.tsx | 21 ++ .../ml_panel/src/components/ui/progress.tsx | 29 ++ .../ml_panel/src/components/ui/select.tsx | 172 +++++++++ .../ml_panel/src/components/ui/spinner.tsx | 16 + packages/ml_panel/src/components/ui/tabs.tsx | 52 +++ .../ml_panel/src/components/ui/textarea.tsx | 18 + packages/ml_panel/src/index.css | 210 +++++------ packages/ml_panel/src/lib/api.ts | 144 ++++++++ packages/ml_panel/src/lib/utils.ts | 6 +- packages/ml_panel/src/types/api.ts | 125 +++++++ 31 files changed, 2271 insertions(+), 328 deletions(-) create mode 100644 ml_new/training/.tokeignore create mode 100644 packages/ml_panel/.tokeignore create mode 100644 packages/ml_panel/src/components/DatasetManager.tsx create mode 100644 packages/ml_panel/src/components/SamplingPanel.tsx create mode 100644 packages/ml_panel/src/components/TaskMonitor.tsx create mode 100644 packages/ml_panel/src/components/ui/alert.tsx create mode 100644 packages/ml_panel/src/components/ui/badge.tsx create mode 100644 packages/ml_panel/src/components/ui/button.tsx create mode 100644 packages/ml_panel/src/components/ui/card.tsx create mode 100644 packages/ml_panel/src/components/ui/dialog.tsx create mode 100644 packages/ml_panel/src/components/ui/input.tsx create mode 100644 packages/ml_panel/src/components/ui/label.tsx create mode 100644 packages/ml_panel/src/components/ui/progress.tsx create mode 100644 packages/ml_panel/src/components/ui/select.tsx create mode 100644 packages/ml_panel/src/components/ui/spinner.tsx create mode 100644 packages/ml_panel/src/components/ui/tabs.tsx create mode 100644 packages/ml_panel/src/components/ui/textarea.tsx create mode 100644 packages/ml_panel/src/lib/api.ts create mode 100644 packages/ml_panel/src/types/api.ts diff --git a/bun.lock b/bun.lock index f215c8b..df0a796 100644 --- a/bun.lock +++ b/bun.lock @@ -4,6 +4,7 @@ "": { "name": "cvsa", "dependencies": { + "@tanstack/react-query": "^5.90.12", "arg": "^5.0.2", "dotenv": "^17.2.3", "drizzle-orm": "^0.44.7", @@ -84,6 +85,12 @@ "name": "ml_panel", "version": "0.0.0", "dependencies": { + "@radix-ui/react-dialog": "^1.1.15", + "@radix-ui/react-label": "^2.1.8", + "@radix-ui/react-progress": "^1.1.8", + "@radix-ui/react-select": "^2.2.6", + "@radix-ui/react-slot": "^1.2.4", + "@radix-ui/react-tabs": "^1.1.13", "@tailwindcss/vite": "^4.1.17", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", @@ -927,6 +934,10 @@ "@tailwindcss/vite": ["@tailwindcss/vite@4.1.17", "", { "dependencies": { "@tailwindcss/node": "4.1.17", "@tailwindcss/oxide": "4.1.17", "tailwindcss": "4.1.17" }, "peerDependencies": { "vite": "^5.2.0 || ^6 || ^7" } }, "sha512-4+9w8ZHOiGnpcGI6z1TVVfWaX/koK7fKeSYF3qlYg2xpBtbteP2ddBxiarL+HVgfSJGeK5RIxRQmKm4rTJJAwA=="], + "@tanstack/query-core": ["@tanstack/query-core@5.90.12", "", {}, "sha512-T1/8t5DhV/SisWjDnaiU2drl6ySvsHj1bHBCWNXd+/T+Hh1cf6JodyEYMd5sgwm+b/mETT4EV3H+zCVczCU5hg=="], + + "@tanstack/react-query": ["@tanstack/react-query@5.90.12", "", { "dependencies": { "@tanstack/query-core": "5.90.12" }, "peerDependencies": { "react": "^18 || ^19" } }, "sha512-graRZspg7EoEaw0a8faiUASCyJrqjKPdqJ9EwuDRUF9mEYJ1YPczI9H+/agJ0mOJkPCJDk0lsz5QTrLZ/jQ2rg=="], + "@tokenizer/inflate": ["@tokenizer/inflate@0.2.7", "", { "dependencies": { "debug": "^4.4.0", "fflate": "^0.8.2", "token-types": "^6.0.0" } }, "sha512-MADQgmZT1eKjp06jpI2yozxaU9uVs4GzzgSL+uEq7bVcJ9V1ZXQkeGNql1fsSI0gMy1vhvNTNbUqrx+pZfJVmg=="], "@tokenizer/token": ["@tokenizer/token@0.3.0", "", {}, "sha512-OvjF+z51L3ov0OyAU0duzsYuvO01PH7x4t6DJx+guahgTnBHkhJdG7soQeTSFLWN3efnHyibZ4Z8l2EuWwJN3A=="], diff --git a/ml_new/training/.tokeignore b/ml_new/training/.tokeignore new file mode 100644 index 0000000..4027c18 --- /dev/null +++ b/ml_new/training/.tokeignore @@ -0,0 +1 @@ +*.toml \ No newline at end of file diff --git a/ml_new/training/api_routes.py b/ml_new/training/api_routes.py index 5c772bc..937ead9 100644 --- a/ml_new/training/api_routes.py +++ b/ml_new/training/api_routes.py @@ -9,7 +9,17 @@ from fastapi import APIRouter, HTTPException from fastapi.responses import JSONResponse from config_loader import config_loader -from models import DatasetBuildRequest, DatasetBuildResponse, TaskStatus, TaskStatusResponse, TaskListResponse +from models import ( + DatasetBuildRequest, + DatasetBuildResponse, + TaskStatus, + TaskStatusResponse, + TaskListResponse, + SamplingRequest, + SamplingResponse, + DatasetCreateRequest, + DatasetCreateResponse +) from dataset_service import DatasetBuilder from logger_config import get_logger @@ -135,39 +145,6 @@ async def get_dataset_endpoint(dataset_id: str): } -@router.get("/datasets") -async def list_datasets(): - """List all built datasets""" - - if not dataset_builder: - raise HTTPException(status_code=503, detail="Dataset builder not available") - - datasets = [] - for dataset_id, dataset_info in dataset_builder.dataset_storage.items(): - if "error" not in dataset_info: - datasets.append({ - "dataset_id": dataset_id, - "description": dataset_info.get("description"), - "stats": dataset_info["stats"], - "created_at": dataset_info["created_at"] - }) - - return {"datasets": datasets} - - -@router.delete("/dataset/{dataset_id}") -async def delete_dataset_endpoint(dataset_id: str): - """Delete a built dataset""" - - if not dataset_builder: - raise HTTPException(status_code=503, detail="Dataset builder not available") - - if dataset_builder.delete_dataset(dataset_id): - return {"message": f"Dataset {dataset_id} deleted successfully"} - else: - raise HTTPException(status_code=404, detail="Dataset not found") - - @router.get("/datasets") async def list_datasets_endpoint(): """List all built datasets""" @@ -188,6 +165,19 @@ async def list_datasets_endpoint(): return {"datasets": datasets_with_description} +@router.delete("/dataset/{dataset_id}") +async def delete_dataset_endpoint(dataset_id: str): + """Delete a built dataset""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + if dataset_builder.delete_dataset(dataset_id): + return {"message": f"Dataset {dataset_id} deleted successfully"} + else: + raise HTTPException(status_code=404, detail="Dataset not found") + + @router.get("/datasets/stats") async def get_dataset_stats_endpoint(): """Get overall statistics about stored datasets""" @@ -314,4 +304,98 @@ async def cleanup_tasks_endpoint(max_age_hours: int = 24): raise HTTPException(status_code=503, detail="Dataset builder not available") cleaned_count = await dataset_builder.cleanup_completed_tasks(max_age_hours) - return {"message": f"Cleaned up {cleaned_count} tasks older than {max_age_hours} hours"} \ No newline at end of file + return {"message": f"Cleaned up {cleaned_count} tasks older than {max_age_hours} hours"} + + +# Sampling Endpoints + +@router.post("/dataset/sample", response_model=SamplingResponse) +async def sample_dataset_endpoint(request: SamplingRequest): + """Sample AIDs based on strategy""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + try: + # Get AIDs based on strategy + aid_list = await dataset_builder.db_manager.get_aids_by_strategy( + strategy=request.strategy, + limit=request.limit, + ) + + # Get statistics + total_available = await dataset_builder.db_manager.get_all_aids_count() + + return SamplingResponse( + strategy=request.strategy, + total_available=total_available, + sampled_count=len(aid_list), + aid_list=aid_list, + filters_applied={ + "limit": request.limit + }, + sampling_info={ + "strategy_description": _get_strategy_description(request.strategy), + "sample_ratio": len(aid_list) / total_available if total_available > 0 else 0 + } + ) + + except Exception as e: + logger.error(f"Sampling failed: {str(e)}") + raise HTTPException(status_code=400, detail=f"Sampling failed: {str(e)}") + +@router.post("/dataset/create-with-sampling", response_model=DatasetCreateResponse) +async def create_dataset_with_sampling_endpoint(request: DatasetCreateRequest): + """Create dataset using sampling strategy""" + + if not dataset_builder: + raise HTTPException(status_code=503, detail="Dataset builder not available") + + # Validate embedding model + if request.embedding_model not in config_loader.get_embedding_models(): + raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}") + + import uuid + dataset_id = str(uuid.uuid4()) + + try: + # First sample the AIDs + sampling_response = await sample_dataset_endpoint(request.sampling) + aid_list = sampling_response.aid_list + + if not aid_list: + raise HTTPException(status_code=400, detail="No AIDs found matching the sampling criteria") + + # Start task-based dataset building with sampled AIDs + task_id = await dataset_builder.start_dataset_build_task( + dataset_id, + aid_list, + request.embedding_model, + request.force_regenerate, + request.description + ) + + return DatasetCreateResponse( + dataset_id=dataset_id, + sampling_response=sampling_response, + task_id=task_id, + total_records=len(aid_list), + status="started", + message=f"Dataset building started with task ID: {task_id}", + description=request.description + ) + + except HTTPException: + raise + except Exception as e: + logger.error(f"Dataset creation with sampling failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"Dataset creation failed: {str(e)}") + + +def _get_strategy_description(strategy: str) -> str: + """Get description for sampling strategy""" + descriptions = { + "all": "All labeled videos in the database", + "random": "Randomly sampled labeled videos" + } + return descriptions.get(strategy, "Unknown sampling strategy") \ No newline at end of file diff --git a/ml_new/training/database.py b/ml_new/training/database.py index c0a73de..b873656 100644 --- a/ml_new/training/database.py +++ b/ml_new/training/database.py @@ -2,6 +2,7 @@ Database connection and operations for ML training service """ +from collections import defaultdict import os import hashlib from typing import List, Dict, Optional, Any @@ -18,6 +19,7 @@ logger = get_logger(__name__) # Database configuration DATABASE_URL = os.getenv("DATABASE_URL") + class DatabaseManager: def __init__(self): self.pool: Optional[asyncpg.Pool] = None @@ -26,7 +28,7 @@ class DatabaseManager: """Initialize database connection pool""" try: self.pool = await asyncpg.create_pool(DATABASE_URL, min_size=5, max_size=20) - + logger.info("Database connection pool initialized") except Exception as e: logger.error(f"Failed to connect to database: {e}") @@ -37,7 +39,7 @@ class DatabaseManager: if self.pool: await self.pool.close() logger.info("Database connection pool closed") - + @property def is_connected(self) -> bool: """Check if database connection pool is initialized""" @@ -129,12 +131,24 @@ class DatabaseManager: result = {} for row in rows: checksum = row["data_checksum"] - + # Convert vector strings to lists if they exist - vec_2048 = self._parse_vector_string(row["vec_2048"]) if row["vec_2048"] else None - vec_1536 = self._parse_vector_string(row["vec_1536"]) if row["vec_1536"] else None - vec_1024 = self._parse_vector_string(row["vec_1024"]) if row["vec_1024"] else None - + vec_2048 = ( + self._parse_vector_string(row["vec_2048"]) + if row["vec_2048"] + else None + ) + vec_1536 = ( + self._parse_vector_string(row["vec_1536"]) + if row["vec_1536"] + else None + ) + vec_1024 = ( + self._parse_vector_string(row["vec_1024"]) + if row["vec_1024"] + else None + ) + result[checksum] = { "checksum": checksum, "vec_2048": vec_2048, @@ -144,36 +158,45 @@ class DatabaseManager: } return result - + def _parse_vector_string(self, vector_str: str) -> List[float]: """Parse vector string format '[1.0,2.0,3.0]' back to list""" if not vector_str: return [] - + try: # Remove brackets and split by comma vector_str = vector_str.strip() - if vector_str.startswith('[') and vector_str.endswith(']'): + if vector_str.startswith("[") and vector_str.endswith("]"): vector_str = vector_str[1:-1] - - return [float(x.strip()) for x in vector_str.split(',') if x.strip()] + + return [float(x.strip()) for x in vector_str.split(",") if x.strip()] except Exception as e: logger.warning(f"Failed to parse vector string '{vector_str}': {e}") return [] async def insert_embeddings(self, embeddings_data: List[Dict[str, Any]]) -> None: - """Batch insert embeddings into database""" + """Batch insert embeddings into database (Optimized)""" if not embeddings_data: return + batches = defaultdict(list) + now = datetime.now() + + for data in embeddings_data: + vector_str = str(data["vector"]) + # "[" + ",".join(map(str, data["vector"])) + "]" + + dim = data["dimensions"] + + batches[dim].append( + (data["model_name"], dim, data["checksum"], vector_str, now) + ) + async with self.pool.acquire() as conn: async with conn.transaction(): - for data in embeddings_data: - # Determine which vector column to use based on dimensions - vec_column = f"vec_{data['dimensions']}" - - # Convert vector list to string format for PostgreSQL - vector_str = "[" + ",".join(map(str, data["vector"])) + "]" + for dim, values in batches.items(): + vec_column = f"vec_{dim}" query = f""" INSERT INTO internal.embeddings @@ -182,103 +205,70 @@ class DatabaseManager: ON CONFLICT (model_name, dimensions, data_checksum) DO NOTHING """ - await conn.execute( - query, - data["model_name"], - data["dimensions"], - data["checksum"], - vector_str, - datetime.now(), - ) + await conn.executemany(query, values) - async def get_final_dataset( - self, aid_list: List[int], model_name: str - ) -> List[Dict[str, Any]]: - """Get final dataset with embeddings and labels""" - if not aid_list: - return [] + # Sampling Methods + async def get_all_aids(self) -> List[int]: + """Get all available AIDs from labeled data (internal.video_type_label)""" + async with self.pool.acquire() as conn: + query = "SELECT DISTINCT aid FROM internal.video_type_label WHERE aid IS NOT NULL" + rows = await conn.fetch(query) + return [int(row["aid"]) for row in rows] + + async def get_all_aids_count(self) -> List[int]: + """Get all available AIDs from labeled data (internal.video_type_label)""" + async with self.pool.acquire() as conn: + query = "SELECT COUNT(DISTINCT aid) FROM internal.video_type_label WHERE aid IS NOT NULL" + rows = await conn.fetch(query) + return rows[0]["count"] - # Get video metadata - metadata = await self.get_video_metadata(aid_list) + async def get_aids_by_strategy( + self, strategy: str, limit: Optional[int] = None + ) -> List[int]: + """Get AIDs based on sampling strategy""" + if strategy == "all": + return await self.get_all_aids() + elif strategy == "random": + return await self.get_random_aids(limit or 1000) + else: + raise ValueError(f"Unknown sampling strategy: {strategy}") - # Get user labels (latest per user) - labels = await self.get_user_labels(aid_list) + async def get_random_aids( + self, limit: int + ) -> List[int]: + """Get random AIDs from labeled data only""" + async with self.pool.acquire() as conn: + query = "SELECT aid FROM internal.video_type_label ORDER BY RANDOM() LIMIT $1" + rows = await conn.fetch(query, limit) + aids = [int(row["aid"]) for row in rows] + # deduplication + return list(set(aids)) - # Prepare text data for embedding - text_data = [] - aid_to_text = {} - - for aid in aid_list: - if aid in metadata: - # Combine title, description, and tags for embedding - text_parts = [ - metadata[aid]["title"], - metadata[aid]["description"], - metadata[aid]["tags"], - ] - combined_text = " ".join(filter(None, text_parts)) - - # Create checksum for deduplication - checksum = hashlib.md5(combined_text.encode("utf-8")).hexdigest() - - text_data.append( - {"aid": aid, "text": combined_text, "checksum": checksum} - ) - aid_to_text[checksum] = aid - - # Get checksums = [ existing embeddings - checks = [item["checksum"] for item in text_data] - existing_embeddings = await self.get_existing_embeddings(checks) - - # ums, model_name Prepare final dataset - dataset = [] - - for item in text_data: - aid = item["aid"] - checksum = item["checksum"] - - # Get embedding vector - embedding_vector = None - if checksum in existing_embeddings: - # Use existing embedding - emb_data = existing_embeddings[checksum] - if emb_data["vec_1536"]: - embedding_vector = emb_data["vec_1536"] - elif emb_data["vec_2048"]: - embedding_vector = emb_data["vec_2048"] - elif emb_data["vec_1024"]: - embedding_vector = emb_data["vec_1024"] - - # Get labels for this aid - aid_labels = labels.get(aid, []) - - # Determine final label using consensus (majority vote) - if aid_labels: - positive_votes = sum(1 for lbl in aid_labels if lbl["label"]) - final_label = positive_votes > len(aid_labels) / 2 - else: - final_label = None # No labels available - - # Check for inconsistent labels - inconsistent = len(aid_labels) > 1 and ( - sum(1 for lbl in aid_labels if lbl["label"]) != 0 - and sum(1 for lbl in aid_labels if lbl["label"]) != len(aid_labels) + async def get_sampling_stats(self) -> Dict[str, Any]: + """Get statistics about available labeled data for sampling""" + async with self.pool.acquire() as conn: + # Total labeled videos + total_labeled_query = ( + "SELECT COUNT(DISTINCT aid) as count FROM internal.video_type_label" ) + total_labeled_result = await conn.fetchrow(total_labeled_query) + total_labeled_videos = total_labeled_result["count"] - if embedding_vector and final_label is not None: - dataset.append( - { - "aid": aid, - "embedding": embedding_vector, - "label": final_label, - "metadata": metadata.get(aid, {}), - "user_labels": aid_labels, - "inconsistent": inconsistent, - "text_checksum": checksum, - } - ) + # Positive and negative labels + positive_query = "SELECT COUNT(DISTINCT aid) as count FROM internal.video_type_label WHERE label = true" + negative_query = "SELECT COUNT(DISTINCT aid) as count FROM internal.video_type_label WHERE label = false" - return dataset + positive_result = await conn.fetchrow(positive_query) + negative_result = await conn.fetchrow(negative_query) + + positive_labels = positive_result["count"] + negative_labels = negative_result["count"] + + return { + "total_labeled_videos": total_labeled_videos, + "positive_labels": positive_labels, + "negative_labels": negative_labels, + } # Global database manager instance diff --git a/ml_new/training/dataset_service.py b/ml_new/training/dataset_service.py index fc55e17..bf7ff81 100644 --- a/ml_new/training/dataset_service.py +++ b/ml_new/training/dataset_service.py @@ -570,6 +570,8 @@ class DatasetBuilder: """List all datasets with their basic information""" datasets = [] + self._load_all_datasets() + for dataset_id, dataset_info in self.dataset_storage.items(): if "error" not in dataset_info: datasets.append({ diff --git a/ml_new/training/models.py b/ml_new/training/models.py index ee027d5..1dc1db1 100644 --- a/ml_new/training/models.py +++ b/ml_new/training/models.py @@ -2,9 +2,51 @@ Data models for dataset building functionality """ -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, Literal from pydantic import BaseModel, Field from datetime import datetime +from enum import Enum + + +class TaskStatus(str, Enum): + """Task status enumeration""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class SamplingStrategy(str, Enum): + """Sampling strategy enumeration""" + ALL = "all" # All labeled AIDs + RANDOM = "random" # Random sampling from labeled data + + +class TaskProgress(BaseModel): + """Progress information for a task""" + current_step: str + total_steps: int + completed_steps: int + percentage: float + message: Optional[str] = None + estimated_time_remaining: Optional[float] = None + + +class DatasetBuildTaskStatus(BaseModel): + """Status model for dataset building task""" + task_id: str + status: TaskStatus + dataset_id: Optional[str] = None + aid_list: List[int] + embedding_model: str + force_regenerate: bool + progress: Optional[TaskProgress] = None + error_message: Optional[str] = None + created_at: datetime + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + result: Optional[Dict[str, Any]] = None class DatasetBuildRequest(BaseModel): @@ -64,45 +106,41 @@ class EmbeddingModelInfo(BaseModel): max_batch_size: Optional[int] = None -from typing import List, Optional, Dict, Any, Literal -from pydantic import BaseModel, Field -from datetime import datetime -from enum import Enum +# Sampling and Dataset Selection Models + +class SamplingRequest(BaseModel): + """Request model for dataset sampling""" + strategy: SamplingStrategy = Field(..., description="Sampling strategy to use") + limit: Optional[int] = Field(None, description="Maximum number of AIDs to sample (for random sampling)") -class TaskStatus(str, Enum): - """Task status enumeration""" - PENDING = "pending" - RUNNING = "running" - COMPLETED = "completed" - FAILED = "failed" - CANCELLED = "cancelled" - - -class TaskProgress(BaseModel): - """Progress information for a task""" - current_step: str - total_steps: int - completed_steps: int - percentage: float - message: Optional[str] = None - estimated_time_remaining: Optional[float] = None - - -class DatasetBuildTaskStatus(BaseModel): - """Status model for dataset building task""" - task_id: str - status: TaskStatus - dataset_id: Optional[str] = None +class SamplingResponse(BaseModel): + """Response model for dataset sampling""" + strategy: SamplingStrategy + total_available: int + sampled_count: int aid_list: List[int] - embedding_model: str - force_regenerate: bool - progress: Optional[TaskProgress] = None - error_message: Optional[str] = None - created_at: datetime - started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - result: Optional[Dict[str, Any]] = None + filters_applied: Optional[Dict[str, Any]] = None + sampling_info: Dict[str, Any] + + +class DatasetCreateRequest(BaseModel): + """Request model for creating dataset with sampling""" + sampling: SamplingRequest = Field(..., description="Sampling configuration") + embedding_model: str = Field(..., description="Embedding model name") + force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings") + description: Optional[str] = Field(None, description="Optional description for the dataset") + + +class DatasetCreateResponse(BaseModel): + """Response model for dataset creation""" + dataset_id: str + sampling_response: SamplingResponse + task_id: str + total_records: int + status: str + message: str + description: Optional[str] = None class TaskStatusResponse(BaseModel): diff --git a/package.json b/package.json index 051bf7f..20044b4 100644 --- a/package.json +++ b/package.json @@ -18,6 +18,7 @@ } }, "dependencies": { + "@tanstack/react-query": "^5.90.12", "arg": "^5.0.2", "dotenv": "^17.2.3", "drizzle-orm": "^0.44.7", diff --git a/packages/ml_panel/.tokeignore b/packages/ml_panel/.tokeignore new file mode 100644 index 0000000..505542f --- /dev/null +++ b/packages/ml_panel/.tokeignore @@ -0,0 +1 @@ +src/components/ui \ No newline at end of file diff --git a/packages/ml_panel/README.md b/packages/ml_panel/README.md index e1ff783..98e3e49 100644 --- a/packages/ml_panel/README.md +++ b/packages/ml_panel/README.md @@ -170,7 +170,7 @@ graph TD ```mermaid graph TD A[PostgreSQL 远程数据库
RTT: 100ms] --> B[videoTypeLabelInInternal] - A --> C[embeddingsInInternal] + A --> C[embeddingsInInternal] A --> D[bilibiliMetadata] B --> E[批量获取用户最后一次标注
避免循环查询] C --> F[批量获取嵌入向量
一次性查询所有维度] @@ -226,32 +226,32 @@ graph TD ### Phase 1: 核心功能 (优先) 1. **数据管线实现** - - 标注数据获取和一致性检查 - - 嵌入向量生成和存储 - - 数据集构建逻辑 + - 标注数据获取和一致性检查 + - 嵌入向量生成和存储 + - 数据集构建逻辑 2. **FastAPI 服务完善** - - 构建新的模型架构(输入嵌入向量,直接二分类头) - - 迁移现有 ml/filter 训练逻辑 - - 实现超参数动态暴露 - - 集成 OpenAI 兼容嵌入 API - - 训练任务队列管理 + - 构建新的模型架构(输入嵌入向量,直接二分类头) + - 迁移现有 ml/filter 训练逻辑 + - 实现超参数动态暴露 + - 集成 OpenAI 兼容嵌入 API + - 训练任务队列管理 ### Phase 2: 用户界面 1. **数据集创建界面** - - 嵌入模型选择 - - 数据预览和筛选 - - 处理进度显示 + - 嵌入模型选择 + - 数据预览和筛选 + - 处理进度显示 2. **训练参数配置界面** - - 超参数动态渲染 - - 参数验证和约束 + - 超参数动态渲染 + - 参数验证和约束 3. **实验管理和追踪** - - 实验历史和比较 - - 训练状态实时监控 - - 结果可视化 + - 实验历史和比较 + - 训练状态实时监控 + - 结果可视化 ### Phase 3: 高级功能 @@ -264,4 +264,4 @@ graph TD 1. **数据库性能**: 远程数据库 RTT 高,避免 N+1 查询,使用批量操作 2. **标注一致性**: 实现自动的标注不一致检测 -3. **嵌入模型支持**: 为未来扩展多种嵌入模型预留接口 \ No newline at end of file +3. **嵌入模型支持**: 为未来扩展多种嵌入模型预留接口 diff --git a/packages/ml_panel/components.json b/packages/ml_panel/components.json index 2b0833f..6555943 100644 --- a/packages/ml_panel/components.json +++ b/packages/ml_panel/components.json @@ -1,22 +1,22 @@ { - "$schema": "https://ui.shadcn.com/schema.json", - "style": "new-york", - "rsc": false, - "tsx": true, - "tailwind": { - "config": "", - "css": "src/index.css", - "baseColor": "neutral", - "cssVariables": true, - "prefix": "" - }, - "iconLibrary": "lucide", - "aliases": { - "components": "@/components", - "utils": "@/lib/utils", - "ui": "@/components/ui", - "lib": "@/lib", - "hooks": "@/hooks" - }, - "registries": {} + "$schema": "https://ui.shadcn.com/schema.json", + "style": "new-york", + "rsc": false, + "tsx": true, + "tailwind": { + "config": "", + "css": "src/index.css", + "baseColor": "neutral", + "cssVariables": true, + "prefix": "" + }, + "iconLibrary": "lucide", + "aliases": { + "components": "@/components", + "utils": "@/lib/utils", + "ui": "@/components/ui", + "lib": "@/lib", + "hooks": "@/hooks" + }, + "registries": {} } diff --git a/packages/ml_panel/package.json b/packages/ml_panel/package.json index 69b74a4..efe506d 100644 --- a/packages/ml_panel/package.json +++ b/packages/ml_panel/package.json @@ -10,6 +10,12 @@ "preview": "vite preview" }, "dependencies": { + "@radix-ui/react-dialog": "^1.1.15", + "@radix-ui/react-label": "^2.1.8", + "@radix-ui/react-progress": "^1.1.8", + "@radix-ui/react-select": "^2.2.6", + "@radix-ui/react-slot": "^1.2.4", + "@radix-ui/react-tabs": "^1.1.13", "@tailwindcss/vite": "^4.1.17", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", diff --git a/packages/ml_panel/src/App.tsx b/packages/ml_panel/src/App.tsx index 7ce28db..ee5737a 100644 --- a/packages/ml_panel/src/App.tsx +++ b/packages/ml_panel/src/App.tsx @@ -1,5 +1,72 @@ +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { CardDescription, CardTitle } from "@/components/ui/card"; +import { DatasetManager } from "@/components/DatasetManager"; +import { TaskMonitor } from "@/components/TaskMonitor"; +import { SamplingPanel } from "@/components/SamplingPanel"; +import { Database, Activity, Settings } from "lucide-react"; + +const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: 3, + refetchOnWindowFocus: false + } + } +}); + function App() { - return <>; + return ( + +
+
+
+

ML Dataset Management Panel

+

+ Create and manage machine learning datasets with multiple sampling strategies and task monitoring +

+
+ + + + + + Datasets + + + + Sampling + + + + Tasks + + + + + Dataset Management + View, create and manage your machine learning datasets + + + + + Sampling Strategy Configuration + + Configure different data sampling strategies to create balanced datasets + + + + + + Task Monitor + Monitor real-time status and progress of dataset building tasks + + + +
+
+
+ ); } export default App; diff --git a/packages/ml_panel/src/components/DatasetManager.tsx b/packages/ml_panel/src/components/DatasetManager.tsx new file mode 100644 index 0000000..bff1eaa --- /dev/null +++ b/packages/ml_panel/src/components/DatasetManager.tsx @@ -0,0 +1,337 @@ +import { useState } from "react"; +import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue +} from "@/components/ui/select"; +import { Textarea } from "@/components/ui/textarea"; +import { Trash2, Plus, Database, FileText, Calendar, Activity } from "lucide-react"; +import { apiClient } from "@/lib/api"; +import { toast } from "sonner"; +import { Spinner } from "@/components/ui/spinner" + +export function DatasetManager() { + const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); + const [createFormData, setCreateFormData] = useState({ + strategy: "all", + limit: "", + embeddingModel: "", + description: "", + forceRegenerate: false + }); + + const queryClient = useQueryClient(); + + // Fetch datasets + const { data: datasetsData, isLoading: datasetsLoading } = useQuery({ + queryKey: ["datasets"], + queryFn: () => apiClient.getDatasets(), + refetchInterval: 30000 // Refresh every 30 seconds + }); + + // Fetch embedding models + const { data: modelsData, isLoading: modelsLoading } = useQuery({ + queryKey: ["embedding-models"], + queryFn: () => apiClient.getEmbeddingModels() + }); + + // Create dataset mutation + const createDatasetMutation = useMutation({ + mutationFn: (data: any) => apiClient.createDatasetWithSampling(data), + onSuccess: () => { + toast.success("Dataset creation task started"); + setIsCreateDialogOpen(false); + setCreateFormData({ + strategy: "all", + limit: "", + embeddingModel: "", + description: "", + forceRegenerate: false + }); + queryClient.invalidateQueries({ queryKey: ["datasets"] }); + queryClient.invalidateQueries({ queryKey: ["tasks"] }); + }, + onError: (error: Error) => { + toast.error(`Creation failed: ${error.message}`); + } + }); + + // Delete dataset mutation + const deleteDatasetMutation = useMutation({ + mutationFn: (datasetId: string) => apiClient.deleteDataset(datasetId), + onSuccess: () => { + toast.success("Dataset deleted"); + queryClient.invalidateQueries({ queryKey: ["datasets"] }); + }, + onError: (error: Error) => { + toast.error(`Delete failed: ${error.message}`); + } + }); + + const handleCreateDataset = () => { + if (!createFormData.embeddingModel) { + toast.error("Please select an embedding model"); + return; + } + + const requestData = { + sampling: { + strategy: createFormData.strategy, + ...(createFormData.limit && { limit: parseInt(createFormData.limit) }) + }, + embedding_model: createFormData.embeddingModel, + force_regenerate: createFormData.forceRegenerate, + description: createFormData.description || undefined + }; + + createDatasetMutation.mutate(requestData); + }; + + const handleDeleteDataset = (datasetId: string) => { + if (window.confirm("Are you sure you want to delete this dataset?")) { + deleteDatasetMutation.mutate(datasetId); + } + }; + + const formatDate = (dateString: string) => { + return new Date(dateString).toLocaleString("en-US"); + }; + + const formatFileSize = (bytes: number) => { + if (bytes === 0) return "0 Bytes"; + const k = 1024; + const sizes = ["Bytes", "KB", "MB", "GB"]; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + " " + sizes[i]; + }; + + if (datasetsLoading) { + return ( +
+ +
+ ); + } + + return ( +
+ {/* Create Dataset Button */} +
+
+

Dataset List

+

+ {datasetsData?.datasets?.length || 0} datasets created +

+
+ + + + + + + + Create New Dataset + + Select sampling strategy and configuration parameters to create a new dataset + + + +
+
+ + +
+ + {createFormData.strategy === "random" && ( +
+ +