From 3d96f4986d30031df6e76d448a04d98f01f13c7b Mon Sep 17 00:00:00 2001 From: alikia2x Date: Thu, 11 Dec 2025 00:58:37 +0800 Subject: [PATCH] add: script for testing model --- ml/api/main.py | 176 +++-- ml_new/data/dataset_service.py | 1 + ml_new/routes/main.py | 11 +- ml_new/training/models.py | 87 +-- ml_new/training/test.py | 636 ++++++++++++++++++ ml_new/training/train.py | 15 +- ml_new/training/trainer.py | 3 +- packages/ml_panel/src/App.tsx | 29 +- .../src/components/DatasetManager.tsx | 297 ++++++-- .../ml_panel/src/components/SamplingPanel.tsx | 234 ------- .../ml_panel/src/components/TaskMonitor.tsx | 36 +- packages/ml_panel/src/lib/api.ts | 1 - 12 files changed, 1009 insertions(+), 517 deletions(-) create mode 100644 ml_new/training/test.py delete mode 100644 packages/ml_panel/src/components/SamplingPanel.tsx diff --git a/ml/api/main.py b/ml/api/main.py index 462f60a..d4112b0 100644 --- a/ml/api/main.py +++ b/ml/api/main.py @@ -18,68 +18,78 @@ app = FastAPI(title="CVSA ML API", version="1.0.0") tokenizer = None classifier_model = None + class ClassificationRequest(BaseModel): title: str description: str tags: str aid: int = None + class ClassificationResponse(BaseModel): label: int probabilities: List[float] aid: int = None + class HealthResponse(BaseModel): status: str models_loaded: bool + def load_models(): """Load the tokenizer and classifier models""" global tokenizer, classifier_model - + try: # Load tokenizer logger.info("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3") - + # Load classifier model logger.info("Loading classifier model...") from model_config import VideoClassifierV3_15 - + model_path = "../../model/akari/3.17.pt" classifier_model = VideoClassifierV3_15() - classifier_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) + classifier_model.load_state_dict( + torch.load(model_path, map_location=torch.device("cpu")) + ) classifier_model.eval() - + logger.info("All models loaded successfully") return True - + except Exception as e: logger.error(f"Failed to load models: {str(e)}") return False + def softmax(logits: np.ndarray) -> np.ndarray: """Apply softmax to logits""" exp_logits = np.exp(logits - np.max(logits)) return exp_logits / np.sum(exp_logits) + def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray: """Get Jina embeddings using tokenizer and ONNX-like processing""" if tokenizer is None: raise ValueError("Tokenizer not loaded") - + import onnxruntime as ort - + session = ort.InferenceSession("../../model/embedding/model.onnx") - + encoded_inputs = tokenizer( texts, add_special_tokens=False, # 关键:不添加特殊token(与JS一致) return_attention_mask=False, - return_tensors=None # 返回原生Python列表,便于后续处理 + return_tensors=None, # 返回原生Python列表,便于后续处理 ) - input_ids = encoded_inputs["input_ids"] # 形状: [batch_size, seq_len_i](每个样本长度可能不同) - + input_ids = encoded_inputs[ + "input_ids" + ] # 形状: [batch_size, seq_len_i](每个样本长度可能不同) + # 2. 计算offsets(与JS的cumsum逻辑完全一致) # 先获取每个样本的token长度 lengths = [len(ids) for ids in input_ids] @@ -91,25 +101,28 @@ def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray: cumsum.append(current_sum) # 构建offsets:起始为0,后面跟累积和 offsets = [0] + cumsum # 形状: [batch_size] - + # 3. 展平input_ids为一维数组 flattened_input_ids = [] for ids in input_ids: flattened_input_ids.extend(ids) # 直接拼接所有token id flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64) - + # 4. 准备ONNX输入(与JS的tensor形状保持一致) inputs = { "input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids), - "offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64)) + "offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64)), } - + # 5. 运行模型推理 outputs = session.run(None, inputs) - embeddings = outputs[0] # 假设第一个输出是embeddings,形状: [batch_size, embedding_dim] - + embeddings = outputs[ + 0 + ] # 假设第一个输出是embeddings,形状: [batch_size, embedding_dim] + return torch.tensor(embeddings, dtype=torch.float32).numpy() + @app.on_event("startup") async def startup_event(): """Load models on startup""" @@ -117,91 +130,126 @@ async def startup_event(): if not success: logger.error("Failed to load models during startup") + @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" models_loaded = tokenizer is not None and classifier_model is not None return HealthResponse( status="healthy" if models_loaded else "models_not_loaded", - models_loaded=models_loaded + models_loaded=models_loaded, ) + @app.post("/classify", response_model=ClassificationResponse) async def classify_video(request: ClassificationRequest): """Classify a video based on title, description, and tags""" try: if tokenizer is None or classifier_model is None: raise HTTPException(status_code=503, detail="Models not loaded") - + # Get embeddings for each channel - texts = [request.title, request.description, request.tags] + texts = [ + request.title or "no title", + request.description or "no description", + request.tags or "no tags", + ] embeddings = get_jina_embeddings_1024(texts) - + # Prepare input for classifier (batch_size=1, channels=3, embedding_dim=1024) channel_features = torch.tensor(embeddings).unsqueeze(0) # [1, 3, 1024] - + # Run inference with torch.no_grad(): logits = classifier_model(channel_features) probabilities = softmax(logits.numpy()[0]) predicted_label = int(np.argmax(probabilities)) - - logger.info(f"Classification completed for aid {request.aid}: label={predicted_label}") - + logger.info( + f"Classification completed for aid {request.aid}: label={predicted_label}" + ) + return ClassificationResponse( - label=predicted_label, - probabilities=probabilities.tolist(), - aid=request.aid + label=predicted_label, probabilities=probabilities.tolist(), aid=request.aid ) - + except Exception as e: logger.error(f"Classification error for aid {request.aid}: {str(e)}") raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}") + @app.post("/classify_batch") async def classify_video_batch(requests: List[ClassificationRequest]): - """Classify multiple videos in batch""" + """Classify multiple videos in batch using true batch processing""" try: if tokenizer is None or classifier_model is None: raise HTTPException(status_code=503, detail="Models not loaded") - + + if not requests: + return {"results": []} + results = [] + + # Collect all texts for batch processing + all_title_texts = [] + all_desc_texts = [] + all_tags_texts = [] + for request in requests: - try: - # Get embeddings for each channel - texts = [request.title, request.description, request.tags] - embeddings = get_jina_embeddings_1024(texts) - - # Prepare input for classifier - channel_features = torch.tensor(embeddings).unsqueeze(0) - - # Run inference - with torch.no_grad(): - logits = classifier_model(channel_features) - probabilities = softmax(logits.numpy()[0]) - predicted_label = int(np.argmax(probabilities)) - - results.append({ - "aid": request.aid, - "label": predicted_label, - "probabilities": probabilities.tolist() - }) - - except Exception as e: - logger.error(f"Batch classification error for aid {request.aid}: {str(e)}") - results.append({ - "aid": request.aid, - "label": -1, - "probabilities": [], - "error": str(e) - }) - + # Handle missing or empty fields + title = request.title or "no title" + description = request.description or "no description" + tags = request.tags or "no tags" + + all_title_texts.append(title) + all_desc_texts.append(description) + all_tags_texts.append(tags) + + # Process all titles in batch + title_embeddings = get_jina_embeddings_1024(all_title_texts) + + # Process all descriptions in batch + desc_embeddings = get_jina_embeddings_1024(all_desc_texts) + + # Process all tags in batch + tags_embeddings = get_jina_embeddings_1024(all_tags_texts) + + # Stack embeddings: [batch_size, 3, embedding_dim] + batch_features = np.stack( + [title_embeddings, desc_embeddings, tags_embeddings], axis=1 + ) + + # Convert to tensor and run inference for entire batch + channel_features = torch.tensor(batch_features, dtype=torch.float32) + + print(channel_features.shape) + + with torch.no_grad(): + logits = classifier_model(channel_features) # [batch_size, num_classes] + probabilities_batch = softmax(logits.numpy()) + predicted_labels = np.argmax(probabilities_batch, axis=1) + + # Prepare results + for i, request in enumerate(requests): + results.append( + { + "aid": request.aid, + "label": int(predicted_labels[i]), + "probabilities": probabilities_batch[i].tolist(), + } + ) + + logger.info(f"Batch classification completed for {len(requests)} requests") + return {"results": results} - + except Exception as e: logger.error(f"Batch classification failed: {str(e)}") - raise HTTPException(status_code=500, detail=f"Batch classification failed: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Batch classification failed: {str(e)}" + ) + if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8544) \ No newline at end of file + + uvicorn.run(app, host="0.0.0.0", port=8544) diff --git a/ml_new/data/dataset_service.py b/ml_new/data/dataset_service.py index 43c70b1..ada8fcc 100644 --- a/ml_new/data/dataset_service.py +++ b/ml_new/data/dataset_service.py @@ -475,6 +475,7 @@ class DatasetBuilder: def list_datasets(self) -> List[Dict[str, Any]]: """List all datasets with their basic information""" + self.storage._load_metadata_cache() return self.storage.list_datasets() def get_dataset_stats(self) -> Dict[str, Any]: diff --git a/ml_new/routes/main.py b/ml_new/routes/main.py index 66aa7c5..f40ccc8 100644 --- a/ml_new/routes/main.py +++ b/ml_new/routes/main.py @@ -153,16 +153,7 @@ async def list_datasets_endpoint(): raise HTTPException(status_code=503, detail="Dataset builder not available") datasets = dataset_builder.list_datasets() - # Add description to each dataset - datasets_with_description = [] - for dataset in datasets: - dataset_info = dataset_builder.get_dataset(dataset["dataset_id"]) - if dataset_info and "description" in dataset_info: - dataset["description"] = dataset_info["description"] - else: - dataset["description"] = None - datasets_with_description.append(dataset) - return {"datasets": datasets_with_description} + return {"datasets": datasets} @router.delete("/dataset/{dataset_id}") diff --git a/ml_new/training/models.py b/ml_new/training/models.py index e481260..740424b 100644 --- a/ml_new/training/models.py +++ b/ml_new/training/models.py @@ -189,85 +189,7 @@ class EmbeddingClassifier(nn.Module): } -class AttentionEmbeddingClassifier(EmbeddingClassifier): - """ - Enhanced classifier with self-attention mechanism - """ - - def __init__( - self, - input_dim: int = 2048, - hidden_dims: Optional[Tuple[int, ...]] = None, - dropout_rate: float = 0.3, - batch_norm: bool = True, - activation: str = "relu", - attention_dim: int = 512 - ): - super().__init__(input_dim, hidden_dims, dropout_rate, batch_norm, activation) - - # Self-attention mechanism - self.attention_dim = attention_dim - self.attention = nn.MultiheadAttention( - embed_dim=input_dim, - num_heads=8, - dropout=dropout_rate, - batch_first=True - ) - - # Attention projection layer - self.attention_projection = nn.Linear(input_dim, attention_dim) - - # Re-initialize attention weights - self._initialize_attention_weights() - - logger.info(f"Initialized AttentionEmbeddingClassifier with attention_dim={attention_dim}") - - def _initialize_attention_weights(self): - """Initialize attention mechanism weights""" - for module in self.attention.modules(): - if isinstance(module, nn.Linear): - nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass with attention mechanism - - Args: - x: Input embeddings of shape (batch_size, input_dim) - - Returns: - logits of shape (batch_size, 1) - """ - # Ensure input is float tensor - if not x.dtype == torch.float32: - x = x.float() - - # Add sequence dimension for attention (batch_size, 1, input_dim) - x_expanded = x.unsqueeze(1) - - # Apply self-attention - attended, attention_weights = self.attention(x_expanded, x_expanded, x_expanded) - - # Remove sequence dimension (batch_size, input_dim) - attended = attended.squeeze(1) - - # Project to attention dimension - attended = self.attention_projection(attended) - - # Process through original classification layers - for layer in self.layers: - attended = layer(attended) - - # Final classification layer - logits = self.classifier(attended) - - return logits - - def create_model( - model_type: str = "standard", input_dim: int = 2048, hidden_dims: Optional[Tuple[int, ...]] = None, **kwargs @@ -276,7 +198,6 @@ def create_model( Factory function to create embedding classifier models Args: - model_type: Type of model ('standard', 'attention') input_dim: Input embedding dimension hidden_dims: Hidden layer dimensions **kwargs: Additional model arguments @@ -284,18 +205,12 @@ def create_model( Returns: Initialized model """ - if model_type == "standard": - return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs) - elif model_type == "attention": - return AttentionEmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs) - else: - raise ValueError(f"Unknown model type: {model_type}") + return EmbeddingClassifier(input_dim=input_dim, hidden_dims=hidden_dims, **kwargs) if __name__ == "__main__": # Test model creation and forward pass model = create_model( - model_type="standard", input_dim=2048, hidden_dims=(512, 256, 128), dropout_rate=0.3 diff --git a/ml_new/training/test.py b/ml_new/training/test.py new file mode 100644 index 0000000..1853074 --- /dev/null +++ b/ml_new/training/test.py @@ -0,0 +1,636 @@ +#!/usr/bin/env python3 +""" +Test script for evaluating trained models on a dataset +""" + +import argparse +import json +import sys +from pathlib import Path + +import numpy as np +import torch +import aiohttp +import asyncio +from torch.utils.data import DataLoader +from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix +from typing import Optional, List, Dict, Any + +# Add the parent directory to the path to import ml_new modules +sys.path.append(str(Path(__file__).parent.parent)) + +from ml_new.training.models import create_model +from ml_new.training.data_loader import DatasetLoader, EmbeddingDataset +from ml_new.config.logger_config import get_logger + +logger = get_logger(__name__) + + +def parse_args(): + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description="Test embedding classification model", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + # Required arguments + parser.add_argument( + "--dataset-id", + type=str, + required=True, + help="ID of the dataset to use for testing" + ) + + parser.add_argument( + "--experiment", + type=str, + help="Name of the experiment to load model from" + ) + + # Optional arguments + parser.add_argument( + "--datasets-dir", + type=str, + default="training/datasets", + help="Directory containing dataset files" + ) + + parser.add_argument( + "--checkpoints-dir", + type=str, + default="training/checkpoints", + help="Directory containing model checkpoints" + ) + + parser.add_argument( + "--checkpoint-file", + type=str, + default="best_model.pth", + help="Checkpoint file to load (relative to experiment dir)" + ) + + parser.add_argument( + "--batch-size", + type=int, + default=32, + help="Batch size for testing" + ) + + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of worker processes for data loading" + ) + + parser.add_argument( + "--device", + type=str, + default="auto", + help="Device to use (auto, cpu, cuda)" + ) + + parser.add_argument( + "--normalize", + action="store_true", + default=False, + help="Normalize embeddings during testing" + ) + + parser.add_argument( + "--output", + type=str, + default=None, + help="Output file for detailed results (JSON)" + ) + + parser.add_argument( + "--threshold", + type=float, + default=0.5, + help="Classification threshold" + ) + + parser.add_argument( + "--use-api", + action="store_true", + default=False, + help="Use API model instead of local model" + ) + + parser.add_argument( + "--api-url", + type=str, + default="http://localhost:8544", + help="API base URL" + ) + + parser.add_argument( + "--misclassified-output", + type=str, + default=None, + help="Output file for misclassified samples (FN and FP aids)" + ) + + return parser.parse_args() + + +def setup_device(device_arg: str): + """Setup device""" + if device_arg == "auto": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(device_arg) + + logger.info(f"Using device: {device}") + return device + + +def load_model_from_experiment( + checkpoints_dir: str, + experiment_name: str, + checkpoint_file: str, + device: torch.device +): + """ + Load a trained model from an experiment checkpoint + + Args: + checkpoints_dir: Directory containing checkpoints + experiment_name: Name of the experiment + checkpoint_file: Checkpoint file name + device: Device to load model to + + Returns: + Loaded model + """ + checkpoint_path = Path(checkpoints_dir) / experiment_name / checkpoint_file + + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + logger.info(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + + # Get model config from checkpoint + model_config = checkpoint.get('model_config', {}) + + # Create model with saved config + model = create_model( + input_dim=model_config.get('input_dim', 2048), + hidden_dims=tuple(model_config.get('hidden_dims', [512, 256, 128])), + dropout_rate=model_config.get('dropout_rate', 0.3), + batch_norm=model_config.get('batch_norm', True) + ) + + # Load state dict + model.load_state_dict(checkpoint['model_state_dict']) + model.to(device) + model.eval() + + logger.info(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}") + logger.info(f"Model config: {model_config}") + + return model, model_config + + +def evaluate_model( + model, + test_loader: DataLoader, + device: torch.device, + threshold: float = 0.5 +): + """ + Evaluate model on test set + + Args: + model: The model to evaluate + test_loader: Test data loader + device: Device to use + threshold: Classification threshold + + Returns: + Tuple of (metrics, predictions, probabilities, true_labels, fn_aids, fp_aids) + """ + model.eval() + criterion = torch.nn.BCEWithLogitsLoss() + + total_loss = 0.0 + all_predictions = [] + all_labels = [] + all_probabilities = [] + all_metadata = [] + fn_aids = [] + fp_aids = [] + + with torch.no_grad(): + for batch_idx, (embeddings, labels, metadata) in enumerate(test_loader): + embeddings = embeddings.to(device) + labels = labels.to(device).float() + + # Forward pass + outputs = model(embeddings) + loss = criterion(outputs.squeeze(), labels) + + # Collect statistics + total_loss += loss.item() + + # Get predictions and probabilities + probabilities = torch.sigmoid(outputs).squeeze() + predictions = (probabilities > threshold).long() + + all_predictions.extend(predictions.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + all_probabilities.extend(probabilities.cpu().numpy()) + + # Collect metadata and track FN/FP + batch_metadata = metadata if isinstance(metadata, list) else [metadata] + all_metadata.extend(batch_metadata) + + # Track FN and FP aids for this batch + for i, (true_label, pred_label) in enumerate(zip(labels.cpu().numpy(), predictions.cpu().numpy())): + if isinstance(batch_metadata[i], dict) and 'aid' in batch_metadata[i]: + aid = batch_metadata[i]['aid'] + if true_label == 1 and pred_label == 0: # False Negative + fn_aids.append(aid) + elif true_label == 0 and pred_label == 1: # False Positive + fp_aids.append(aid) + + if (batch_idx + 1) % 10 == 0: + logger.info(f"Processed {batch_idx + 1}/{len(test_loader)} batches") + + # Calculate metrics + test_loss = total_loss / len(test_loader) + test_accuracy = accuracy_score(all_labels, all_predictions) + precision, recall, f1, _ = precision_recall_fscore_support( + all_labels, all_predictions, average='binary', zero_division=0 + ) + + try: + test_auc = roc_auc_score(all_labels, all_probabilities) + except ValueError: + test_auc = 0.0 + + # Confusion matrix + cm = confusion_matrix(all_labels, all_predictions) + if cm.size == 4: + tn, fp, fn, tp = cm.ravel() + else: + tn, fp, fn, tp = 0, 0, 0, 0 + + metrics = { + 'loss': test_loss, + 'accuracy': test_accuracy, + 'precision': precision, + 'recall': recall, + 'f1': f1, + 'auc': test_auc, + 'true_negatives': int(tn), + 'false_positives': int(fp), + 'false_negatives': int(fn), + 'true_positives': int(tp), + 'total_samples': len(all_labels), + 'threshold': threshold + } + + # Add class distribution + unique, counts = np.unique(all_labels, return_counts=True) + metrics['class_distribution'] = {int(k): int(v) for k, v in zip(unique, counts)} + + return metrics, all_predictions, all_probabilities, all_labels, fn_aids, fp_aids + + +async def call_api_batch(session: aiohttp.ClientSession, api_url: str, requests: List[Dict[str, Any]]) -> Optional[List[Dict[str, Any]]]: + """Call the classification API for batch requests""" + try: + url = f"{api_url}/classify_batch" + async with session.post(url, json=requests) as response: + if response.status == 200: + result = await response.json() + return result.get('results', []) + else: + logger.warning(f"Batch API request failed with status {response.status}") + return None + except Exception as e: + logger.warning(f"Batch API request failed: {e}") + return None + + +def convert_api_label_to_bool(api_label: int) -> int: + """Convert API label to boolean (non-zero = true)""" + return 1 if api_label != 0 else 0 + + +async def evaluate_with_api( + embeddings: np.ndarray, + labels: np.ndarray, + metadata: List[Dict[str, Any]], + api_url: str, + batch_size: int = 32 +): + """ + Evaluate using the API instead of local model + + Args: + embeddings: Array of embeddings (not used for API calls) + labels: Ground truth labels + metadata: Metadata containing title, description, tags, aid + api_url: API base URL + batch_size: Number of requests per API batch call + + Returns: + Tuple of (metrics, predictions, probabilities, true_labels, fn_aids, fp_aids) + """ + logger.info(f"Using API at {api_url} for evaluation") + + # Prepare API requests + requests = [] + for i, meta in enumerate(metadata): + # Extract metadata fields for API + title = meta.get('title', '') + description = meta.get('description', '') + tags = meta.get('tags', '') + aid = meta.get('aid', i) + + # Handle missing or empty fields + if not title: + title = f"Video {aid}" + if not description: + description = "" + if not tags: + tags = "" + + request_data = { + "title": title, + "description": description, + "tags": tags, + "aid": aid + } + requests.append(request_data) + + # Split requests into batches + num_batches = (len(requests) + batch_size - 1) // batch_size + logger.info(f"Making {num_batches} batch API requests with batch_size={batch_size} for {len(requests)} total requests") + + # Process all batches + all_predictions = [] + all_probabilities = [] + all_labels = labels.tolist() + all_aids = [meta.get('aid', i) for i, meta in enumerate(metadata)] + failed_requests = 0 + fn_aids = [] + fp_aids = [] + + async with aiohttp.ClientSession() as session: + for batch_idx in range(num_batches): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, len(requests)) + batch_requests = requests[start_idx:end_idx] + + logger.info(f"Processing batch {batch_idx + 1}/{num_batches} ({len(batch_requests)} requests)") + + results = await call_api_batch(session, api_url, batch_requests) + + if results is None: + logger.error(f"Batch {batch_idx + 1} API request failed completely") + # Create dummy results for this batch + all_predictions.extend([0] * len(batch_requests)) + all_probabilities.extend([0.0] * len(batch_requests)) + failed_requests += len(batch_requests) + continue + + for i, result in enumerate(results): + global_idx = start_idx + i + if not isinstance(result, dict) or 'error' in result: + error_msg = result.get('error', 'Unknown error') if isinstance(result, dict) else 'Invalid result' + logger.warning(f"Failed to get API prediction for request {global_idx}: {error_msg}") + failed_requests += 1 + all_predictions.append(0) + all_probabilities.append(0.0) + continue + + # Convert API response to our format + api_label = result.get('label', -1) + probabilities = result.get('probabilities') + + # Convert to boolean (non-zero = true) + prediction = convert_api_label_to_bool(api_label) + # Use the probability of the positive class + if probabilities and len(probabilities) > 0: + positive_prob = 1 - probabilities[0] + else: + logger.warning(f"No probabilities for request {global_idx}") + failed_requests += 1 + all_predictions.append(0) + all_probabilities.append(0.0) + continue + + all_predictions.append(prediction) + all_probabilities.append(positive_prob) + + if failed_requests > 0: + logger.warning(f"Failed to get API predictions for {failed_requests} requests") + + # Collect FN and FP aids + for i, (true_label, pred_label) in enumerate(zip(all_labels, all_predictions)): + aid = all_aids[i] + if true_label == 1 and pred_label == 0: # False Negative + fn_aids.append(aid) + elif true_label == 0 and pred_label == 1: # False Positive + fp_aids.append(aid) + + # Calculate metrics + test_accuracy = accuracy_score(all_labels, all_predictions) + precision, recall, f1, _ = precision_recall_fscore_support( + all_labels, all_predictions, average='binary', zero_division=0 + ) + + try: + test_auc = roc_auc_score(all_labels, all_probabilities) + except ValueError: + test_auc = 0.0 + + # Confusion matrix + cm = confusion_matrix(all_labels, all_predictions) + if cm.size == 4: + tn, fp, fn, tp = cm.ravel() + else: + tn, fp, fn, tp = 0, 0, 0, 0 + + metrics = { + 'accuracy': test_accuracy, + 'precision': precision, + 'recall': recall, + 'f1': f1, + 'auc': test_auc, + 'true_negatives': int(tn), + 'false_positives': int(fp), + 'false_negatives': int(fn), + 'true_positives': int(tp), + 'total_samples': len(all_labels), + 'failed_requests': failed_requests + } + + # Add class distribution + unique, counts = np.unique(all_labels, return_counts=True) + metrics['class_distribution'] = {int(k): int(v) for k, v in zip(unique, counts)} + + return metrics, all_predictions, all_probabilities, all_labels, fn_aids, fp_aids + + +def main(): + """Main test function""" + args = parse_args() + + # Setup device + device = setup_device(args.device) + + # Check if dataset exists + loader = DatasetLoader(args.datasets_dir) + datasets = loader.list_datasets() + + if args.dataset_id not in datasets: + logger.error(f"Dataset '{args.dataset_id}' not found in {args.datasets_dir}") + logger.info(f"Available datasets: {datasets}") + sys.exit(1) + + # Load dataset (use entire dataset as test set) + try: + logger.info(f"Loading dataset {args.dataset_id}...") + embeddings, labels, metadata = loader.load_dataset(args.dataset_id) + + logger.info(f"Dataset loaded: {len(embeddings)} samples") + logger.info(f"Embedding dimension: {embeddings.shape[1]}") + + except Exception as e: + logger.error(f"Failed to load dataset: {e}") + sys.exit(1) + + # Choose evaluation method + if args.use_api: + # Use API for evaluation + logger.info("Using API-based evaluation") + + # Run async evaluation + metrics, predictions, probabilities, true_labels, fn_aids, fp_aids = asyncio.run( + evaluate_with_api( + embeddings, labels, metadata, + args.api_url, + args.batch_size + ) + ) + + # For API mode, we don't have model_config + model_config = {"type": "api", "api_url": args.api_url} + + else: + # Use local model for evaluation + # Check if experiment exists + experiment_dir = Path(args.checkpoints_dir) / args.experiment + if not experiment_dir.exists(): + logger.error(f"Experiment '{args.experiment}' not found in {args.checkpoints_dir}") + available = [d.name for d in Path(args.checkpoints_dir).iterdir() if d.is_dir()] + logger.info(f"Available experiments: {available}") + sys.exit(1) + + # Load model + try: + model, model_config = load_model_from_experiment( + args.checkpoints_dir, + args.experiment, + args.checkpoint_file, + device + ) + except Exception as e: + logger.error(f"Failed to load model: {e}") + sys.exit(1) + + # Create test dataset and loader + test_dataset = EmbeddingDataset( + embeddings, labels, metadata, + normalize=args.normalize + ) + + test_loader = DataLoader( + test_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers + ) + + # Evaluate model + logger.info("Starting local model evaluation...") + metrics, predictions, probabilities, true_labels, fn_aids, fp_aids = evaluate_model( + model, test_loader, device, args.threshold + ) + + # Print results + logger.info("=" * 50) + logger.info("Test Results") + logger.info("=" * 50) + logger.info(f"Dataset: {args.dataset_id}") + if args.use_api: + logger.info(f"Method: API ({args.api_url})") + else: + logger.info(f"Experiment: {args.experiment}") + logger.info(f"Total samples: {metrics['total_samples']}") + logger.info(f"Class distribution: {metrics['class_distribution']}") + if 'failed_requests' in metrics: + logger.info(f"Failed API requests: {metrics['failed_requests']}") + logger.info("-" * 50) + if 'loss' in metrics: + logger.info(f"Loss: {metrics['loss']:.4f}") + logger.info(f"Accuracy: {metrics['accuracy']:.4f}") + logger.info(f"Precision: {metrics['precision']:.4f}") + logger.info(f"Recall: {metrics['recall']:.4f}") + logger.info(f"F1 Score: {metrics['f1']:.4f}") + logger.info(f"AUC: {metrics['auc']:.4f}") + logger.info("-" * 50) + logger.info(f"True Positives: {metrics['true_positives']}") + logger.info(f"True Negatives: {metrics['true_negatives']}") + logger.info(f"False Positives: {metrics['false_positives']}") + logger.info(f"False Negatives: {metrics['false_negatives']}") + logger.info("=" * 50) + + # Save detailed results if requested + if args.output: + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + results = { + 'dataset_id': args.dataset_id, + 'experiment': args.experiment, + 'checkpoint': args.checkpoint_file, + 'model_config': model_config, + 'metrics': metrics, + 'predictions': [int(p) for p in predictions], + 'probabilities': [float(p) for p in probabilities], + 'labels': [int(l) for l in true_labels] + } + + with open(output_path, 'w') as f: + json.dump(results, f, indent=2) + + logger.info(f"Detailed results saved to {output_path}") + + # Save misclassified samples (FN and FP aids) if requested + if args.misclassified_output: + misclassified_path = Path(args.misclassified_output) + misclassified_path.parent.mkdir(parents=True, exist_ok=True) + + misclassified_data = { + 'false_negatives': fn_aids, + 'false_positives': fp_aids, + 'fn_count': len(fn_aids), + 'fp_count': len(fp_aids), + 'total_misclassified': len(fn_aids) + len(fp_aids) + } + + with open(misclassified_path, 'w') as f: + json.dump(misclassified_data, f, indent=2) + + logger.info(f"Misclassified samples (FN: {len(fn_aids)}, FP: {len(fp_aids)}) saved to {misclassified_path}") + + +if __name__ == "__main__": + main() diff --git a/ml_new/training/train.py b/ml_new/training/train.py index e4edcdb..eaecfa8 100644 --- a/ml_new/training/train.py +++ b/ml_new/training/train.py @@ -62,15 +62,6 @@ def parse_args(): help="Directory containing dataset files" ) - # Model arguments - parser.add_argument( - "--model-type", - type=str, - choices=["standard", "attention"], - default="standard", - help="Type of model architecture" - ) - parser.add_argument( "--input-dim", type=int, @@ -353,11 +344,11 @@ def main(): # Create experiment name if not provided if args.experiment_name is None: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - args.experiment_name = f"{args.model_type}_{args.dataset_id}_{timestamp}" + args.experiment_name = f"{timestamp}_{args.dataset_id}" logger.info(f"Starting experiment: {args.experiment_name}") logger.info(f"Dataset: {args.dataset_id}") - logger.info(f"Model: {args.model_type} with hidden dims {args.hidden_dims}") + logger.info(f"Model: hidden dims {args.hidden_dims}") # Load dataset and create data loaders try: @@ -387,7 +378,6 @@ def main(): try: logger.info("Creating model...") model = create_model( - model_type=args.model_type, input_dim=args.input_dim, hidden_dims=tuple(args.hidden_dims), dropout_rate=args.dropout_rate, @@ -452,7 +442,6 @@ def main(): "experiment_name": args.experiment_name, "dataset_id": args.dataset_id, "model_config": { - "model_type": args.model_type, "input_dim": args.input_dim, "hidden_dims": args.hidden_dims, "dropout_rate": args.dropout_rate, diff --git a/ml_new/training/trainer.py b/ml_new/training/trainer.py index e873bfe..81247d8 100644 --- a/ml_new/training/trainer.py +++ b/ml_new/training/trainer.py @@ -417,7 +417,7 @@ class ModelTrainer: def load_checkpoint(self, checkpoint_path: str, load_optimizer: bool = True) -> None: """Load model from checkpoint""" - checkpoint = torch.load(checkpoint_path, map_location=self.device) + checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) self.model.load_state_dict(checkpoint['model_state_dict']) @@ -504,7 +504,6 @@ if __name__ == "__main__": # Create dummy model and data model = create_model( - model_type="standard", input_dim=2048, hidden_dims=(512, 256, 128) ) diff --git a/packages/ml_panel/src/App.tsx b/packages/ml_panel/src/App.tsx index ee5737a..2107cf4 100644 --- a/packages/ml_panel/src/App.tsx +++ b/packages/ml_panel/src/App.tsx @@ -1,10 +1,9 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; -import { CardDescription, CardTitle } from "@/components/ui/card"; +import { CardTitle } from "@/components/ui/card"; import { DatasetManager } from "@/components/DatasetManager"; import { TaskMonitor } from "@/components/TaskMonitor"; -import { SamplingPanel } from "@/components/SamplingPanel"; -import { Database, Activity, Settings } from "lucide-react"; +import { Database, Activity } from "lucide-react"; const queryClient = new QueryClient({ defaultOptions: { @@ -21,22 +20,17 @@ function App() {
-

