add: script for testing model
This commit is contained in:
parent
c14c680228
commit
3d96f4986d
176
ml/api/main.py
176
ml/api/main.py
@ -18,68 +18,78 @@ app = FastAPI(title="CVSA ML API", version="1.0.0")
|
||||
tokenizer = None
|
||||
classifier_model = None
|
||||
|
||||
|
||||
class ClassificationRequest(BaseModel):
|
||||
title: str
|
||||
description: str
|
||||
tags: str
|
||||
aid: int = None
|
||||
|
||||
|
||||
class ClassificationResponse(BaseModel):
|
||||
label: int
|
||||
probabilities: List[float]
|
||||
aid: int = None
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
models_loaded: bool
|
||||
|
||||
|
||||
def load_models():
|
||||
"""Load the tokenizer and classifier models"""
|
||||
global tokenizer, classifier_model
|
||||
|
||||
|
||||
try:
|
||||
# Load tokenizer
|
||||
logger.info("Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3")
|
||||
|
||||
|
||||
# Load classifier model
|
||||
logger.info("Loading classifier model...")
|
||||
from model_config import VideoClassifierV3_15
|
||||
|
||||
|
||||
model_path = "../../model/akari/3.17.pt"
|
||||
classifier_model = VideoClassifierV3_15()
|
||||
classifier_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
||||
classifier_model.load_state_dict(
|
||||
torch.load(model_path, map_location=torch.device("cpu"))
|
||||
)
|
||||
classifier_model.eval()
|
||||
|
||||
|
||||
logger.info("All models loaded successfully")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load models: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def softmax(logits: np.ndarray) -> np.ndarray:
|
||||
"""Apply softmax to logits"""
|
||||
exp_logits = np.exp(logits - np.max(logits))
|
||||
return exp_logits / np.sum(exp_logits)
|
||||
|
||||
|
||||
def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray:
|
||||
"""Get Jina embeddings using tokenizer and ONNX-like processing"""
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer not loaded")
|
||||
|
||||
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
session = ort.InferenceSession("../../model/embedding/model.onnx")
|
||||
|
||||
|
||||
encoded_inputs = tokenizer(
|
||||
texts,
|
||||
add_special_tokens=False, # 关键:不添加特殊token(与JS一致)
|
||||
return_attention_mask=False,
|
||||
return_tensors=None # 返回原生Python列表,便于后续处理
|
||||
return_tensors=None, # 返回原生Python列表,便于后续处理
|
||||
)
|
||||
input_ids = encoded_inputs["input_ids"] # 形状: [batch_size, seq_len_i](每个样本长度可能不同)
|
||||
|
||||
input_ids = encoded_inputs[
|
||||
"input_ids"
|
||||
] # 形状: [batch_size, seq_len_i](每个样本长度可能不同)
|
||||
|
||||
# 2. 计算offsets(与JS的cumsum逻辑完全一致)
|
||||
# 先获取每个样本的token长度
|
||||
lengths = [len(ids) for ids in input_ids]
|
||||
@ -91,25 +101,28 @@ def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray:
|
||||
cumsum.append(current_sum)
|
||||
# 构建offsets:起始为0,后面跟累积和
|
||||
offsets = [0] + cumsum # 形状: [batch_size]
|
||||
|
||||
|
||||
# 3. 展平input_ids为一维数组
|
||||
flattened_input_ids = []
|
||||
for ids in input_ids:
|
||||
flattened_input_ids.extend(ids) # 直接拼接所有token id
|
||||
flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64)
|
||||
|
||||
|
||||
# 4. 准备ONNX输入(与JS的tensor形状保持一致)
|
||||
inputs = {
|
||||
"input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids),
|
||||
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64))
|
||||
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64)),
|
||||
}
|
||||
|
||||
|
||||
# 5. 运行模型推理
|
||||
outputs = session.run(None, inputs)
|
||||
embeddings = outputs[0] # 假设第一个输出是embeddings,形状: [batch_size, embedding_dim]
|
||||
|
||||
embeddings = outputs[
|
||||
0
|
||||
] # 假设第一个输出是embeddings,形状: [batch_size, embedding_dim]
|
||||
|
||||
return torch.tensor(embeddings, dtype=torch.float32).numpy()
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Load models on startup"""
|
||||
@ -117,91 +130,126 @@ async def startup_event():
|
||||
if not success:
|
||||
logger.error("Failed to load models during startup")
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
models_loaded = tokenizer is not None and classifier_model is not None
|
||||
return HealthResponse(
|
||||
status="healthy" if models_loaded else "models_not_loaded",
|
||||
models_loaded=models_loaded
|
||||
models_loaded=models_loaded,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/classify", response_model=ClassificationResponse)
|
||||
async def classify_video(request: ClassificationRequest):
|
||||
"""Classify a video based on title, description, and tags"""
|
||||
try:
|
||||
if tokenizer is None or classifier_model is None:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded")
|
||||
|
||||
|
||||
# Get embeddings for each channel
|
||||
texts = [request.title, request.description, request.tags]
|
||||
texts = [
|
||||
request.title or "no title",
|
||||
request.description or "no description",
|
||||
request.tags or "no tags",
|
||||
]
|
||||
embeddings = get_jina_embeddings_1024(texts)
|
||||
|
||||
|
||||
# Prepare input for classifier (batch_size=1, channels=3, embedding_dim=1024)
|
||||
channel_features = torch.tensor(embeddings).unsqueeze(0) # [1, 3, 1024]
|
||||
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
logits = classifier_model(channel_features)
|
||||
probabilities = softmax(logits.numpy()[0])
|
||||
predicted_label = int(np.argmax(probabilities))
|
||||
|
||||
logger.info(f"Classification completed for aid {request.aid}: label={predicted_label}")
|
||||
|
||||
logger.info(
|
||||
f"Classification completed for aid {request.aid}: label={predicted_label}"
|
||||
)
|
||||
|
||||
return ClassificationResponse(
|
||||
label=predicted_label,
|
||||
probabilities=probabilities.tolist(),
|
||||
aid=request.aid
|
||||
label=predicted_label, probabilities=probabilities.tolist(), aid=request.aid
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Classification error for aid {request.aid}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/classify_batch")
|
||||
async def classify_video_batch(requests: List[ClassificationRequest]):
|
||||
"""Classify multiple videos in batch"""
|
||||
"""Classify multiple videos in batch using true batch processing"""
|
||||
try:
|
||||
if tokenizer is None or classifier_model is None:
|
||||
raise HTTPException(status_code=503, detail="Models not loaded")
|
||||
|
||||
|
||||
if not requests:
|
||||
return {"results": []}
|
||||
|
||||
results = []
|
||||
|
||||
# Collect all texts for batch processing
|
||||
all_title_texts = []
|
||||
all_desc_texts = []
|
||||
all_tags_texts = []
|
||||
|
||||
for request in requests:
|
||||
try:
|
||||
# Get embeddings for each channel
|
||||
texts = [request.title, request.description, request.tags]
|
||||
embeddings = get_jina_embeddings_1024(texts)
|
||||
|
||||
# Prepare input for classifier
|
||||
channel_features = torch.tensor(embeddings).unsqueeze(0)
|
||||
|
||||
# Run inference
|
||||
with torch.no_grad():
|
||||
logits = classifier_model(channel_features)
|
||||
probabilities = softmax(logits.numpy()[0])
|
||||
predicted_label = int(np.argmax(probabilities))
|
||||
|
||||
results.append({
|
||||
"aid": request.aid,
|
||||
"label": predicted_label,
|
||||
"probabilities": probabilities.tolist()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch classification error for aid {request.aid}: {str(e)}")
|
||||
results.append({
|
||||
"aid": request.aid,
|
||||
"label": -1,
|
||||
"probabilities": [],
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# Handle missing or empty fields
|
||||
title = request.title or "no title"
|
||||
description = request.description or "no description"
|
||||
tags = request.tags or "no tags"
|
||||
|
||||
all_title_texts.append(title)
|
||||
all_desc_texts.append(description)
|
||||
all_tags_texts.append(tags)
|
||||
|
||||
# Process all titles in batch
|
||||
title_embeddings = get_jina_embeddings_1024(all_title_texts)
|
||||
|
||||
# Process all descriptions in batch
|
||||
desc_embeddings = get_jina_embeddings_1024(all_desc_texts)
|
||||
|
||||
# Process all tags in batch
|
||||
tags_embeddings = get_jina_embeddings_1024(all_tags_texts)
|
||||
|
||||
# Stack embeddings: [batch_size, 3, embedding_dim]
|
||||
batch_features = np.stack(
|
||||
[title_embeddings, desc_embeddings, tags_embeddings], axis=1
|
||||
)
|
||||
|
||||
# Convert to tensor and run inference for entire batch
|
||||
channel_features = torch.tensor(batch_features, dtype=torch.float32)
|
||||
|
||||
print(channel_features.shape)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = classifier_model(channel_features) # [batch_size, num_classes]
|
||||
probabilities_batch = softmax(logits.numpy())
|
||||
predicted_labels = np.argmax(probabilities_batch, axis=1)
|
||||
|
||||
# Prepare results
|
||||
for i, request in enumerate(requests):
|
||||
results.append(
|
||||
{
|
||||
"aid": request.aid,
|
||||
"label": int(predicted_labels[i]),
|
||||
"probabilities": probabilities_batch[i].tolist(),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"Batch classification completed for {len(requests)} requests")
|
||||
|
||||
return {"results": results}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Batch classification failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Batch classification failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Batch classification failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8544)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8544)
|
||||
|
||||
@ -475,6 +475,7 @@ class DatasetBuilder:
|
||||
|
||||
def list_datasets(self) -> List[Dict[str, Any]]:
|
||||
"""List all datasets with their basic information"""
|
||||
self.storage._load_metadata_cache()
|
||||
return self.storage.list_datasets()
|
||||
|
||||
def get_dataset_stats(self) -> Dict[str, Any]:
|
||||
|
||||
@ -153,16 +153,7 @@ async def list_datasets_endpoint():
|
||||
raise HTTPException(status_code=503, detail="Dataset builder not available")
|
||||
|
||||
datasets = dataset_builder.list_datasets()
|
||||
# Add description to each dataset
|
||||
datasets_with_description = []
|
||||
for dataset in datasets:
|
||||
dataset_info = dataset_builder.get_dataset(dataset["dataset_id"])
|
||||
if dataset_info and "description" in dataset_info:
|
||||
dataset["description"] = dataset_info["description"]
|
||||
else:
|
||||
dataset["description"] = None
|
||||
datasets_with_description.append(dataset)
|
||||
return {"datasets": datasets_with_description}
|
||||
return {"datasets": datasets}
|
||||
|
||||
|
||||
@router.delete("/dataset/{dataset_id}")
|
||||
|
||||
@ -189,85 +189,7 @@ class EmbeddingClassifier(nn.Module):
|
||||
}
|
||||
|
||||
|
||||
class AttentionEmbeddingClassifier(EmbeddingClassifier):
|
||||
"""
|
||||
Enhanced classifier with self-attention mechanism
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 2048,
|
||||
hidden_dims: Optional[Tuple[int, ...]] = None,
|
||||
dropout_rate: float = 0.3,
|
||||
batch_norm: bool = True,
|
||||
activation: str = "relu",
|
||||
attention_dim: int = 512
|
||||
):
|
||||
super().__init__(input_dim, hidden_dims, dropout_rate, batch_norm, activation)
|
||||
|
||||
# Self-attention mechanism
|
||||
self.attention_dim = attention_dim
|
||||
self.attention = nn.MultiheadAttention(
|
||||
embed_dim=input_dim,
|
||||
num_heads=8,
|
||||
dropout=dropout_rate,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
# Attention projection layer
|
||||
self.attention_projection = nn.Linear(input_dim, attention_dim)
|
||||
|
||||
# Re-initialize attention weights
|
||||
self._initialize_attention_weights()
|
||||
|
||||
logger.info(f"Initialized AttentionEmbeddingClassifier with attention_dim={attention_dim}")
|
||||
|
||||
def _initialize_attention_weights(self):
|
||||
"""Initialize attention mechanism weights"""
|
||||
for module in self.attention.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass with attention mechanism
|
||||
|
||||
Args:
|
||||
x: Input embeddings of shape (batch_size, input_dim)
|
||||
|
||||
Returns:
|
||||
logits of shape (batch_size, 1)
|
||||
"""
|
||||
# Ensure input is float tensor
|
||||
if not x.dtype == torch.float32:
|
||||
x = x.float()
|
||||
|
||||
# Add sequence dimension for attention (batch_size, 1, input_dim)
|
||||
x_expanded = x.unsqueeze(1)
|
||||
|
||||
# Apply self-attention
|
||||
attended, attention_weights = self.attention(x_expanded, x_expanded, x_expanded)
|
||||
|
||||
# Remove sequence dimension (batch_size, input_dim)
|
||||
attended = attended.squeeze(1)
|
||||
|
||||
# Project to attention dimension
|
||||
attended = self.attention_projection(attended)
|
||||
|
||||
# Process through original classification layers
|
||||
for layer in self.layers:
|
||||
attended = layer(attended)
|
||||
|
||||
# Final classification layer
|
||||
logits = self.classifier(attended)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def create_model(
|
||||
model_type: str = "standard",
|
||||
input_dim: int = 2048,
|
||||
hidden_dims: Optional[Tuple[int, ...]] = None,
|
||||
**kwargs
|
||||
@ -276,7 +198,6 @@ def create_model(
|
||||
Factory function to create embedding classifier models
|
||||
|
||||
Args:
|
||||
model_type: Type of model ('standard', 'attention')
|
||||
input_dim: Input embedding dimension
|
||||
hidden_dims: Hidden layer dimensions
|
||||
**kwargs: Additional model arguments
|
||||
@ -284,18 +205,12 @@ def create_model(
|
||||
Returns:
|
||||
Initialized model
|
||||
"""
|
||||
if model_type == "standard":
|
||||
return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
|
||||
elif model_type == "attention":
|
||||
return AttentionEmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test model creation and forward pass
|
||||
model = create_model(
|
||||
model_type="standard",
|
||||
input_dim=2048,
|
||||
hidden_dims=(512, 256, 128),
|
||||
dropout_rate=0.3
|
||||
|
||||
636
ml_new/training/test.py
Normal file
636
ml_new/training/test.py
Normal file
@ -0,0 +1,636 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for evaluating trained models on a dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from torch.utils.data import DataLoader
|
||||
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
# Add the parent directory to the path to import ml_new modules
|
||||
sys.path.append(str(Path(__file__).parent.parent))
|
||||
|
||||
from ml_new.training.models import create_model
|
||||
from ml_new.training.data_loader import DatasetLoader, EmbeddingDataset
|
||||
from ml_new.config.logger_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test embedding classification model",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
# Required arguments
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="ID of the dataset to use for testing"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--experiment",
|
||||
type=str,
|
||||
help="Name of the experiment to load model from"
|
||||
)
|
||||
|
||||
# Optional arguments
|
||||
parser.add_argument(
|
||||
"--datasets-dir",
|
||||
type=str,
|
||||
default="training/datasets",
|
||||
help="Directory containing dataset files"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoints-dir",
|
||||
type=str,
|
||||
default="training/checkpoints",
|
||||
help="Directory containing model checkpoints"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=str,
|
||||
default="best_model.pth",
|
||||
help="Checkpoint file to load (relative to experiment dir)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size for testing"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of worker processes for data loading"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Device to use (auto, cpu, cuda)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Normalize embeddings during testing"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output file for detailed results (JSON)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Classification threshold"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-api",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use API model instead of local model"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--api-url",
|
||||
type=str,
|
||||
default="http://localhost:8544",
|
||||
help="API base URL"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--misclassified-output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output file for misclassified samples (FN and FP aids)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def setup_device(device_arg: str):
|
||||
"""Setup device"""
|
||||
if device_arg == "auto":
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
device = torch.device(device_arg)
|
||||
|
||||
logger.info(f"Using device: {device}")
|
||||
return device
|
||||
|
||||
|
||||
def load_model_from_experiment(
|
||||
checkpoints_dir: str,
|
||||
experiment_name: str,
|
||||
checkpoint_file: str,
|
||||
device: torch.device
|
||||
):
|
||||
"""
|
||||
Load a trained model from an experiment checkpoint
|
||||
|
||||
Args:
|
||||
checkpoints_dir: Directory containing checkpoints
|
||||
experiment_name: Name of the experiment
|
||||
checkpoint_file: Checkpoint file name
|
||||
device: Device to load model to
|
||||
|
||||
Returns:
|
||||
Loaded model
|
||||
"""
|
||||
checkpoint_path = Path(checkpoints_dir) / experiment_name / checkpoint_file
|
||||
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
logger.info(f"Loading checkpoint from {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
||||
|
||||
# Get model config from checkpoint
|
||||
model_config = checkpoint.get('model_config', {})
|
||||
|
||||
# Create model with saved config
|
||||
model = create_model(
|
||||
input_dim=model_config.get('input_dim', 2048),
|
||||
hidden_dims=tuple(model_config.get('hidden_dims', [512, 256, 128])),
|
||||
dropout_rate=model_config.get('dropout_rate', 0.3),
|
||||
batch_norm=model_config.get('batch_norm', True)
|
||||
)
|
||||
|
||||
# Load state dict
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logger.info(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
|
||||
logger.info(f"Model config: {model_config}")
|
||||
|
||||
return model, model_config
|
||||
|
||||
|
||||
def evaluate_model(
|
||||
model,
|
||||
test_loader: DataLoader,
|
||||
device: torch.device,
|
||||
threshold: float = 0.5
|
||||
):
|
||||
"""
|
||||
Evaluate model on test set
|
||||
|
||||
Args:
|
||||
model: The model to evaluate
|
||||
test_loader: Test data loader
|
||||
device: Device to use
|
||||
threshold: Classification threshold
|
||||
|
||||
Returns:
|
||||
Tuple of (metrics, predictions, probabilities, true_labels, fn_aids, fp_aids)
|
||||
"""
|
||||
model.eval()
|
||||
criterion = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
total_loss = 0.0
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
all_probabilities = []
|
||||
all_metadata = []
|
||||
fn_aids = []
|
||||
fp_aids = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, (embeddings, labels, metadata) in enumerate(test_loader):
|
||||
embeddings = embeddings.to(device)
|
||||
labels = labels.to(device).float()
|
||||
|
||||
# Forward pass
|
||||
outputs = model(embeddings)
|
||||
loss = criterion(outputs.squeeze(), labels)
|
||||
|
||||
# Collect statistics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Get predictions and probabilities
|
||||
probabilities = torch.sigmoid(outputs).squeeze()
|
||||
predictions = (probabilities > threshold).long()
|
||||
|
||||
all_predictions.extend(predictions.cpu().numpy())
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
all_probabilities.extend(probabilities.cpu().numpy())
|
||||
|
||||
# Collect metadata and track FN/FP
|
||||
batch_metadata = metadata if isinstance(metadata, list) else [metadata]
|
||||
all_metadata.extend(batch_metadata)
|
||||
|
||||
# Track FN and FP aids for this batch
|
||||
for i, (true_label, pred_label) in enumerate(zip(labels.cpu().numpy(), predictions.cpu().numpy())):
|
||||
if isinstance(batch_metadata[i], dict) and 'aid' in batch_metadata[i]:
|
||||
aid = batch_metadata[i]['aid']
|
||||
if true_label == 1 and pred_label == 0: # False Negative
|
||||
fn_aids.append(aid)
|
||||
elif true_label == 0 and pred_label == 1: # False Positive
|
||||
fp_aids.append(aid)
|
||||
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
logger.info(f"Processed {batch_idx + 1}/{len(test_loader)} batches")
|
||||
|
||||
# Calculate metrics
|
||||
test_loss = total_loss / len(test_loader)
|
||||
test_accuracy = accuracy_score(all_labels, all_predictions)
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
all_labels, all_predictions, average='binary', zero_division=0
|
||||
)
|
||||
|
||||
try:
|
||||
test_auc = roc_auc_score(all_labels, all_probabilities)
|
||||
except ValueError:
|
||||
test_auc = 0.0
|
||||
|
||||
# Confusion matrix
|
||||
cm = confusion_matrix(all_labels, all_predictions)
|
||||
if cm.size == 4:
|
||||
tn, fp, fn, tp = cm.ravel()
|
||||
else:
|
||||
tn, fp, fn, tp = 0, 0, 0, 0
|
||||
|
||||
metrics = {
|
||||
'loss': test_loss,
|
||||
'accuracy': test_accuracy,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1,
|
||||
'auc': test_auc,
|
||||
'true_negatives': int(tn),
|
||||
'false_positives': int(fp),
|
||||
'false_negatives': int(fn),
|
||||
'true_positives': int(tp),
|
||||
'total_samples': len(all_labels),
|
||||
'threshold': threshold
|
||||
}
|
||||
|
||||
# Add class distribution
|
||||
unique, counts = np.unique(all_labels, return_counts=True)
|
||||
metrics['class_distribution'] = {int(k): int(v) for k, v in zip(unique, counts)}
|
||||
|
||||
return metrics, all_predictions, all_probabilities, all_labels, fn_aids, fp_aids
|
||||
|
||||
|
||||
async def call_api_batch(session: aiohttp.ClientSession, api_url: str, requests: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Call the classification API for batch requests"""
|
||||
try:
|
||||
url = f"{api_url}/classify_batch"
|
||||
async with session.post(url, json=requests) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result.get('results', [])
|
||||
else:
|
||||
logger.warning(f"Batch API request failed with status {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Batch API request failed: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def convert_api_label_to_bool(api_label: int) -> int:
|
||||
"""Convert API label to boolean (non-zero = true)"""
|
||||
return 1 if api_label != 0 else 0
|
||||
|
||||
|
||||
async def evaluate_with_api(
|
||||
embeddings: np.ndarray,
|
||||
labels: np.ndarray,
|
||||
metadata: List[Dict[str, Any]],
|
||||
api_url: str,
|
||||
batch_size: int = 32
|
||||
):
|
||||
"""
|
||||
Evaluate using the API instead of local model
|
||||
|
||||
Args:
|
||||
embeddings: Array of embeddings (not used for API calls)
|
||||
labels: Ground truth labels
|
||||
metadata: Metadata containing title, description, tags, aid
|
||||
api_url: API base URL
|
||||
batch_size: Number of requests per API batch call
|
||||
|
||||
Returns:
|
||||
Tuple of (metrics, predictions, probabilities, true_labels, fn_aids, fp_aids)
|
||||
"""
|
||||
logger.info(f"Using API at {api_url} for evaluation")
|
||||
|
||||
# Prepare API requests
|
||||
requests = []
|
||||
for i, meta in enumerate(metadata):
|
||||
# Extract metadata fields for API
|
||||
title = meta.get('title', '')
|
||||
description = meta.get('description', '')
|
||||
tags = meta.get('tags', '')
|
||||
aid = meta.get('aid', i)
|
||||
|
||||
# Handle missing or empty fields
|
||||
if not title:
|
||||
title = f"Video {aid}"
|
||||
if not description:
|
||||
description = ""
|
||||
if not tags:
|
||||
tags = ""
|
||||
|
||||
request_data = {
|
||||
"title": title,
|
||||
"description": description,
|
||||
"tags": tags,
|
||||
"aid": aid
|
||||
}
|
||||
requests.append(request_data)
|
||||
|
||||
# Split requests into batches
|
||||
num_batches = (len(requests) + batch_size - 1) // batch_size
|
||||
logger.info(f"Making {num_batches} batch API requests with batch_size={batch_size} for {len(requests)} total requests")
|
||||
|
||||
# Process all batches
|
||||
all_predictions = []
|
||||
all_probabilities = []
|
||||
all_labels = labels.tolist()
|
||||
all_aids = [meta.get('aid', i) for i, meta in enumerate(metadata)]
|
||||
failed_requests = 0
|
||||
fn_aids = []
|
||||
fp_aids = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for batch_idx in range(num_batches):
|
||||
start_idx = batch_idx * batch_size
|
||||
end_idx = min(start_idx + batch_size, len(requests))
|
||||
batch_requests = requests[start_idx:end_idx]
|
||||
|
||||
logger.info(f"Processing batch {batch_idx + 1}/{num_batches} ({len(batch_requests)} requests)")
|
||||
|
||||
results = await call_api_batch(session, api_url, batch_requests)
|
||||
|
||||
if results is None:
|
||||
logger.error(f"Batch {batch_idx + 1} API request failed completely")
|
||||
# Create dummy results for this batch
|
||||
all_predictions.extend([0] * len(batch_requests))
|
||||
all_probabilities.extend([0.0] * len(batch_requests))
|
||||
failed_requests += len(batch_requests)
|
||||
continue
|
||||
|
||||
for i, result in enumerate(results):
|
||||
global_idx = start_idx + i
|
||||
if not isinstance(result, dict) or 'error' in result:
|
||||
error_msg = result.get('error', 'Unknown error') if isinstance(result, dict) else 'Invalid result'
|
||||
logger.warning(f"Failed to get API prediction for request {global_idx}: {error_msg}")
|
||||
failed_requests += 1
|
||||
all_predictions.append(0)
|
||||
all_probabilities.append(0.0)
|
||||
continue
|
||||
|
||||
# Convert API response to our format
|
||||
api_label = result.get('label', -1)
|
||||
probabilities = result.get('probabilities')
|
||||
|
||||
# Convert to boolean (non-zero = true)
|
||||
prediction = convert_api_label_to_bool(api_label)
|
||||
# Use the probability of the positive class
|
||||
if probabilities and len(probabilities) > 0:
|
||||
positive_prob = 1 - probabilities[0]
|
||||
else:
|
||||
logger.warning(f"No probabilities for request {global_idx}")
|
||||
failed_requests += 1
|
||||
all_predictions.append(0)
|
||||
all_probabilities.append(0.0)
|
||||
continue
|
||||
|
||||
all_predictions.append(prediction)
|
||||
all_probabilities.append(positive_prob)
|
||||
|
||||
if failed_requests > 0:
|
||||
logger.warning(f"Failed to get API predictions for {failed_requests} requests")
|
||||
|
||||
# Collect FN and FP aids
|
||||
for i, (true_label, pred_label) in enumerate(zip(all_labels, all_predictions)):
|
||||
aid = all_aids[i]
|
||||
if true_label == 1 and pred_label == 0: # False Negative
|
||||
fn_aids.append(aid)
|
||||
elif true_label == 0 and pred_label == 1: # False Positive
|
||||
fp_aids.append(aid)
|
||||
|
||||
# Calculate metrics
|
||||
test_accuracy = accuracy_score(all_labels, all_predictions)
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
all_labels, all_predictions, average='binary', zero_division=0
|
||||
)
|
||||
|
||||
try:
|
||||
test_auc = roc_auc_score(all_labels, all_probabilities)
|
||||
except ValueError:
|
||||
test_auc = 0.0
|
||||
|
||||
# Confusion matrix
|
||||
cm = confusion_matrix(all_labels, all_predictions)
|
||||
if cm.size == 4:
|
||||
tn, fp, fn, tp = cm.ravel()
|
||||
else:
|
||||
tn, fp, fn, tp = 0, 0, 0, 0
|
||||
|
||||
metrics = {
|
||||
'accuracy': test_accuracy,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1,
|
||||
'auc': test_auc,
|
||||
'true_negatives': int(tn),
|
||||
'false_positives': int(fp),
|
||||
'false_negatives': int(fn),
|
||||
'true_positives': int(tp),
|
||||
'total_samples': len(all_labels),
|
||||
'failed_requests': failed_requests
|
||||
}
|
||||
|
||||
# Add class distribution
|
||||
unique, counts = np.unique(all_labels, return_counts=True)
|
||||
metrics['class_distribution'] = {int(k): int(v) for k, v in zip(unique, counts)}
|
||||
|
||||
return metrics, all_predictions, all_probabilities, all_labels, fn_aids, fp_aids
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test function"""
|
||||
args = parse_args()
|
||||
|
||||
# Setup device
|
||||
device = setup_device(args.device)
|
||||
|
||||
# Check if dataset exists
|
||||
loader = DatasetLoader(args.datasets_dir)
|
||||
datasets = loader.list_datasets()
|
||||
|
||||
if args.dataset_id not in datasets:
|
||||
logger.error(f"Dataset '{args.dataset_id}' not found in {args.datasets_dir}")
|
||||
logger.info(f"Available datasets: {datasets}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load dataset (use entire dataset as test set)
|
||||
try:
|
||||
logger.info(f"Loading dataset {args.dataset_id}...")
|
||||
embeddings, labels, metadata = loader.load_dataset(args.dataset_id)
|
||||
|
||||
logger.info(f"Dataset loaded: {len(embeddings)} samples")
|
||||
logger.info(f"Embedding dimension: {embeddings.shape[1]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load dataset: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Choose evaluation method
|
||||
if args.use_api:
|
||||
# Use API for evaluation
|
||||
logger.info("Using API-based evaluation")
|
||||
|
||||
# Run async evaluation
|
||||
metrics, predictions, probabilities, true_labels, fn_aids, fp_aids = asyncio.run(
|
||||
evaluate_with_api(
|
||||
embeddings, labels, metadata,
|
||||
args.api_url,
|
||||
args.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
# For API mode, we don't have model_config
|
||||
model_config = {"type": "api", "api_url": args.api_url}
|
||||
|
||||
else:
|
||||
# Use local model for evaluation
|
||||
# Check if experiment exists
|
||||
experiment_dir = Path(args.checkpoints_dir) / args.experiment
|
||||
if not experiment_dir.exists():
|
||||
logger.error(f"Experiment '{args.experiment}' not found in {args.checkpoints_dir}")
|
||||
available = [d.name for d in Path(args.checkpoints_dir).iterdir() if d.is_dir()]
|
||||
logger.info(f"Available experiments: {available}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load model
|
||||
try:
|
||||
model, model_config = load_model_from_experiment(
|
||||
args.checkpoints_dir,
|
||||
args.experiment,
|
||||
args.checkpoint_file,
|
||||
device
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Create test dataset and loader
|
||||
test_dataset = EmbeddingDataset(
|
||||
embeddings, labels, metadata,
|
||||
normalize=args.normalize
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers
|
||||
)
|
||||
|
||||
# Evaluate model
|
||||
logger.info("Starting local model evaluation...")
|
||||
metrics, predictions, probabilities, true_labels, fn_aids, fp_aids = evaluate_model(
|
||||
model, test_loader, device, args.threshold
|
||||
)
|
||||
|
||||
# Print results
|
||||
logger.info("=" * 50)
|
||||
logger.info("Test Results")
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"Dataset: {args.dataset_id}")
|
||||
if args.use_api:
|
||||
logger.info(f"Method: API ({args.api_url})")
|
||||
else:
|
||||
logger.info(f"Experiment: {args.experiment}")
|
||||
logger.info(f"Total samples: {metrics['total_samples']}")
|
||||
logger.info(f"Class distribution: {metrics['class_distribution']}")
|
||||
if 'failed_requests' in metrics:
|
||||
logger.info(f"Failed API requests: {metrics['failed_requests']}")
|
||||
logger.info("-" * 50)
|
||||
if 'loss' in metrics:
|
||||
logger.info(f"Loss: {metrics['loss']:.4f}")
|
||||
logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
|
||||
logger.info(f"Precision: {metrics['precision']:.4f}")
|
||||
logger.info(f"Recall: {metrics['recall']:.4f}")
|
||||
logger.info(f"F1 Score: {metrics['f1']:.4f}")
|
||||
logger.info(f"AUC: {metrics['auc']:.4f}")
|
||||
logger.info("-" * 50)
|
||||
logger.info(f"True Positives: {metrics['true_positives']}")
|
||||
logger.info(f"True Negatives: {metrics['true_negatives']}")
|
||||
logger.info(f"False Positives: {metrics['false_positives']}")
|
||||
logger.info(f"False Negatives: {metrics['false_negatives']}")
|
||||
logger.info("=" * 50)
|
||||
|
||||
# Save detailed results if requested
|
||||
if args.output:
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
results = {
|
||||
'dataset_id': args.dataset_id,
|
||||
'experiment': args.experiment,
|
||||
'checkpoint': args.checkpoint_file,
|
||||
'model_config': model_config,
|
||||
'metrics': metrics,
|
||||
'predictions': [int(p) for p in predictions],
|
||||
'probabilities': [float(p) for p in probabilities],
|
||||
'labels': [int(l) for l in true_labels]
|
||||
}
|
||||
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
logger.info(f"Detailed results saved to {output_path}")
|
||||
|
||||
# Save misclassified samples (FN and FP aids) if requested
|
||||
if args.misclassified_output:
|
||||
misclassified_path = Path(args.misclassified_output)
|
||||
misclassified_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
misclassified_data = {
|
||||
'false_negatives': fn_aids,
|
||||
'false_positives': fp_aids,
|
||||
'fn_count': len(fn_aids),
|
||||
'fp_count': len(fp_aids),
|
||||
'total_misclassified': len(fn_aids) + len(fp_aids)
|
||||
}
|
||||
|
||||
with open(misclassified_path, 'w') as f:
|
||||
json.dump(misclassified_data, f, indent=2)
|
||||
|
||||
logger.info(f"Misclassified samples (FN: {len(fn_aids)}, FP: {len(fp_aids)}) saved to {misclassified_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -62,15 +62,6 @@ def parse_args():
|
||||
help="Directory containing dataset files"
|
||||
)
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model-type",
|
||||
type=str,
|
||||
choices=["standard", "attention"],
|
||||
default="standard",
|
||||
help="Type of model architecture"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-dim",
|
||||
type=int,
|
||||
@ -353,11 +344,11 @@ def main():
|
||||
# Create experiment name if not provided
|
||||
if args.experiment_name is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
args.experiment_name = f"{args.model_type}_{args.dataset_id}_{timestamp}"
|
||||
args.experiment_name = f"{timestamp}_{args.dataset_id}"
|
||||
|
||||
logger.info(f"Starting experiment: {args.experiment_name}")
|
||||
logger.info(f"Dataset: {args.dataset_id}")
|
||||
logger.info(f"Model: {args.model_type} with hidden dims {args.hidden_dims}")
|
||||
logger.info(f"Model: hidden dims {args.hidden_dims}")
|
||||
|
||||
# Load dataset and create data loaders
|
||||
try:
|
||||
@ -387,7 +378,6 @@ def main():
|
||||
try:
|
||||
logger.info("Creating model...")
|
||||
model = create_model(
|
||||
model_type=args.model_type,
|
||||
input_dim=args.input_dim,
|
||||
hidden_dims=tuple(args.hidden_dims),
|
||||
dropout_rate=args.dropout_rate,
|
||||
@ -452,7 +442,6 @@ def main():
|
||||
"experiment_name": args.experiment_name,
|
||||
"dataset_id": args.dataset_id,
|
||||
"model_config": {
|
||||
"model_type": args.model_type,
|
||||
"input_dim": args.input_dim,
|
||||
"hidden_dims": args.hidden_dims,
|
||||
"dropout_rate": args.dropout_rate,
|
||||
|
||||
@ -417,7 +417,7 @@ class ModelTrainer:
|
||||
|
||||
def load_checkpoint(self, checkpoint_path: str, load_optimizer: bool = True) -> None:
|
||||
"""Load model from checkpoint"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
@ -504,7 +504,6 @@ if __name__ == "__main__":
|
||||
|
||||
# Create dummy model and data
|
||||
model = create_model(
|
||||
model_type="standard",
|
||||
input_dim=2048,
|
||||
hidden_dims=(512, 256, 128)
|
||||
)
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { CardDescription, CardTitle } from "@/components/ui/card";
|
||||
import { CardTitle } from "@/components/ui/card";
|
||||
import { DatasetManager } from "@/components/DatasetManager";
|
||||
import { TaskMonitor } from "@/components/TaskMonitor";
|
||||
import { SamplingPanel } from "@/components/SamplingPanel";
|
||||
import { Database, Activity, Settings } from "lucide-react";
|
||||
import { Database, Activity } from "lucide-react";
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
@ -21,22 +20,17 @@ function App() {
|
||||
<div className="min-h-screen flex justify-center">
|
||||
<div className="container lg:max-w-3xl xl:max-w-4xl bg-background py-8 px-3">
|
||||
<div className="mb-8">
|
||||
<h1 className="text-3xl font-bold tracking-tight">ML Dataset Management Panel</h1>
|
||||
<p className="text-muted-foreground">
|
||||
Create and manage machine learning datasets with multiple sampling strategies and task monitoring
|
||||
</p>
|
||||
<h1 className="text-3xl font-bold tracking-tight">
|
||||
CVSA Machine Learning Panel
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
<Tabs defaultValue="datasets" className="space-y-4">
|
||||
<TabsList className="grid w-full grid-cols-3">
|
||||
<TabsList className="grid w-full grid-cols-2">
|
||||
<TabsTrigger value="datasets" className="flex items-center gap-2">
|
||||
<Database className="h-4 w-4" />
|
||||
Datasets
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="sampling" className="flex items-center gap-2">
|
||||
<Settings className="h-4 w-4" />
|
||||
Sampling
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="monitor" className="flex items-center gap-2">
|
||||
<Activity className="h-4 w-4" />
|
||||
Tasks
|
||||
@ -44,22 +38,11 @@ function App() {
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="datasets" className="space-y-4">
|
||||
<CardTitle>Dataset Management</CardTitle>
|
||||
<CardDescription>View, create and manage your machine learning datasets</CardDescription>
|
||||
<DatasetManager />
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="sampling" className="space-y-4">
|
||||
<CardTitle>Sampling Strategy Configuration</CardTitle>
|
||||
<CardDescription>
|
||||
Configure different data sampling strategies to create balanced datasets
|
||||
</CardDescription>
|
||||
<SamplingPanel />
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="monitor" className="space-y-4">
|
||||
<CardTitle>Task Monitor</CardTitle>
|
||||
<CardDescription>Monitor real-time status and progress of dataset building tasks</CardDescription>
|
||||
<TaskMonitor />
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { useState } from "react";
|
||||
import { useState } from "react";
|
||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
@ -20,19 +20,22 @@ import {
|
||||
SelectValue
|
||||
} from "@/components/ui/select";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Trash2, Plus, Database, FileText, Calendar, Activity } from "lucide-react";
|
||||
import { Trash2, Plus, Database, Upload } from "lucide-react";
|
||||
import { apiClient } from "@/lib/api";
|
||||
import { toast } from "sonner";
|
||||
import { Spinner } from "@/components/ui/spinner"
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
|
||||
export function DatasetManager() {
|
||||
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
|
||||
const [createFormData, setCreateFormData] = useState({
|
||||
creationMode: "sampling", // "sampling" or "aidList"
|
||||
strategy: "all",
|
||||
limit: "",
|
||||
embeddingModel: "",
|
||||
description: "",
|
||||
forceRegenerate: false
|
||||
forceRegenerate: false,
|
||||
aidListFile: null as File | null,
|
||||
aidList: [] as number[]
|
||||
});
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
@ -45,7 +48,7 @@ export function DatasetManager() {
|
||||
});
|
||||
|
||||
// Fetch embedding models
|
||||
const { data: modelsData, isLoading: modelsLoading } = useQuery({
|
||||
const { data: modelsData } = useQuery({
|
||||
queryKey: ["embedding-models"],
|
||||
queryFn: () => apiClient.getEmbeddingModels()
|
||||
});
|
||||
@ -57,11 +60,14 @@ export function DatasetManager() {
|
||||
toast.success("Dataset creation task started");
|
||||
setIsCreateDialogOpen(false);
|
||||
setCreateFormData({
|
||||
creationMode: "sampling",
|
||||
strategy: "all",
|
||||
limit: "",
|
||||
embeddingModel: "",
|
||||
description: "",
|
||||
forceRegenerate: false
|
||||
forceRegenerate: false,
|
||||
aidListFile: null,
|
||||
aidList: []
|
||||
});
|
||||
queryClient.invalidateQueries({ queryKey: ["datasets"] });
|
||||
queryClient.invalidateQueries({ queryKey: ["tasks"] });
|
||||
@ -83,23 +89,68 @@ export function DatasetManager() {
|
||||
}
|
||||
});
|
||||
|
||||
// Build dataset mutation
|
||||
const buildDatasetMutation = useMutation({
|
||||
mutationFn: (data: {
|
||||
aid_list: number[];
|
||||
embedding_model: string;
|
||||
force_regenerate?: boolean;
|
||||
description?: string;
|
||||
}) => apiClient.buildDataset(data),
|
||||
onSuccess: () => {
|
||||
toast.success("Dataset build task started");
|
||||
setIsCreateDialogOpen(false);
|
||||
setCreateFormData({
|
||||
creationMode: "sampling",
|
||||
strategy: "all",
|
||||
limit: "",
|
||||
embeddingModel: "",
|
||||
description: "",
|
||||
forceRegenerate: false,
|
||||
aidListFile: null,
|
||||
aidList: []
|
||||
});
|
||||
queryClient.invalidateQueries({ queryKey: ["datasets"] });
|
||||
queryClient.invalidateQueries({ queryKey: ["tasks"] });
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
toast.error(`Build failed: ${error.message}`);
|
||||
}
|
||||
});
|
||||
|
||||
const handleCreateDataset = () => {
|
||||
if (!createFormData.embeddingModel) {
|
||||
toast.error("Please select an embedding model");
|
||||
return;
|
||||
}
|
||||
|
||||
const requestData = {
|
||||
sampling: {
|
||||
strategy: createFormData.strategy,
|
||||
...(createFormData.limit && { limit: parseInt(createFormData.limit) })
|
||||
},
|
||||
embedding_model: createFormData.embeddingModel,
|
||||
force_regenerate: createFormData.forceRegenerate,
|
||||
description: createFormData.description || undefined
|
||||
};
|
||||
if (createFormData.creationMode === "sampling") {
|
||||
const requestData = {
|
||||
sampling: {
|
||||
strategy: createFormData.strategy,
|
||||
...(createFormData.limit && { limit: parseInt(createFormData.limit) })
|
||||
},
|
||||
embedding_model: createFormData.embeddingModel,
|
||||
force_regenerate: createFormData.forceRegenerate,
|
||||
description: createFormData.description || undefined
|
||||
};
|
||||
|
||||
createDatasetMutation.mutate(requestData);
|
||||
createDatasetMutation.mutate(requestData);
|
||||
} else if (createFormData.creationMode === "aidList") {
|
||||
if (createFormData.aidList.length === 0) {
|
||||
toast.error("Please upload an aid list file");
|
||||
return;
|
||||
}
|
||||
|
||||
const requestData = {
|
||||
aid_list: createFormData.aidList,
|
||||
embedding_model: createFormData.embeddingModel,
|
||||
force_regenerate: createFormData.forceRegenerate,
|
||||
description: createFormData.description || undefined
|
||||
};
|
||||
|
||||
buildDatasetMutation.mutate(requestData);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeleteDataset = (datasetId: string) => {
|
||||
@ -108,16 +159,67 @@ export function DatasetManager() {
|
||||
}
|
||||
};
|
||||
|
||||
const formatDate = (dateString: string) => {
|
||||
return new Date(dateString).toLocaleString("en-US");
|
||||
// Parse aid list file
|
||||
const parseAidListFile = (file: File): Promise<number[]> => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = (e) => {
|
||||
try {
|
||||
const content = e.target?.result as string;
|
||||
const lines = content.split("\n").filter((line) => line.trim());
|
||||
const aidList: number[] = [];
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (trimmed) {
|
||||
const aid = parseInt(trimmed, 10);
|
||||
if (!isNaN(aid)) {
|
||||
aidList.push(aid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
resolve(aidList);
|
||||
} catch (error) {
|
||||
reject(new Error("Failed to parse file"));
|
||||
}
|
||||
};
|
||||
reader.onerror = () => reject(new Error("Failed to read file"));
|
||||
reader.readAsText(file);
|
||||
});
|
||||
};
|
||||
|
||||
const formatFileSize = (bytes: number) => {
|
||||
if (bytes === 0) return "0 Bytes";
|
||||
const k = 1024;
|
||||
const sizes = ["Bytes", "KB", "MB", "GB"];
|
||||
const i = Math.floor(Math.log(bytes) / Math.log(k));
|
||||
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + " " + sizes[i];
|
||||
// Handle file upload
|
||||
const handleFileUpload = async (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = event.target.files?.[0];
|
||||
if (!file) return;
|
||||
|
||||
if (!file.name.endsWith(".txt") && !file.name.endsWith(".csv")) {
|
||||
toast.error("Please upload a .txt or .csv file");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const aidList = await parseAidListFile(file);
|
||||
if (aidList.length === 0) {
|
||||
toast.error("No valid AIDs found in the file");
|
||||
return;
|
||||
}
|
||||
|
||||
setCreateFormData((prev) => ({
|
||||
...prev,
|
||||
aidListFile: file,
|
||||
aidList: aidList
|
||||
}));
|
||||
|
||||
toast.success(`Loaded ${aidList.length} AIDs from file`);
|
||||
} catch (error) {
|
||||
toast.error("Failed to parse aid list file");
|
||||
}
|
||||
};
|
||||
|
||||
const formatDate = (dateString: string) => {
|
||||
return new Date(dateString).toLocaleString("en-US");
|
||||
};
|
||||
|
||||
if (datasetsLoading) {
|
||||
@ -150,43 +252,120 @@ export function DatasetManager() {
|
||||
<DialogHeader>
|
||||
<DialogTitle>Create New Dataset</DialogTitle>
|
||||
<DialogDescription>
|
||||
Select sampling strategy and configuration parameters to create a new dataset
|
||||
Select sampling strategy and configuration parameters to create a
|
||||
new dataset
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="strategy">Sampling Strategy</Label>
|
||||
<Label htmlFor="creationMode">Creation Mode</Label>
|
||||
<Select
|
||||
value={createFormData.strategy}
|
||||
value={createFormData.creationMode}
|
||||
onValueChange={(value) =>
|
||||
setCreateFormData((prev) => ({ ...prev, strategy: value }))
|
||||
setCreateFormData((prev) => ({
|
||||
...prev,
|
||||
creationMode: value,
|
||||
// Reset aid list when switching modes
|
||||
aidListFile: null,
|
||||
aidList: []
|
||||
}))
|
||||
}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select sampling strategy" />
|
||||
<SelectValue placeholder="Select creation mode" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">All Videos</SelectItem>
|
||||
<SelectItem value="random">Random Sampling</SelectItem>
|
||||
<SelectItem value="sampling">Sampling Strategy</SelectItem>
|
||||
<SelectItem value="aidList">Upload Aid List</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{createFormData.strategy === "random" && (
|
||||
{createFormData.creationMode === "sampling" && (
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="limit">Sample Count</Label>
|
||||
<Textarea
|
||||
id="limit"
|
||||
placeholder="Enter number of samples, e.g., 1000"
|
||||
value={createFormData.limit}
|
||||
onChange={(e) =>
|
||||
<Label htmlFor="strategy">Sampling Strategy</Label>
|
||||
<Select
|
||||
value={createFormData.strategy}
|
||||
onValueChange={(value) =>
|
||||
setCreateFormData((prev) => ({
|
||||
...prev,
|
||||
limit: e.target.value
|
||||
strategy: value
|
||||
}))
|
||||
}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select sampling strategy" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">All Videos</SelectItem>
|
||||
<SelectItem value="random">Random Sampling</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{createFormData.creationMode === "sampling" &&
|
||||
createFormData.strategy === "random" && (
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="limit">Sample Count</Label>
|
||||
<Textarea
|
||||
id="limit"
|
||||
placeholder="Enter number of samples, e.g., 1000"
|
||||
value={createFormData.limit}
|
||||
onChange={(e) =>
|
||||
setCreateFormData((prev) => ({
|
||||
...prev,
|
||||
limit: e.target.value
|
||||
}))
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{createFormData.creationMode === "aidList" && (
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="aidListFile">Aid List File</Label>
|
||||
<div
|
||||
className="border-2 border-dashed rounded-lg p-4 cursor-pointer"
|
||||
onClick={() =>
|
||||
document.getElementById("aidListFile")?.click()
|
||||
}
|
||||
>
|
||||
<div className="flex flex-col items-center space-y-2">
|
||||
<Upload className="h-8 w-8text-secondary-foreground" />
|
||||
<div className="text-sm text-secondary-foreground text-center">
|
||||
{createFormData.aidListFile
|
||||
? `${createFormData.aidListFile.name} (${createFormData.aidList.length} AIDs loaded)`
|
||||
: "Click to upload a .txt or .csv file containing AIDs (one per line)"}
|
||||
</div>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="mt-2"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
document.getElementById("aidListFile")?.click();
|
||||
}}
|
||||
>
|
||||
Choose File
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<input
|
||||
id="aidListFile"
|
||||
type="file"
|
||||
accept=".txt,.csv"
|
||||
onChange={handleFileUpload}
|
||||
className="hidden"
|
||||
/>
|
||||
{createFormData.aidList.length > 0 && (
|
||||
<div className="text-sm text-green-600">
|
||||
✓ Loaded {createFormData.aidList.length} AIDs from{" "}
|
||||
{createFormData.aidListFile?.name}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@ -253,9 +432,18 @@ export function DatasetManager() {
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleCreateDataset}
|
||||
disabled={createDatasetMutation.isPending}
|
||||
disabled={
|
||||
createDatasetMutation.isPending ||
|
||||
buildDatasetMutation.isPending ||
|
||||
(createFormData.creationMode === "aidList" &&
|
||||
createFormData.aidList.length === 0)
|
||||
}
|
||||
>
|
||||
{createDatasetMutation.isPending ? "Creating..." : "Create Dataset"}
|
||||
{createDatasetMutation.isPending || buildDatasetMutation.isPending
|
||||
? "Creating..."
|
||||
: createFormData.creationMode === "sampling"
|
||||
? "Create Dataset"
|
||||
: "Build Dataset"}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
@ -270,8 +458,8 @@ export function DatasetManager() {
|
||||
<CardHeader className="pb-3">
|
||||
<div className="flex items-start justify-between">
|
||||
<div className="flex items-center space-x-2">
|
||||
<CardTitle className="text-base">
|
||||
{dataset.dataset_id.slice(0, 8)}...{dataset.dataset_id.slice(-8)}
|
||||
<CardTitle className="text-base line-clamp-1">
|
||||
{dataset.dataset_id}
|
||||
</CardTitle>
|
||||
</div>
|
||||
<Button
|
||||
@ -288,22 +476,13 @@ export function DatasetManager() {
|
||||
)}
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="grid grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-4 text-sm">
|
||||
<div className="flex items-center space-x-2">
|
||||
<span>{dataset.stats.total_records} records</span>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<span>{dataset.stats.embedding_model}</span>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<span>{formatDate(dataset.created_at)}</span>
|
||||
</div>
|
||||
<div className="flex items-center space-x-2">
|
||||
<span className="text-muted-foreground">
|
||||
New: {dataset.stats.new_embeddings}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-wrap gap-5 text-sm leading-1">
|
||||
<span>{dataset.stats.total_records} records</span>
|
||||
<span>{dataset.stats.embedding_model}</span>
|
||||
<span>{formatDate(dataset.created_at)}</span>
|
||||
<span className="text-muted-foreground">
|
||||
New: {dataset.stats.new_embeddings}
|
||||
</span>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
@ -1,234 +0,0 @@
|
||||
import { useState } from "react";
|
||||
import { useMutation, useQuery } from "@tanstack/react-query";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue
|
||||
} from "@/components/ui/select";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { Alert, AlertDescription } from "@/components/ui/alert";
|
||||
import { Database, Play, TestTube, Settings, BarChart3 } from "lucide-react";
|
||||
import { apiClient } from "@/lib/api";
|
||||
import type { SamplingResponse, DatasetCreateResponse } from "@/types/api";
|
||||
|
||||
interface SamplingConfig {
|
||||
strategy: string;
|
||||
limit?: number;
|
||||
}
|
||||
|
||||
export function SamplingPanel() {
|
||||
const [samplingConfig, setSamplingConfig] = useState<SamplingConfig>({
|
||||
strategy: "all",
|
||||
limit: undefined,
|
||||
});
|
||||
|
||||
const [embeddingModel, setEmbeddingModel] = useState<string>("");
|
||||
const [description, setDescription] = useState<string>("");
|
||||
|
||||
|
||||
// Test sampling mutation
|
||||
const testSamplingMutation = useMutation({
|
||||
mutationFn: (config: SamplingConfig) => apiClient.sampleDataset(config),
|
||||
onSuccess: (data: SamplingResponse) => {
|
||||
console.log("Sampling test successful:", data);
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
console.error("Sampling test failed:", error);
|
||||
}
|
||||
});
|
||||
|
||||
// Create dataset with sampling mutation
|
||||
const createDatasetMutation = useMutation({
|
||||
mutationFn: (config: {
|
||||
sampling: SamplingConfig;
|
||||
embedding_model: string;
|
||||
description?: string;
|
||||
}) => apiClient.createDatasetWithSampling(config),
|
||||
onSuccess: (data: DatasetCreateResponse) => {
|
||||
console.log("Dataset created successfully:", data);
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
console.error("Dataset creation failed:", error);
|
||||
}
|
||||
});
|
||||
|
||||
const handleStrategyChange = (strategy: string) => {
|
||||
setSamplingConfig((prev) => ({ ...prev, strategy }));
|
||||
};
|
||||
|
||||
const handleLimitChange = (limit: string) => {
|
||||
setSamplingConfig((prev) => ({
|
||||
...prev,
|
||||
limit: limit ? parseInt(limit) : undefined
|
||||
}));
|
||||
};
|
||||
|
||||
const handleTestSampling = () => {
|
||||
testSamplingMutation.mutate(samplingConfig);
|
||||
};
|
||||
|
||||
const handleCreateDataset = () => {
|
||||
if (!embeddingModel) {
|
||||
alert("Please select an embedding model");
|
||||
return;
|
||||
}
|
||||
|
||||
createDatasetMutation.mutate({
|
||||
sampling: samplingConfig,
|
||||
embedding_model: embeddingModel,
|
||||
description: description || undefined
|
||||
});
|
||||
};
|
||||
|
||||
const getStrategyDescription = (strategy: string) => {
|
||||
switch (strategy) {
|
||||
case "all":
|
||||
return "Sample all labeled videos";
|
||||
case "random":
|
||||
return "Randomly sample specified number of labeled videos";
|
||||
default:
|
||||
return "Unknown strategy";
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<Tabs defaultValue="configure" className="w-full">
|
||||
<TabsList className="w-full mb-4">
|
||||
<TabsTrigger value="configure">
|
||||
<Settings className="h-4 w-4 mr-2" />
|
||||
Configure Sampling
|
||||
</TabsTrigger>
|
||||
<TabsTrigger value="test">
|
||||
<TestTube className="h-4 w-4 mr-2" />
|
||||
Test Sampling
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="configure" className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Sampling Strategy Configuration</CardTitle>
|
||||
<CardDescription>Select data sampling strategy and parameters</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="strategy">Sampling Strategy</Label>
|
||||
<Select
|
||||
value={samplingConfig.strategy}
|
||||
onValueChange={handleStrategyChange}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select strategy" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">All Labeled Videos</SelectItem>
|
||||
<SelectItem value="random">Random Sampling</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{getStrategyDescription(samplingConfig.strategy)}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{samplingConfig.strategy === "random" && (
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="limit">Sample Count</Label>
|
||||
<Input
|
||||
id="limit"
|
||||
type="number"
|
||||
placeholder="e.g., 1000"
|
||||
value={samplingConfig.limit || ""}
|
||||
onChange={(e) => handleLimitChange(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="test" className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Test Sampling</CardTitle>
|
||||
<CardDescription>
|
||||
Test sampling strategy and view data statistics for sampling
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<div className="flex space-x-4">
|
||||
<Button
|
||||
onClick={handleTestSampling}
|
||||
disabled={testSamplingMutation.isPending}
|
||||
className="flex-1"
|
||||
>
|
||||
<Play className="h-4 w-4 mr-2" />
|
||||
{testSamplingMutation.isPending ? "Testing..." : "Start Test"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{testSamplingMutation.isSuccess && testSamplingMutation.data && (
|
||||
<Alert>
|
||||
<BarChart3 className="h-4 w-4" />
|
||||
<AlertDescription>
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center justify-between">
|
||||
<span>Total available data:</span>
|
||||
<Badge variant="outline">
|
||||
{(
|
||||
testSamplingMutation.data as SamplingResponse
|
||||
).total_available.toLocaleString()}
|
||||
</Badge>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span>Will sample:</span>
|
||||
<Badge>
|
||||
{(
|
||||
testSamplingMutation.data as SamplingResponse
|
||||
).sampled_count.toLocaleString()}
|
||||
</Badge>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<span>Sampling ratio:</span>
|
||||
<Badge variant="secondary">
|
||||
{(
|
||||
((
|
||||
testSamplingMutation.data as SamplingResponse
|
||||
).sampled_count /
|
||||
(
|
||||
testSamplingMutation.data as SamplingResponse
|
||||
).total_available) *
|
||||
100
|
||||
).toFixed(1)}
|
||||
%
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{testSamplingMutation.isError && (
|
||||
<Alert variant="destructive">
|
||||
<AlertDescription>
|
||||
Test failed: {(testSamplingMutation.error as Error).message}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -7,14 +7,14 @@ import {
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
SelectValue
|
||||
} from "@/components/ui/select";
|
||||
import { Progress } from "@/components/ui/progress";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { RefreshCw, Play, Pause, CheckCircle, XCircle, Clock } from "lucide-react";
|
||||
import { RefreshCw, Clock } from "lucide-react";
|
||||
import { apiClient } from "@/lib/api";
|
||||
import type { TasksResponse } from "@/types/api";
|
||||
import { Spinner } from "@/components/ui/spinner"
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
|
||||
export function TaskMonitor() {
|
||||
const [statusFilter, setStatusFilter] = useState<string>("all");
|
||||
@ -33,21 +33,6 @@ export function TaskMonitor() {
|
||||
refetchInterval: 500
|
||||
});
|
||||
|
||||
const getStatusIcon = (status: string) => {
|
||||
switch (status) {
|
||||
case "running":
|
||||
return <Play className="h-4 w-4 text-blue-500" />;
|
||||
case "completed":
|
||||
return <CheckCircle className="h-4 w-4 text-green-500" />;
|
||||
case "failed":
|
||||
return <XCircle className="h-4 w-4 text-red-500" />;
|
||||
case "pending":
|
||||
return <Clock className="h-4 w-4 text-yellow-500" />;
|
||||
default:
|
||||
return <Pause className="h-4 w-4 text-gray-500" />;
|
||||
}
|
||||
};
|
||||
|
||||
const getStatusBadgeVariant = (status: string) => {
|
||||
switch (status) {
|
||||
case "running":
|
||||
@ -80,7 +65,7 @@ export function TaskMonitor() {
|
||||
if (tasksLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-64">
|
||||
<Spinner/>
|
||||
<Spinner />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -118,13 +103,10 @@ export function TaskMonitor() {
|
||||
<CardContent className="p-4">
|
||||
<div className="flex items-start justify-between mb-3">
|
||||
<div className="flex items-center space-x-2">
|
||||
{getStatusIcon(task.status)}
|
||||
<span className="font-mono text-sm">
|
||||
{task.task_id.slice(0, 8)}...
|
||||
</span>
|
||||
<Badge variant={getStatusBadgeVariant(task.status)}>
|
||||
{task.status}
|
||||
</Badge>
|
||||
<span className="font-mono text-sm">{task.task_id}</span>
|
||||
</div>
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{formatDate(task.created_at)}
|
||||
@ -152,14 +134,18 @@ export function TaskMonitor() {
|
||||
<div className="grid grid-cols-2 md:grid-cols-4 gap-4 text-sm">
|
||||
{task.started_at && (
|
||||
<div>
|
||||
<span className="text-muted-foreground">Start Time:</span>
|
||||
<span className="text-muted-foreground">
|
||||
Start Time:
|
||||
</span>
|
||||
<br />
|
||||
{formatDate(task.started_at)}
|
||||
</div>
|
||||
)}
|
||||
{task.completed_at && (
|
||||
<div>
|
||||
<span className="text-muted-foreground">Complete Time:</span>
|
||||
<span className="text-muted-foreground">
|
||||
Complete Time:
|
||||
</span>
|
||||
<br />
|
||||
{formatDate(task.completed_at)}
|
||||
</div>
|
||||
|
||||
@ -4,7 +4,6 @@ import type {
|
||||
EmbeddingModelsResponse,
|
||||
DatasetsResponse,
|
||||
DatasetDetail,
|
||||
SamplingStats,
|
||||
SamplingResponse,
|
||||
DatasetCreateResponse,
|
||||
Task,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user