1
0

add: script for testing model

This commit is contained in:
alikia2x (寒寒) 2025-12-11 00:58:37 +08:00
parent c14c680228
commit 3d96f4986d
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG Key ID: 56209E0CCD8420C6
12 changed files with 1009 additions and 517 deletions

View File

@ -18,68 +18,78 @@ app = FastAPI(title="CVSA ML API", version="1.0.0")
tokenizer = None
classifier_model = None
class ClassificationRequest(BaseModel):
title: str
description: str
tags: str
aid: int = None
class ClassificationResponse(BaseModel):
label: int
probabilities: List[float]
aid: int = None
class HealthResponse(BaseModel):
status: str
models_loaded: bool
def load_models():
"""Load the tokenizer and classifier models"""
global tokenizer, classifier_model
try:
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3")
# Load classifier model
logger.info("Loading classifier model...")
from model_config import VideoClassifierV3_15
model_path = "../../model/akari/3.17.pt"
classifier_model = VideoClassifierV3_15()
classifier_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
classifier_model.load_state_dict(
torch.load(model_path, map_location=torch.device("cpu"))
)
classifier_model.eval()
logger.info("All models loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to load models: {str(e)}")
return False
def softmax(logits: np.ndarray) -> np.ndarray:
"""Apply softmax to logits"""
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / np.sum(exp_logits)
def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray:
"""Get Jina embeddings using tokenizer and ONNX-like processing"""
if tokenizer is None:
raise ValueError("Tokenizer not loaded")
import onnxruntime as ort
session = ort.InferenceSession("../../model/embedding/model.onnx")
encoded_inputs = tokenizer(
texts,
add_special_tokens=False, # 关键不添加特殊token与JS一致
return_attention_mask=False,
return_tensors=None # 返回原生Python列表便于后续处理
return_tensors=None, # 返回原生Python列表便于后续处理
)
input_ids = encoded_inputs["input_ids"] # 形状: [batch_size, seq_len_i](每个样本长度可能不同)
input_ids = encoded_inputs[
"input_ids"
] # 形状: [batch_size, seq_len_i](每个样本长度可能不同)
# 2. 计算offsets与JS的cumsum逻辑完全一致
# 先获取每个样本的token长度
lengths = [len(ids) for ids in input_ids]
@ -91,25 +101,28 @@ def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray:
cumsum.append(current_sum)
# 构建offsets起始为0后面跟累积和
offsets = [0] + cumsum # 形状: [batch_size]
# 3. 展平input_ids为一维数组
flattened_input_ids = []
for ids in input_ids:
flattened_input_ids.extend(ids) # 直接拼接所有token id
flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64)
# 4. 准备ONNX输入与JS的tensor形状保持一致
inputs = {
"input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids),
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64))
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64)),
}
# 5. 运行模型推理
outputs = session.run(None, inputs)
embeddings = outputs[0] # 假设第一个输出是embeddings形状: [batch_size, embedding_dim]
embeddings = outputs[
0
] # 假设第一个输出是embeddings形状: [batch_size, embedding_dim]
return torch.tensor(embeddings, dtype=torch.float32).numpy()
@app.on_event("startup")
async def startup_event():
"""Load models on startup"""
@ -117,91 +130,126 @@ async def startup_event():
if not success:
logger.error("Failed to load models during startup")
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
models_loaded = tokenizer is not None and classifier_model is not None
return HealthResponse(
status="healthy" if models_loaded else "models_not_loaded",
models_loaded=models_loaded
models_loaded=models_loaded,
)
@app.post("/classify", response_model=ClassificationResponse)
async def classify_video(request: ClassificationRequest):
"""Classify a video based on title, description, and tags"""
try:
if tokenizer is None or classifier_model is None:
raise HTTPException(status_code=503, detail="Models not loaded")
# Get embeddings for each channel
texts = [request.title, request.description, request.tags]
texts = [
request.title or "no title",
request.description or "no description",
request.tags or "no tags",
]
embeddings = get_jina_embeddings_1024(texts)
# Prepare input for classifier (batch_size=1, channels=3, embedding_dim=1024)
channel_features = torch.tensor(embeddings).unsqueeze(0) # [1, 3, 1024]
# Run inference
with torch.no_grad():
logits = classifier_model(channel_features)
probabilities = softmax(logits.numpy()[0])
predicted_label = int(np.argmax(probabilities))
logger.info(f"Classification completed for aid {request.aid}: label={predicted_label}")
logger.info(
f"Classification completed for aid {request.aid}: label={predicted_label}"
)
return ClassificationResponse(
label=predicted_label,
probabilities=probabilities.tolist(),
aid=request.aid
label=predicted_label, probabilities=probabilities.tolist(), aid=request.aid
)
except Exception as e:
logger.error(f"Classification error for aid {request.aid}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")
@app.post("/classify_batch")
async def classify_video_batch(requests: List[ClassificationRequest]):
"""Classify multiple videos in batch"""
"""Classify multiple videos in batch using true batch processing"""
try:
if tokenizer is None or classifier_model is None:
raise HTTPException(status_code=503, detail="Models not loaded")
if not requests:
return {"results": []}
results = []
# Collect all texts for batch processing
all_title_texts = []
all_desc_texts = []
all_tags_texts = []
for request in requests:
try:
# Get embeddings for each channel
texts = [request.title, request.description, request.tags]
embeddings = get_jina_embeddings_1024(texts)
# Prepare input for classifier
channel_features = torch.tensor(embeddings).unsqueeze(0)
# Run inference
with torch.no_grad():
logits = classifier_model(channel_features)
probabilities = softmax(logits.numpy()[0])
predicted_label = int(np.argmax(probabilities))
results.append({
"aid": request.aid,
"label": predicted_label,
"probabilities": probabilities.tolist()
})
except Exception as e:
logger.error(f"Batch classification error for aid {request.aid}: {str(e)}")
results.append({
"aid": request.aid,
"label": -1,
"probabilities": [],
"error": str(e)
})
# Handle missing or empty fields
title = request.title or "no title"
description = request.description or "no description"
tags = request.tags or "no tags"
all_title_texts.append(title)
all_desc_texts.append(description)
all_tags_texts.append(tags)
# Process all titles in batch
title_embeddings = get_jina_embeddings_1024(all_title_texts)
# Process all descriptions in batch
desc_embeddings = get_jina_embeddings_1024(all_desc_texts)
# Process all tags in batch
tags_embeddings = get_jina_embeddings_1024(all_tags_texts)
# Stack embeddings: [batch_size, 3, embedding_dim]
batch_features = np.stack(
[title_embeddings, desc_embeddings, tags_embeddings], axis=1
)
# Convert to tensor and run inference for entire batch
channel_features = torch.tensor(batch_features, dtype=torch.float32)
print(channel_features.shape)
with torch.no_grad():
logits = classifier_model(channel_features) # [batch_size, num_classes]
probabilities_batch = softmax(logits.numpy())
predicted_labels = np.argmax(probabilities_batch, axis=1)
# Prepare results
for i, request in enumerate(requests):
results.append(
{
"aid": request.aid,
"label": int(predicted_labels[i]),
"probabilities": probabilities_batch[i].tolist(),
}
)
logger.info(f"Batch classification completed for {len(requests)} requests")
return {"results": results}
except Exception as e:
logger.error(f"Batch classification failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch classification failed: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Batch classification failed: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8544)
uvicorn.run(app, host="0.0.0.0", port=8544)

View File

@ -475,6 +475,7 @@ class DatasetBuilder:
def list_datasets(self) -> List[Dict[str, Any]]:
"""List all datasets with their basic information"""
self.storage._load_metadata_cache()
return self.storage.list_datasets()
def get_dataset_stats(self) -> Dict[str, Any]:

View File

@ -153,16 +153,7 @@ async def list_datasets_endpoint():
raise HTTPException(status_code=503, detail="Dataset builder not available")
datasets = dataset_builder.list_datasets()
# Add description to each dataset
datasets_with_description = []
for dataset in datasets:
dataset_info = dataset_builder.get_dataset(dataset["dataset_id"])
if dataset_info and "description" in dataset_info:
dataset["description"] = dataset_info["description"]
else:
dataset["description"] = None
datasets_with_description.append(dataset)
return {"datasets": datasets_with_description}
return {"datasets": datasets}
@router.delete("/dataset/{dataset_id}")

View File

@ -189,85 +189,7 @@ class EmbeddingClassifier(nn.Module):
}
class AttentionEmbeddingClassifier(EmbeddingClassifier):
"""
Enhanced classifier with self-attention mechanism
"""
def __init__(
self,
input_dim: int = 2048,
hidden_dims: Optional[Tuple[int, ...]] = None,
dropout_rate: float = 0.3,
batch_norm: bool = True,
activation: str = "relu",
attention_dim: int = 512
):
super().__init__(input_dim, hidden_dims, dropout_rate, batch_norm, activation)
# Self-attention mechanism
self.attention_dim = attention_dim
self.attention = nn.MultiheadAttention(
embed_dim=input_dim,
num_heads=8,
dropout=dropout_rate,
batch_first=True
)
# Attention projection layer
self.attention_projection = nn.Linear(input_dim, attention_dim)
# Re-initialize attention weights
self._initialize_attention_weights()
logger.info(f"Initialized AttentionEmbeddingClassifier with attention_dim={attention_dim}")
def _initialize_attention_weights(self):
"""Initialize attention mechanism weights"""
for module in self.attention.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with attention mechanism
Args:
x: Input embeddings of shape (batch_size, input_dim)
Returns:
logits of shape (batch_size, 1)
"""
# Ensure input is float tensor
if not x.dtype == torch.float32:
x = x.float()
# Add sequence dimension for attention (batch_size, 1, input_dim)
x_expanded = x.unsqueeze(1)
# Apply self-attention
attended, attention_weights = self.attention(x_expanded, x_expanded, x_expanded)
# Remove sequence dimension (batch_size, input_dim)
attended = attended.squeeze(1)
# Project to attention dimension
attended = self.attention_projection(attended)
# Process through original classification layers
for layer in self.layers:
attended = layer(attended)
# Final classification layer
logits = self.classifier(attended)
return logits
def create_model(
model_type: str = "standard",
input_dim: int = 2048,
hidden_dims: Optional[Tuple[int, ...]] = None,
**kwargs
@ -276,7 +198,6 @@ def create_model(
Factory function to create embedding classifier models
Args:
model_type: Type of model ('standard', 'attention')
input_dim: Input embedding dimension
hidden_dims: Hidden layer dimensions
**kwargs: Additional model arguments
@ -284,18 +205,12 @@ def create_model(
Returns:
Initialized model
"""
if model_type == "standard":
return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
elif model_type == "attention":
return AttentionEmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
else:
raise ValueError(f"Unknown model type: {model_type}")
return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs)
if __name__ == "__main__":
# Test model creation and forward pass
model = create_model(
model_type="standard",
input_dim=2048,
hidden_dims=(512, 256, 128),
dropout_rate=0.3

636
ml_new/training/test.py Normal file
View File

@ -0,0 +1,636 @@
#!/usr/bin/env python3
"""
Test script for evaluating trained models on a dataset
"""
import argparse
import json
import sys
from pathlib import Path
import numpy as np
import torch
import aiohttp
import asyncio
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
from typing import Optional, List, Dict, Any
# Add the parent directory to the path to import ml_new modules
sys.path.append(str(Path(__file__).parent.parent))
from ml_new.training.models import create_model
from ml_new.training.data_loader import DatasetLoader, EmbeddingDataset
from ml_new.config.logger_config import get_logger
logger = get_logger(__name__)
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(
description="Test embedding classification model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
# Required arguments
parser.add_argument(
"--dataset-id",
type=str,
required=True,
help="ID of the dataset to use for testing"
)
parser.add_argument(
"--experiment",
type=str,
help="Name of the experiment to load model from"
)
# Optional arguments
parser.add_argument(
"--datasets-dir",
type=str,
default="training/datasets",
help="Directory containing dataset files"
)
parser.add_argument(
"--checkpoints-dir",
type=str,
default="training/checkpoints",
help="Directory containing model checkpoints"
)
parser.add_argument(
"--checkpoint-file",
type=str,
default="best_model.pth",
help="Checkpoint file to load (relative to experiment dir)"
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size for testing"
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of worker processes for data loading"
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="Device to use (auto, cpu, cuda)"
)
parser.add_argument(
"--normalize",
action="store_true",
default=False,
help="Normalize embeddings during testing"
)
parser.add_argument(
"--output",
type=str,
default=None,
help="Output file for detailed results (JSON)"
)
parser.add_argument(
"--threshold",
type=float,
default=0.5,
help="Classification threshold"
)
parser.add_argument(
"--use-api",
action="store_true",
default=False,
help="Use API model instead of local model"
)
parser.add_argument(
"--api-url",
type=str,
default="http://localhost:8544",
help="API base URL"
)
parser.add_argument(
"--misclassified-output",
type=str,
default=None,
help="Output file for misclassified samples (FN and FP aids)"
)
return parser.parse_args()
def setup_device(device_arg: str):
"""Setup device"""
if device_arg == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device_arg)
logger.info(f"Using device: {device}")
return device
def load_model_from_experiment(
checkpoints_dir: str,
experiment_name: str,
checkpoint_file: str,
device: torch.device
):
"""
Load a trained model from an experiment checkpoint
Args:
checkpoints_dir: Directory containing checkpoints
experiment_name: Name of the experiment
checkpoint_file: Checkpoint file name
device: Device to load model to
Returns:
Loaded model
"""
checkpoint_path = Path(checkpoints_dir) / experiment_name / checkpoint_file
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
logger.info(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Get model config from checkpoint
model_config = checkpoint.get('model_config', {})
# Create model with saved config
model = create_model(
input_dim=model_config.get('input_dim', 2048),
hidden_dims=tuple(model_config.get('hidden_dims', [512, 256, 128])),
dropout_rate=model_config.get('dropout_rate', 0.3),
batch_norm=model_config.get('batch_norm', True)
)
# Load state dict
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
logger.info(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
logger.info(f"Model config: {model_config}")
return model, model_config
def evaluate_model(
model,
test_loader: DataLoader,
device: torch.device,
threshold: float = 0.5
):
"""
Evaluate model on test set
Args:
model: The model to evaluate
test_loader: Test data loader
device: Device to use
threshold: Classification threshold
Returns:
Tuple of (metrics, predictions, probabilities, true_labels, fn_aids, fp_aids)
"""
model.eval()
criterion = torch.nn.BCEWithLogitsLoss()
total_loss = 0.0
all_predictions = []
all_labels = []
all_probabilities = []
all_metadata = []
fn_aids = []
fp_aids = []
with torch.no_grad():
for batch_idx, (embeddings, labels, metadata) in enumerate(test_loader):
embeddings = embeddings.to(device)
labels = labels.to(device).float()
# Forward pass
outputs = model(embeddings)
loss = criterion(outputs.squeeze(), labels)
# Collect statistics
total_loss += loss.item()
# Get predictions and probabilities
probabilities = torch.sigmoid(outputs).squeeze()
predictions = (probabilities > threshold).long()
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# Collect metadata and track FN/FP
batch_metadata = metadata if isinstance(metadata, list) else [metadata]
all_metadata.extend(batch_metadata)
# Track FN and FP aids for this batch
for i, (true_label, pred_label) in enumerate(zip(labels.cpu().numpy(), predictions.cpu().numpy())):
if isinstance(batch_metadata[i], dict) and 'aid' in batch_metadata[i]:
aid = batch_metadata[i]['aid']
if true_label == 1 and pred_label == 0: # False Negative
fn_aids.append(aid)
elif true_label == 0 and pred_label == 1: # False Positive
fp_aids.append(aid)
if (batch_idx + 1) % 10 == 0:
logger.info(f"Processed {batch_idx + 1}/{len(test_loader)} batches")
# Calculate metrics
test_loss = total_loss / len(test_loader)
test_accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_predictions, average='binary', zero_division=0
)
try:
test_auc = roc_auc_score(all_labels, all_probabilities)
except ValueError:
test_auc = 0.0
# Confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
if cm.size == 4:
tn, fp, fn, tp = cm.ravel()
else:
tn, fp, fn, tp = 0, 0, 0, 0
metrics = {
'loss': test_loss,
'accuracy': test_accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'auc': test_auc,
'true_negatives': int(tn),
'false_positives': int(fp),
'false_negatives': int(fn),
'true_positives': int(tp),
'total_samples': len(all_labels),
'threshold': threshold
}
# Add class distribution
unique, counts = np.unique(all_labels, return_counts=True)
metrics['class_distribution'] = {int(k): int(v) for k, v in zip(unique, counts)}
return metrics, all_predictions, all_probabilities, all_labels, fn_aids, fp_aids
async def call_api_batch(session: aiohttp.ClientSession, api_url: str, requests: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]:
"""Call the classification API for batch requests"""
try:
url = f"{api_url}/classify_batch"
async with session.post(url, json=requests) as response:
if response.status == 200:
result = await response.json()
return result.get('results', [])
else:
logger.warning(f"Batch API request failed with status {response.status}")
return None
except Exception as e:
logger.warning(f"Batch API request failed: {e}")
return None
def convert_api_label_to_bool(api_label: int) -> int:
"""Convert API label to boolean (non-zero = true)"""
return 1 if api_label != 0 else 0
async def evaluate_with_api(
embeddings: np.ndarray,
labels: np.ndarray,
metadata: List[Dict[str, Any]],
api_url: str,
batch_size: int = 32
):
"""
Evaluate using the API instead of local model
Args:
embeddings: Array of embeddings (not used for API calls)
labels: Ground truth labels
metadata: Metadata containing title, description, tags, aid
api_url: API base URL
batch_size: Number of requests per API batch call
Returns:
Tuple of (metrics, predictions, probabilities, true_labels, fn_aids, fp_aids)
"""
logger.info(f"Using API at {api_url} for evaluation")
# Prepare API requests
requests = []
for i, meta in enumerate(metadata):
# Extract metadata fields for API
title = meta.get('title', '')
description = meta.get('description', '')
tags = meta.get('tags', '')
aid = meta.get('aid', i)
# Handle missing or empty fields
if not title:
title = f"Video {aid}"
if not description:
description = ""
if not tags:
tags = ""
request_data = {
"title": title,
"description": description,
"tags": tags,
"aid": aid
}
requests.append(request_data)
# Split requests into batches
num_batches = (len(requests) + batch_size - 1) // batch_size
logger.info(f"Making {num_batches} batch API requests with batch_size={batch_size} for {len(requests)} total requests")
# Process all batches
all_predictions = []
all_probabilities = []
all_labels = labels.tolist()
all_aids = [meta.get('aid', i) for i, meta in enumerate(metadata)]
failed_requests = 0
fn_aids = []
fp_aids = []
async with aiohttp.ClientSession() as session:
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(requests))
batch_requests = requests[start_idx:end_idx]
logger.info(f"Processing batch {batch_idx + 1}/{num_batches} ({len(batch_requests)} requests)")
results = await call_api_batch(session, api_url, batch_requests)
if results is None:
logger.error(f"Batch {batch_idx + 1} API request failed completely")
# Create dummy results for this batch
all_predictions.extend([0] * len(batch_requests))
all_probabilities.extend([0.0] * len(batch_requests))
failed_requests += len(batch_requests)
continue
for i, result in enumerate(results):
global_idx = start_idx + i
if not isinstance(result, dict) or 'error' in result:
error_msg = result.get('error', 'Unknown error') if isinstance(result, dict) else 'Invalid result'
logger.warning(f"Failed to get API prediction for request {global_idx}: {error_msg}")
failed_requests += 1
all_predictions.append(0)
all_probabilities.append(0.0)
continue
# Convert API response to our format
api_label = result.get('label', -1)
probabilities = result.get('probabilities')
# Convert to boolean (non-zero = true)
prediction = convert_api_label_to_bool(api_label)
# Use the probability of the positive class
if probabilities and len(probabilities) > 0:
positive_prob = 1 - probabilities[0]
else:
logger.warning(f"No probabilities for request {global_idx}")
failed_requests += 1
all_predictions.append(0)
all_probabilities.append(0.0)
continue
all_predictions.append(prediction)
all_probabilities.append(positive_prob)
if failed_requests > 0:
logger.warning(f"Failed to get API predictions for {failed_requests} requests")
# Collect FN and FP aids
for i, (true_label, pred_label) in enumerate(zip(all_labels, all_predictions)):
aid = all_aids[i]
if true_label == 1 and pred_label == 0: # False Negative
fn_aids.append(aid)
elif true_label == 0 and pred_label == 1: # False Positive
fp_aids.append(aid)
# Calculate metrics
test_accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, _ = precision_recall_fscore_support(
all_labels, all_predictions, average='binary', zero_division=0
)
try:
test_auc = roc_auc_score(all_labels, all_probabilities)
except ValueError:
test_auc = 0.0
# Confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
if cm.size == 4:
tn, fp, fn, tp = cm.ravel()
else:
tn, fp, fn, tp = 0, 0, 0, 0
metrics = {
'accuracy': test_accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'auc': test_auc,
'true_negatives': int(tn),
'false_positives': int(fp),
'false_negatives': int(fn),
'true_positives': int(tp),
'total_samples': len(all_labels),
'failed_requests': failed_requests
}
# Add class distribution
unique, counts = np.unique(all_labels, return_counts=True)
metrics['class_distribution'] = {int(k): int(v) for k, v in zip(unique, counts)}
return metrics, all_predictions, all_probabilities, all_labels, fn_aids, fp_aids
def main():
"""Main test function"""
args = parse_args()
# Setup device
device = setup_device(args.device)
# Check if dataset exists
loader = DatasetLoader(args.datasets_dir)
datasets = loader.list_datasets()
if args.dataset_id not in datasets:
logger.error(f"Dataset '{args.dataset_id}' not found in {args.datasets_dir}")
logger.info(f"Available datasets: {datasets}")
sys.exit(1)
# Load dataset (use entire dataset as test set)
try:
logger.info(f"Loading dataset {args.dataset_id}...")
embeddings, labels, metadata = loader.load_dataset(args.dataset_id)
logger.info(f"Dataset loaded: {len(embeddings)} samples")
logger.info(f"Embedding dimension: {embeddings.shape[1]}")
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
sys.exit(1)
# Choose evaluation method
if args.use_api:
# Use API for evaluation
logger.info("Using API-based evaluation")
# Run async evaluation
metrics, predictions, probabilities, true_labels, fn_aids, fp_aids = asyncio.run(
evaluate_with_api(
embeddings, labels, metadata,
args.api_url,
args.batch_size
)
)
# For API mode, we don't have model_config
model_config = {"type": "api", "api_url": args.api_url}
else:
# Use local model for evaluation
# Check if experiment exists
experiment_dir = Path(args.checkpoints_dir) / args.experiment
if not experiment_dir.exists():
logger.error(f"Experiment '{args.experiment}' not found in {args.checkpoints_dir}")
available = [d.name for d in Path(args.checkpoints_dir).iterdir() if d.is_dir()]
logger.info(f"Available experiments: {available}")
sys.exit(1)
# Load model
try:
model, model_config = load_model_from_experiment(
args.checkpoints_dir,
args.experiment,
args.checkpoint_file,
device
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
sys.exit(1)
# Create test dataset and loader
test_dataset = EmbeddingDataset(
embeddings, labels, metadata,
normalize=args.normalize
)
test_loader = DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers
)
# Evaluate model
logger.info("Starting local model evaluation...")
metrics, predictions, probabilities, true_labels, fn_aids, fp_aids = evaluate_model(
model, test_loader, device, args.threshold
)
# Print results
logger.info("=" * 50)
logger.info("Test Results")
logger.info("=" * 50)
logger.info(f"Dataset: {args.dataset_id}")
if args.use_api:
logger.info(f"Method: API ({args.api_url})")
else:
logger.info(f"Experiment: {args.experiment}")
logger.info(f"Total samples: {metrics['total_samples']}")
logger.info(f"Class distribution: {metrics['class_distribution']}")
if 'failed_requests' in metrics:
logger.info(f"Failed API requests: {metrics['failed_requests']}")
logger.info("-" * 50)
if 'loss' in metrics:
logger.info(f"Loss: {metrics['loss']:.4f}")
logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
logger.info(f"Precision: {metrics['precision']:.4f}")
logger.info(f"Recall: {metrics['recall']:.4f}")
logger.info(f"F1 Score: {metrics['f1']:.4f}")
logger.info(f"AUC: {metrics['auc']:.4f}")
logger.info("-" * 50)
logger.info(f"True Positives: {metrics['true_positives']}")
logger.info(f"True Negatives: {metrics['true_negatives']}")
logger.info(f"False Positives: {metrics['false_positives']}")
logger.info(f"False Negatives: {metrics['false_negatives']}")
logger.info("=" * 50)
# Save detailed results if requested
if args.output:
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
results = {
'dataset_id': args.dataset_id,
'experiment': args.experiment,
'checkpoint': args.checkpoint_file,
'model_config': model_config,
'metrics': metrics,
'predictions': [int(p) for p in predictions],
'probabilities': [float(p) for p in probabilities],
'labels': [int(l) for l in true_labels]
}
with open(output_path, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Detailed results saved to {output_path}")
# Save misclassified samples (FN and FP aids) if requested
if args.misclassified_output:
misclassified_path = Path(args.misclassified_output)
misclassified_path.parent.mkdir(parents=True, exist_ok=True)
misclassified_data = {
'false_negatives': fn_aids,
'false_positives': fp_aids,
'fn_count': len(fn_aids),
'fp_count': len(fp_aids),
'total_misclassified': len(fn_aids) + len(fp_aids)
}
with open(misclassified_path, 'w') as f:
json.dump(misclassified_data, f, indent=2)
logger.info(f"Misclassified samples (FN: {len(fn_aids)}, FP: {len(fp_aids)}) saved to {misclassified_path}")
if __name__ == "__main__":
main()

View File

@ -62,15 +62,6 @@ def parse_args():
help="Directory containing dataset files"
)
# Model arguments
parser.add_argument(
"--model-type",
type=str,
choices=["standard", "attention"],
default="standard",
help="Type of model architecture"
)
parser.add_argument(
"--input-dim",
type=int,
@ -353,11 +344,11 @@ def main():
# Create experiment name if not provided
if args.experiment_name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.experiment_name = f"{args.model_type}_{args.dataset_id}_{timestamp}"
args.experiment_name = f"{timestamp}_{args.dataset_id}"
logger.info(f"Starting experiment: {args.experiment_name}")
logger.info(f"Dataset: {args.dataset_id}")
logger.info(f"Model: {args.model_type} with hidden dims {args.hidden_dims}")
logger.info(f"Model: hidden dims {args.hidden_dims}")
# Load dataset and create data loaders
try:
@ -387,7 +378,6 @@ def main():
try:
logger.info("Creating model...")
model = create_model(
model_type=args.model_type,
input_dim=args.input_dim,
hidden_dims=tuple(args.hidden_dims),
dropout_rate=args.dropout_rate,
@ -452,7 +442,6 @@ def main():
"experiment_name": args.experiment_name,
"dataset_id": args.dataset_id,
"model_config": {
"model_type": args.model_type,
"input_dim": args.input_dim,
"hidden_dims": args.hidden_dims,
"dropout_rate": args.dropout_rate,

View File

@ -417,7 +417,7 @@ class ModelTrainer:
def load_checkpoint(self, checkpoint_path: str, load_optimizer: bool = True) -> None:
"""Load model from checkpoint"""
checkpoint = torch.load(checkpoint_path, map_location=self.device)
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
@ -504,7 +504,6 @@ if __name__ == "__main__":
# Create dummy model and data
model = create_model(
model_type="standard",
input_dim=2048,
hidden_dims=(512, 256, 128)
)

View File

@ -1,10 +1,9 @@
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { CardDescription, CardTitle } from "@/components/ui/card";
import { CardTitle } from "@/components/ui/card";
import { DatasetManager } from "@/components/DatasetManager";
import { TaskMonitor } from "@/components/TaskMonitor";
import { SamplingPanel } from "@/components/SamplingPanel";
import { Database, Activity, Settings } from "lucide-react";
import { Database, Activity } from "lucide-react";
const queryClient = new QueryClient({
defaultOptions: {
@ -21,22 +20,17 @@ function App() {
<div className="min-h-screen flex justify-center">
<div className="container lg:max-w-3xl xl:max-w-4xl bg-background py-8 px-3">
<div className="mb-8">
<h1 className="text-3xl font-bold tracking-tight">ML Dataset Management Panel</h1>
<p className="text-muted-foreground">
Create and manage machine learning datasets with multiple sampling strategies and task monitoring
</p>
<h1 className="text-3xl font-bold tracking-tight">
CVSA Machine Learning Panel
</h1>
</div>
<Tabs defaultValue="datasets" className="space-y-4">
<TabsList className="grid w-full grid-cols-3">
<TabsList className="grid w-full grid-cols-2">
<TabsTrigger value="datasets" className="flex items-center gap-2">
<Database className="h-4 w-4" />
Datasets
</TabsTrigger>
<TabsTrigger value="sampling" className="flex items-center gap-2">
<Settings className="h-4 w-4" />
Sampling
</TabsTrigger>
<TabsTrigger value="monitor" className="flex items-center gap-2">
<Activity className="h-4 w-4" />
Tasks
@ -44,22 +38,11 @@ function App() {
</TabsList>
<TabsContent value="datasets" className="space-y-4">
<CardTitle>Dataset Management</CardTitle>
<CardDescription>View, create and manage your machine learning datasets</CardDescription>
<DatasetManager />
</TabsContent>
<TabsContent value="sampling" className="space-y-4">
<CardTitle>Sampling Strategy Configuration</CardTitle>
<CardDescription>
Configure different data sampling strategies to create balanced datasets
</CardDescription>
<SamplingPanel />
</TabsContent>
<TabsContent value="monitor" className="space-y-4">
<CardTitle>Task Monitor</CardTitle>
<CardDescription>Monitor real-time status and progress of dataset building tasks</CardDescription>
<TaskMonitor />
</TabsContent>
</Tabs>

View File

@ -1,4 +1,4 @@
import { useState } from "react";
import { useState } from "react";
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { Button } from "@/components/ui/button";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
@ -20,19 +20,22 @@ import {
SelectValue
} from "@/components/ui/select";
import { Textarea } from "@/components/ui/textarea";
import { Trash2, Plus, Database, FileText, Calendar, Activity } from "lucide-react";
import { Trash2, Plus, Database, Upload } from "lucide-react";
import { apiClient } from "@/lib/api";
import { toast } from "sonner";
import { Spinner } from "@/components/ui/spinner"
import { Spinner } from "@/components/ui/spinner";
export function DatasetManager() {
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
const [createFormData, setCreateFormData] = useState({
creationMode: "sampling", // "sampling" or "aidList"
strategy: "all",
limit: "",
embeddingModel: "",
description: "",
forceRegenerate: false
forceRegenerate: false,
aidListFile: null as File | null,
aidList: [] as number[]
});
const queryClient = useQueryClient();
@ -45,7 +48,7 @@ export function DatasetManager() {
});
// Fetch embedding models
const { data: modelsData, isLoading: modelsLoading } = useQuery({
const { data: modelsData } = useQuery({
queryKey: ["embedding-models"],
queryFn: () => apiClient.getEmbeddingModels()
});
@ -57,11 +60,14 @@ export function DatasetManager() {
toast.success("Dataset creation task started");
setIsCreateDialogOpen(false);
setCreateFormData({
creationMode: "sampling",
strategy: "all",
limit: "",
embeddingModel: "",
description: "",
forceRegenerate: false
forceRegenerate: false,
aidListFile: null,
aidList: []
});
queryClient.invalidateQueries({ queryKey: ["datasets"] });
queryClient.invalidateQueries({ queryKey: ["tasks"] });
@ -83,23 +89,68 @@ export function DatasetManager() {
}
});
// Build dataset mutation
const buildDatasetMutation = useMutation({
mutationFn: (data: {
aid_list: number[];
embedding_model: string;
force_regenerate?: boolean;
description?: string;
}) => apiClient.buildDataset(data),
onSuccess: () => {
toast.success("Dataset build task started");
setIsCreateDialogOpen(false);
setCreateFormData({
creationMode: "sampling",
strategy: "all",
limit: "",
embeddingModel: "",
description: "",
forceRegenerate: false,
aidListFile: null,
aidList: []
});
queryClient.invalidateQueries({ queryKey: ["datasets"] });
queryClient.invalidateQueries({ queryKey: ["tasks"] });
},
onError: (error: Error) => {
toast.error(`Build failed: ${error.message}`);
}
});
const handleCreateDataset = () => {
if (!createFormData.embeddingModel) {
toast.error("Please select an embedding model");
return;
}
const requestData = {
sampling: {
strategy: createFormData.strategy,
...(createFormData.limit && { limit: parseInt(createFormData.limit) })
},
embedding_model: createFormData.embeddingModel,
force_regenerate: createFormData.forceRegenerate,
description: createFormData.description || undefined
};
if (createFormData.creationMode === "sampling") {
const requestData = {
sampling: {
strategy: createFormData.strategy,
...(createFormData.limit && { limit: parseInt(createFormData.limit) })
},
embedding_model: createFormData.embeddingModel,
force_regenerate: createFormData.forceRegenerate,
description: createFormData.description || undefined
};
createDatasetMutation.mutate(requestData);
createDatasetMutation.mutate(requestData);
} else if (createFormData.creationMode === "aidList") {
if (createFormData.aidList.length === 0) {
toast.error("Please upload an aid list file");
return;
}
const requestData = {
aid_list: createFormData.aidList,
embedding_model: createFormData.embeddingModel,
force_regenerate: createFormData.forceRegenerate,
description: createFormData.description || undefined
};
buildDatasetMutation.mutate(requestData);
}
};
const handleDeleteDataset = (datasetId: string) => {
@ -108,16 +159,67 @@ export function DatasetManager() {
}
};
const formatDate = (dateString: string) => {
return new Date(dateString).toLocaleString("en-US");
// Parse aid list file
const parseAidListFile = (file: File): Promise<number[]> => {
return new Promise((resolve, reject) => {
const reader = new FileReader();
reader.onload = (e) => {
try {
const content = e.target?.result as string;
const lines = content.split("\n").filter((line) => line.trim());
const aidList: number[] = [];
for (const line of lines) {
const trimmed = line.trim();
if (trimmed) {
const aid = parseInt(trimmed, 10);
if (!isNaN(aid)) {
aidList.push(aid);
}
}
}
resolve(aidList);
} catch (error) {
reject(new Error("Failed to parse file"));
}
};
reader.onerror = () => reject(new Error("Failed to read file"));
reader.readAsText(file);
});
};
const formatFileSize = (bytes: number) => {
if (bytes === 0) return "0 Bytes";
const k = 1024;
const sizes = ["Bytes", "KB", "MB", "GB"];
const i = Math.floor(Math.log(bytes) / Math.log(k));
return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + " " + sizes[i];
// Handle file upload
const handleFileUpload = async (event: React.ChangeEvent<HTMLInputElement>) => {
const file = event.target.files?.[0];
if (!file) return;
if (!file.name.endsWith(".txt") && !file.name.endsWith(".csv")) {
toast.error("Please upload a .txt or .csv file");
return;
}
try {
const aidList = await parseAidListFile(file);
if (aidList.length === 0) {
toast.error("No valid AIDs found in the file");
return;
}
setCreateFormData((prev) => ({
...prev,
aidListFile: file,
aidList: aidList
}));
toast.success(`Loaded ${aidList.length} AIDs from file`);
} catch (error) {
toast.error("Failed to parse aid list file");
}
};
const formatDate = (dateString: string) => {
return new Date(dateString).toLocaleString("en-US");
};
if (datasetsLoading) {
@ -150,43 +252,120 @@ export function DatasetManager() {
<DialogHeader>
<DialogTitle>Create New Dataset</DialogTitle>
<DialogDescription>
Select sampling strategy and configuration parameters to create a new dataset
Select sampling strategy and configuration parameters to create a
new dataset
</DialogDescription>
</DialogHeader>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label htmlFor="strategy">Sampling Strategy</Label>
<Label htmlFor="creationMode">Creation Mode</Label>
<Select
value={createFormData.strategy}
value={createFormData.creationMode}
onValueChange={(value) =>
setCreateFormData((prev) => ({ ...prev, strategy: value }))
setCreateFormData((prev) => ({
...prev,
creationMode: value,
// Reset aid list when switching modes
aidListFile: null,
aidList: []
}))
}
>
<SelectTrigger>
<SelectValue placeholder="Select sampling strategy" />
<SelectValue placeholder="Select creation mode" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">All Videos</SelectItem>
<SelectItem value="random">Random Sampling</SelectItem>
<SelectItem value="sampling">Sampling Strategy</SelectItem>
<SelectItem value="aidList">Upload Aid List</SelectItem>
</SelectContent>
</Select>
</div>
{createFormData.strategy === "random" && (
{createFormData.creationMode === "sampling" && (
<div className="grid gap-2">
<Label htmlFor="limit">Sample Count</Label>
<Textarea
id="limit"
placeholder="Enter number of samples, e.g., 1000"
value={createFormData.limit}
onChange={(e) =>
<Label htmlFor="strategy">Sampling Strategy</Label>
<Select
value={createFormData.strategy}
onValueChange={(value) =>
setCreateFormData((prev) => ({
...prev,
limit: e.target.value
strategy: value
}))
}
>
<SelectTrigger>
<SelectValue placeholder="Select sampling strategy" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">All Videos</SelectItem>
<SelectItem value="random">Random Sampling</SelectItem>
</SelectContent>
</Select>
</div>
)}
{createFormData.creationMode === "sampling" &&
createFormData.strategy === "random" && (
<div className="grid gap-2">
<Label htmlFor="limit">Sample Count</Label>
<Textarea
id="limit"
placeholder="Enter number of samples, e.g., 1000"
value={createFormData.limit}
onChange={(e) =>
setCreateFormData((prev) => ({
...prev,
limit: e.target.value
}))
}
/>
</div>
)}
{createFormData.creationMode === "aidList" && (
<div className="grid gap-2">
<Label htmlFor="aidListFile">Aid List File</Label>
<div
className="border-2 border-dashed rounded-lg p-4 cursor-pointer"
onClick={() =>
document.getElementById("aidListFile")?.click()
}
>
<div className="flex flex-col items-center space-y-2">
<Upload className="h-8 w-8text-secondary-foreground" />
<div className="text-sm text-secondary-foreground text-center">
{createFormData.aidListFile
? `${createFormData.aidListFile.name} (${createFormData.aidList.length} AIDs loaded)`
: "Click to upload a .txt or .csv file containing AIDs (one per line)"}
</div>
<Button
type="button"
variant="outline"
size="sm"
className="mt-2"
onClick={(e) => {
e.stopPropagation();
document.getElementById("aidListFile")?.click();
}}
>
Choose File
</Button>
</div>
</div>
<input
id="aidListFile"
type="file"
accept=".txt,.csv"
onChange={handleFileUpload}
className="hidden"
/>
{createFormData.aidList.length > 0 && (
<div className="text-sm text-green-600">
Loaded {createFormData.aidList.length} AIDs from{" "}
{createFormData.aidListFile?.name}
</div>
)}
</div>
)}
@ -253,9 +432,18 @@ export function DatasetManager() {
</Button>
<Button
onClick={handleCreateDataset}
disabled={createDatasetMutation.isPending}
disabled={
createDatasetMutation.isPending ||
buildDatasetMutation.isPending ||
(createFormData.creationMode === "aidList" &&
createFormData.aidList.length === 0)
}
>
{createDatasetMutation.isPending ? "Creating..." : "Create Dataset"}
{createDatasetMutation.isPending || buildDatasetMutation.isPending
? "Creating..."
: createFormData.creationMode === "sampling"
? "Create Dataset"
: "Build Dataset"}
</Button>
</DialogFooter>
</DialogContent>
@ -270,8 +458,8 @@ export function DatasetManager() {
<CardHeader className="pb-3">
<div className="flex items-start justify-between">
<div className="flex items-center space-x-2">
<CardTitle className="text-base">
{dataset.dataset_id.slice(0, 8)}...{dataset.dataset_id.slice(-8)}
<CardTitle className="text-base line-clamp-1">
{dataset.dataset_id}
</CardTitle>
</div>
<Button
@ -288,22 +476,13 @@ export function DatasetManager() {
)}
</CardHeader>
<CardContent>
<div className="grid grid-cols-2 lg:grid-cols-3 xl:grid-cols-4 gap-4 text-sm">
<div className="flex items-center space-x-2">
<span>{dataset.stats.total_records} records</span>
</div>
<div className="flex items-center space-x-2">
<span>{dataset.stats.embedding_model}</span>
</div>
<div className="flex items-center space-x-2">
<span>{formatDate(dataset.created_at)}</span>
</div>
<div className="flex items-center space-x-2">
<span className="text-muted-foreground">
New: {dataset.stats.new_embeddings}
</span>
</div>
<div className="flex flex-wrap gap-5 text-sm leading-1">
<span>{dataset.stats.total_records} records</span>
<span>{dataset.stats.embedding_model}</span>
<span>{formatDate(dataset.created_at)}</span>
<span className="text-muted-foreground">
New: {dataset.stats.new_embeddings}
</span>
</div>
</CardContent>
</Card>

View File

@ -1,234 +0,0 @@
import { useState } from "react";
import { useMutation, useQuery } from "@tanstack/react-query";
import { Button } from "@/components/ui/button";
import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue
} from "@/components/ui/select";
import { Textarea } from "@/components/ui/textarea";
import { Badge } from "@/components/ui/badge";
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
import { Alert, AlertDescription } from "@/components/ui/alert";
import { Database, Play, TestTube, Settings, BarChart3 } from "lucide-react";
import { apiClient } from "@/lib/api";
import type { SamplingResponse, DatasetCreateResponse } from "@/types/api";
interface SamplingConfig {
strategy: string;
limit?: number;
}
export function SamplingPanel() {
const [samplingConfig, setSamplingConfig] = useState<SamplingConfig>({
strategy: "all",
limit: undefined,
});
const [embeddingModel, setEmbeddingModel] = useState<string>("");
const [description, setDescription] = useState<string>("");
// Test sampling mutation
const testSamplingMutation = useMutation({
mutationFn: (config: SamplingConfig) => apiClient.sampleDataset(config),
onSuccess: (data: SamplingResponse) => {
console.log("Sampling test successful:", data);
},
onError: (error: Error) => {
console.error("Sampling test failed:", error);
}
});
// Create dataset with sampling mutation
const createDatasetMutation = useMutation({
mutationFn: (config: {
sampling: SamplingConfig;
embedding_model: string;
description?: string;
}) => apiClient.createDatasetWithSampling(config),
onSuccess: (data: DatasetCreateResponse) => {
console.log("Dataset created successfully:", data);
},
onError: (error: Error) => {
console.error("Dataset creation failed:", error);
}
});
const handleStrategyChange = (strategy: string) => {
setSamplingConfig((prev) => ({ ...prev, strategy }));
};
const handleLimitChange = (limit: string) => {
setSamplingConfig((prev) => ({
...prev,
limit: limit ? parseInt(limit) : undefined
}));
};
const handleTestSampling = () => {
testSamplingMutation.mutate(samplingConfig);
};
const handleCreateDataset = () => {
if (!embeddingModel) {
alert("Please select an embedding model");
return;
}
createDatasetMutation.mutate({
sampling: samplingConfig,
embedding_model: embeddingModel,
description: description || undefined
});
};
const getStrategyDescription = (strategy: string) => {
switch (strategy) {
case "all":
return "Sample all labeled videos";
case "random":
return "Randomly sample specified number of labeled videos";
default:
return "Unknown strategy";
}
};
return (
<div className="space-y-6">
<Tabs defaultValue="configure" className="w-full">
<TabsList className="w-full mb-4">
<TabsTrigger value="configure">
<Settings className="h-4 w-4 mr-2" />
Configure Sampling
</TabsTrigger>
<TabsTrigger value="test">
<TestTube className="h-4 w-4 mr-2" />
Test Sampling
</TabsTrigger>
</TabsList>
<TabsContent value="configure" className="space-y-4">
<Card>
<CardHeader>
<CardTitle>Sampling Strategy Configuration</CardTitle>
<CardDescription>Select data sampling strategy and parameters</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<Label htmlFor="strategy">Sampling Strategy</Label>
<Select
value={samplingConfig.strategy}
onValueChange={handleStrategyChange}
>
<SelectTrigger>
<SelectValue placeholder="Select strategy" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">All Labeled Videos</SelectItem>
<SelectItem value="random">Random Sampling</SelectItem>
</SelectContent>
</Select>
<p className="text-sm text-muted-foreground">
{getStrategyDescription(samplingConfig.strategy)}
</p>
</div>
{samplingConfig.strategy === "random" && (
<div className="space-y-2">
<Label htmlFor="limit">Sample Count</Label>
<Input
id="limit"
type="number"
placeholder="e.g., 1000"
value={samplingConfig.limit || ""}
onChange={(e) => handleLimitChange(e.target.value)}
/>
</div>
)}
</div>
</CardContent>
</Card>
</TabsContent>
<TabsContent value="test" className="space-y-4">
<Card>
<CardHeader>
<CardTitle>Test Sampling</CardTitle>
<CardDescription>
Test sampling strategy and view data statistics for sampling
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
<div className="flex space-x-4">
<Button
onClick={handleTestSampling}
disabled={testSamplingMutation.isPending}
className="flex-1"
>
<Play className="h-4 w-4 mr-2" />
{testSamplingMutation.isPending ? "Testing..." : "Start Test"}
</Button>
</div>
{testSamplingMutation.isSuccess && testSamplingMutation.data && (
<Alert>
<BarChart3 className="h-4 w-4" />
<AlertDescription>
<div className="space-y-2">
<div className="flex items-center justify-between">
<span>Total available data:</span>
<Badge variant="outline">
{(
testSamplingMutation.data as SamplingResponse
).total_available.toLocaleString()}
</Badge>
</div>
<div className="flex items-center justify-between">
<span>Will sample:</span>
<Badge>
{(
testSamplingMutation.data as SamplingResponse
).sampled_count.toLocaleString()}
</Badge>
</div>
<div className="flex items-center justify-between">
<span>Sampling ratio:</span>
<Badge variant="secondary">
{(
((
testSamplingMutation.data as SamplingResponse
).sampled_count /
(
testSamplingMutation.data as SamplingResponse
).total_available) *
100
).toFixed(1)}
%
</Badge>
</div>
</div>
</AlertDescription>
</Alert>
)}
{testSamplingMutation.isError && (
<Alert variant="destructive">
<AlertDescription>
Test failed: {(testSamplingMutation.error as Error).message}
</AlertDescription>
</Alert>
)}
</CardContent>
</Card>
</TabsContent>
</Tabs>
</div>
);
}

View File

@ -7,14 +7,14 @@ import {
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
SelectValue
} from "@/components/ui/select";
import { Progress } from "@/components/ui/progress";
import { Badge } from "@/components/ui/badge";
import { RefreshCw, Play, Pause, CheckCircle, XCircle, Clock } from "lucide-react";
import { RefreshCw, Clock } from "lucide-react";
import { apiClient } from "@/lib/api";
import type { TasksResponse } from "@/types/api";
import { Spinner } from "@/components/ui/spinner"
import { Spinner } from "@/components/ui/spinner";
export function TaskMonitor() {
const [statusFilter, setStatusFilter] = useState<string>("all");
@ -33,21 +33,6 @@ export function TaskMonitor() {
refetchInterval: 500
});
const getStatusIcon = (status: string) => {
switch (status) {
case "running":
return <Play className="h-4 w-4 text-blue-500" />;
case "completed":
return <CheckCircle className="h-4 w-4 text-green-500" />;
case "failed":
return <XCircle className="h-4 w-4 text-red-500" />;
case "pending":
return <Clock className="h-4 w-4 text-yellow-500" />;
default:
return <Pause className="h-4 w-4 text-gray-500" />;
}
};
const getStatusBadgeVariant = (status: string) => {
switch (status) {
case "running":
@ -80,7 +65,7 @@ export function TaskMonitor() {
if (tasksLoading) {
return (
<div className="flex items-center justify-center h-64">
<Spinner/>
<Spinner />
</div>
);
}
@ -118,13 +103,10 @@ export function TaskMonitor() {
<CardContent className="p-4">
<div className="flex items-start justify-between mb-3">
<div className="flex items-center space-x-2">
{getStatusIcon(task.status)}
<span className="font-mono text-sm">
{task.task_id.slice(0, 8)}...
</span>
<Badge variant={getStatusBadgeVariant(task.status)}>
{task.status}
</Badge>
<span className="font-mono text-sm">{task.task_id}</span>
</div>
<div className="text-sm text-muted-foreground">
{formatDate(task.created_at)}
@ -152,14 +134,18 @@ export function TaskMonitor() {
<div className="grid grid-cols-2 md:grid-cols-4 gap-4 text-sm">
{task.started_at && (
<div>
<span className="text-muted-foreground">Start Time:</span>
<span className="text-muted-foreground">
Start Time:
</span>
<br />
{formatDate(task.started_at)}
</div>
)}
{task.completed_at && (
<div>
<span className="text-muted-foreground">Complete Time:</span>
<span className="text-muted-foreground">
Complete Time:
</span>
<br />
{formatDate(task.completed_at)}
</div>

View File

@ -4,7 +4,6 @@ import type {
EmbeddingModelsResponse,
DatasetsResponse,
DatasetDetail,
SamplingStats,
SamplingResponse,
DatasetCreateResponse,
Task,