ML Dataset Management Panel

-

- Create and manage machine learning datasets with multiple sampling strategies and task monitoring -

+

+ CVSA Machine Learning Panel +

- + Datasets - - - Sampling - Tasks @@ -44,22 +38,11 @@ function App() { - Dataset Management - View, create and manage your machine learning datasets - - Sampling Strategy Configuration - - Configure different data sampling strategies to create balanced datasets - - - - Task Monitor - Monitor real-time status and progress of dataset building tasks diff --git a/packages/ml_panel/src/components/DatasetManager.tsx b/packages/ml_panel/src/components/DatasetManager.tsx index a5cef73..6f1dac4 100644 --- a/packages/ml_panel/src/components/DatasetManager.tsx +++ b/packages/ml_panel/src/components/DatasetManager.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useState } from "react"; import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query"; import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; @@ -20,19 +20,22 @@ import { SelectValue } from "@/components/ui/select"; import { Textarea } from "@/components/ui/textarea"; -import { Trash2, Plus, Database, FileText, Calendar, Activity } from "lucide-react"; +import { Trash2, Plus, Database, Upload } from "lucide-react"; import { apiClient } from "@/lib/api"; import { toast } from "sonner"; -import { Spinner } from "@/components/ui/spinner" +import { Spinner } from "@/components/ui/spinner"; export function DatasetManager() { const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); const [createFormData, setCreateFormData] = useState({ + creationMode: "sampling", // "sampling" or "aidList" strategy: "all", limit: "", embeddingModel: "", description: "", - forceRegenerate: false + forceRegenerate: false, + aidListFile: null as File | null, + aidList: [] as number[] }); const queryClient = useQueryClient(); @@ -45,7 +48,7 @@ export function DatasetManager() { }); // Fetch embedding models - const { data: modelsData, isLoading: modelsLoading } = useQuery({ + const { data: modelsData } = useQuery({ queryKey: ["embedding-models"], queryFn: () => apiClient.getEmbeddingModels() }); @@ -57,11 +60,14 @@ export function DatasetManager() { toast.success("Dataset creation task started"); setIsCreateDialogOpen(false); setCreateFormData({ + creationMode: "sampling", strategy: "all", limit: "", embeddingModel: "", description: "", - forceRegenerate: false + forceRegenerate: false, + aidListFile: null, + aidList: [] }); queryClient.invalidateQueries({ queryKey: ["datasets"] }); queryClient.invalidateQueries({ queryKey: ["tasks"] }); @@ -83,23 +89,68 @@ export function DatasetManager() { } }); + // Build dataset mutation + const buildDatasetMutation = useMutation({ + mutationFn: (data: { + aid_list: number[]; + embedding_model: string; + force_regenerate?: boolean; + description?: string; + }) => apiClient.buildDataset(data), + onSuccess: () => { + toast.success("Dataset build task started"); + setIsCreateDialogOpen(false); + setCreateFormData({ + creationMode: "sampling", + strategy: "all", + limit: "", + embeddingModel: "", + description: "", + forceRegenerate: false, + aidListFile: null, + aidList: [] + }); + queryClient.invalidateQueries({ queryKey: ["datasets"] }); + queryClient.invalidateQueries({ queryKey: ["tasks"] }); + }, + onError: (error: Error) => { + toast.error(`Build failed: ${error.message}`); + } + }); + const handleCreateDataset = () => { if (!createFormData.embeddingModel) { toast.error("Please select an embedding model"); return; } - const requestData = { - sampling: { - strategy: createFormData.strategy, - ...(createFormData.limit && { limit: parseInt(createFormData.limit) }) - }, - embedding_model: createFormData.embeddingModel, - force_regenerate: createFormData.forceRegenerate, - description: createFormData.description || undefined - }; + if (createFormData.creationMode === "sampling") { + const requestData = { + sampling: { + strategy: createFormData.strategy, + ...(createFormData.limit && { limit: parseInt(createFormData.limit) }) + }, + embedding_model: createFormData.embeddingModel, + force_regenerate: createFormData.forceRegenerate, + description: createFormData.description || undefined + }; - createDatasetMutation.mutate(requestData); + createDatasetMutation.mutate(requestData); + } else if (createFormData.creationMode === "aidList") { + if (createFormData.aidList.length === 0) { + toast.error("Please upload an aid list file"); + return; + } + + const requestData = { + aid_list: createFormData.aidList, + embedding_model: createFormData.embeddingModel, + force_regenerate: createFormData.forceRegenerate, + description: createFormData.description || undefined + }; + + buildDatasetMutation.mutate(requestData); + } }; const handleDeleteDataset = (datasetId: string) => { @@ -108,16 +159,67 @@ export function DatasetManager() { } }; - const formatDate = (dateString: string) => { - return new Date(dateString).toLocaleString("en-US"); + // Parse aid list file + const parseAidListFile = (file: File): Promise => { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = (e) => { + try { + const content = e.target?.result as string; + const lines = content.split("\n").filter((line) => line.trim()); + const aidList: number[] = []; + + for (const line of lines) { + const trimmed = line.trim(); + if (trimmed) { + const aid = parseInt(trimmed, 10); + if (!isNaN(aid)) { + aidList.push(aid); + } + } + } + + resolve(aidList); + } catch (error) { + reject(new Error("Failed to parse file")); + } + }; + reader.onerror = () => reject(new Error("Failed to read file")); + reader.readAsText(file); + }); }; - const formatFileSize = (bytes: number) => { - if (bytes === 0) return "0 Bytes"; - const k = 1024; - const sizes = ["Bytes", "KB", "MB", "GB"]; - const i = Math.floor(Math.log(bytes) / Math.log(k)); - return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + " " + sizes[i]; + // Handle file upload + const handleFileUpload = async (event: React.ChangeEvent) => { + const file = event.target.files?.[0]; + if (!file) return; + + if (!file.name.endsWith(".txt") && !file.name.endsWith(".csv")) { + toast.error("Please upload a .txt or .csv file"); + return; + } + + try { + const aidList = await parseAidListFile(file); + if (aidList.length === 0) { + toast.error("No valid AIDs found in the file"); + return; + } + + setCreateFormData((prev) => ({ + ...prev, + aidListFile: file, + aidList: aidList + })); + + toast.success(`Loaded ${aidList.length} AIDs from file`); + } catch (error) { + toast.error("Failed to parse aid list file"); + } + }; + + const formatDate = (dateString: string) => { + return new Date(dateString).toLocaleString("en-US"); }; if (datasetsLoading) { @@ -150,43 +252,120 @@ export function DatasetManager() { Create New Dataset - Select sampling strategy and configuration parameters to create a new dataset + Select sampling strategy and configuration parameters to create a + new dataset
- +
- {createFormData.strategy === "random" && ( + {createFormData.creationMode === "sampling" && (
- -