fix: test scripts in old and new ML package
This commit is contained in:
parent
fc06b3d69f
commit
f4127d7c2e
2
bun.lock
2
bun.lock
@ -1,6 +1,5 @@
|
||||
{
|
||||
"lockfileVersion": 1,
|
||||
"configVersion": 0,
|
||||
"workspaces": {
|
||||
"": {
|
||||
"name": "cvsa",
|
||||
@ -86,6 +85,7 @@
|
||||
"name": "ml_panel",
|
||||
"version": "0.0.0",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-checkbox": "^1.3.3",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-label": "^2.1.8",
|
||||
"@radix-ui/react-progress": "^1.1.8",
|
||||
|
||||
@ -56,6 +56,11 @@ class MultiChannelDataset(Dataset):
|
||||
self.max_length = max_length
|
||||
self.mode = mode
|
||||
|
||||
if self.mode == 'test' and os.path.exists(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
self.examples = [json.loads(line) for line in f]
|
||||
return
|
||||
|
||||
# 检查train、eval和test文件是否存在
|
||||
train_file = os.path.join(os.path.dirname(file_path), 'train.jsonl')
|
||||
eval_file = os.path.join(os.path.dirname(file_path), 'eval.jsonl')
|
||||
@ -101,9 +106,9 @@ class MultiChannelDataset(Dataset):
|
||||
|
||||
# 返回文本字典
|
||||
texts = {
|
||||
'title': example['title'],
|
||||
'description': example['description'],
|
||||
'tags': tags_text
|
||||
'title': example['title'] or 'no title',
|
||||
'description': example['description'] or 'no description',
|
||||
'tags': tags_text or 'no tags'
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@ -1,6 +1,21 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import torch
|
||||
from model2vec import StaticModel
|
||||
import onnxruntime as ort
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# 初始化 tokenizer 和 ONNX 模型(全局缓存)
|
||||
_tokenizer = None
|
||||
_onnx_session = None
|
||||
|
||||
def _get_tokenizer_and_session():
|
||||
"""获取全局缓存的 tokenizer 和 ONNX session"""
|
||||
global _tokenizer, _onnx_session
|
||||
if _tokenizer is None:
|
||||
_tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3")
|
||||
if _onnx_session is None:
|
||||
_onnx_session = ort.InferenceSession("../model/embedding/model.onnx")
|
||||
return _tokenizer, _onnx_session
|
||||
|
||||
|
||||
def prepare_batch(batch_data, device="cpu"):
|
||||
@ -18,21 +33,60 @@ def prepare_batch(batch_data, device="cpu"):
|
||||
返回:
|
||||
torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量。
|
||||
"""
|
||||
# 1. 对每个通道的文本分别编码
|
||||
channel_embeddings = []
|
||||
model = StaticModel.from_pretrained("./model/embedding_1024/")
|
||||
for channel in ["title", "description", "tags"]:
|
||||
texts = batch_data[channel] # 获取当前通道的文本列表
|
||||
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
|
||||
channel_embeddings.append(embeddings)
|
||||
|
||||
# 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim]
|
||||
batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠
|
||||
return batch_tensor
|
||||
title_embeddings = get_jina_embeddings_1024(batch_data['title'])
|
||||
desc_embeddings = get_jina_embeddings_1024(batch_data['description'])
|
||||
tags_embeddings = get_jina_embeddings_1024(batch_data['tags'])
|
||||
|
||||
return torch.stack([title_embeddings, desc_embeddings, tags_embeddings], dim=1).to(device)
|
||||
|
||||
|
||||
def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray:
|
||||
"""Get Jina embeddings using tokenizer and ONNX-like processing"""
|
||||
[tokenizer, session] = _get_tokenizer_and_session()
|
||||
|
||||
encoded_inputs = tokenizer(
|
||||
texts,
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
return_tensors=None, # 返回原生Python列表,便于后续处理
|
||||
)
|
||||
input_ids = encoded_inputs[
|
||||
"input_ids"
|
||||
] # 形状: [batch_size, seq_len_i](每个样本长度可能不同)
|
||||
|
||||
# 2. 计算offsets(与JS的cumsum逻辑完全一致)
|
||||
# 先获取每个样本的token长度
|
||||
lengths = [len(ids) for ids in input_ids]
|
||||
# 计算累积和(排除最后一个样本)
|
||||
cumsum = []
|
||||
current_sum = 0
|
||||
for l in lengths[:-1]: # 只累加前n-1个样本的长度
|
||||
current_sum += l
|
||||
cumsum.append(current_sum)
|
||||
# 构建offsets:起始为0,后面跟累积和
|
||||
offsets = [0] + cumsum # 形状: [batch_size]
|
||||
|
||||
# 3. 展平input_ids为一维数组
|
||||
flattened_input_ids = []
|
||||
for ids in input_ids:
|
||||
flattened_input_ids.extend(ids) # 直接拼接所有token id
|
||||
flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64)
|
||||
|
||||
# 4. 准备ONNX输入(与JS的tensor形状保持一致)
|
||||
inputs = {
|
||||
"input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids),
|
||||
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64)),
|
||||
}
|
||||
|
||||
# 5. 运行模型推理
|
||||
outputs = session.run(None, inputs)
|
||||
embeddings = outputs[
|
||||
0
|
||||
] # 假设第一个输出是embeddings,形状: [batch_size, embedding_dim]
|
||||
|
||||
return torch.tensor(embeddings, dtype=torch.float32)
|
||||
|
||||
import onnxruntime as ort
|
||||
from transformers import AutoTokenizer
|
||||
from itertools import accumulate
|
||||
|
||||
def prepare_batch_per_token(batch_data, max_length=1024):
|
||||
"""
|
||||
@ -67,6 +121,7 @@ def prepare_batch_per_token(batch_data, max_length=1024):
|
||||
input_ids_lengths = [len(enc["input_ids"][0]) for enc in encoded_inputs]
|
||||
|
||||
# 生成 offsets: [0, len1, len1+len2, ...]
|
||||
from itertools import accumulate
|
||||
offsets = list(accumulate([0] + input_ids_lengths[:-1])) # 累积和,排除最后一个长度
|
||||
|
||||
# 将所有 input_ids 展平为一维数组
|
||||
|
||||
137
ml/filter/test_new.py
Normal file
137
ml/filter/test_new.py
Normal file
@ -0,0 +1,137 @@
|
||||
import os
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from dataset import MultiChannelDataset
|
||||
from filter.modelV3_15 import VideoClassifierV3_15
|
||||
from embedding import prepare_batch
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||||
import argparse
|
||||
|
||||
|
||||
def load_model(checkpoint_path):
|
||||
"""加载模型权重"""
|
||||
model = VideoClassifierV3_15()
|
||||
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_test_data(test_file):
|
||||
"""加载测试数据"""
|
||||
test_dataset = MultiChannelDataset(test_file, mode='test')
|
||||
test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False)
|
||||
return test_loader
|
||||
|
||||
|
||||
def convert_to_binary_labels(labels):
|
||||
"""将三分类标签转换为二分类:类别1和2合并为类别1"""
|
||||
binary_labels = np.where(labels >= 1, 1, 0)
|
||||
return binary_labels
|
||||
|
||||
|
||||
def run_inference(model, test_loader):
|
||||
"""运行推理"""
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in test_loader:
|
||||
# 准备文本数据
|
||||
batch_tensor = prepare_batch(batch['texts'])
|
||||
|
||||
# 前向传播
|
||||
logits = model(batch_tensor)
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
|
||||
all_preds.extend(preds.cpu().numpy())
|
||||
all_labels.extend(batch['label'].cpu().numpy())
|
||||
|
||||
return np.array(all_preds), np.array(all_labels)
|
||||
|
||||
|
||||
def calculate_metrics(y_true, y_pred):
|
||||
"""计算二分类指标"""
|
||||
# 转换为二分类
|
||||
y_true_binary = convert_to_binary_labels(y_true)
|
||||
y_pred_binary = convert_to_binary_labels(y_pred)
|
||||
|
||||
# 计算混淆矩阵
|
||||
cm = confusion_matrix(y_true_binary, y_pred_binary)
|
||||
|
||||
if cm.shape == (2, 2):
|
||||
tn, fp, fn, tp = cm.ravel()
|
||||
else:
|
||||
# 如果只有一类的情况
|
||||
tn = fp = fn = tp = 0
|
||||
if y_true_binary.sum() == 0:
|
||||
tn = (y_true_binary == y_pred_binary).sum()
|
||||
else:
|
||||
tp = ((y_true_binary == 1) & (y_pred_binary == 1)).sum()
|
||||
fp = ((y_true_binary == 0) & (y_pred_binary == 1)).sum()
|
||||
fn = ((y_true_binary == 1) & (y_pred_binary == 0)).sum()
|
||||
|
||||
# 计算指标
|
||||
accuracy = accuracy_score(y_true_binary, y_pred_binary)
|
||||
precision = precision_score(y_true_binary, y_pred_binary, zero_division=0)
|
||||
recall = recall_score(y_true_binary, y_pred_binary, zero_division=0)
|
||||
f1 = f1_score(y_true_binary, y_pred_binary, zero_division=0)
|
||||
|
||||
return {
|
||||
'Acc': accuracy,
|
||||
'Prec': precision,
|
||||
'Recall': recall,
|
||||
'F1': f1,
|
||||
'TP': tp,
|
||||
'FP': fp,
|
||||
'TN': tn,
|
||||
'FN': fn
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Test model on JSONL data')
|
||||
parser.add_argument('--model_path', type=str, default='./filter/checkpoints/best_model_V3.17.pt',
|
||||
help='Path to model checkpoint')
|
||||
parser.add_argument('--test_file', type=str, default='./data/filter/test1.jsonl',
|
||||
help='Path to test JSONL file')
|
||||
args = parser.parse_args()
|
||||
|
||||
# 加载模型
|
||||
print(f"Loading model from {args.model_path}")
|
||||
model = load_model(args.model_path)
|
||||
|
||||
# 加载测试数据
|
||||
print(f"Loading test data from {args.test_file}")
|
||||
test_loader = load_test_data(args.test_file)
|
||||
|
||||
# 运行推理
|
||||
print("Running inference...")
|
||||
y_pred, y_true = run_inference(model, test_loader)
|
||||
|
||||
# 计算指标
|
||||
print("\nCalculating metrics...")
|
||||
metrics = calculate_metrics(y_true, y_pred)
|
||||
|
||||
# 打印结果
|
||||
print("\n=== Test Results (Binary Classification) ===")
|
||||
print(f"Accuracy (Acc): {metrics['Acc']:.4f}")
|
||||
print(f"Precision (Prec): {metrics['Prec']:.4f}")
|
||||
print(f"Recall: {metrics['Recall']:.4f}")
|
||||
print(f"F1 Score: {metrics['F1']:.4f}")
|
||||
print(f"\nConfusion Matrix:")
|
||||
print(f" TP (True Positive): {metrics['TP']}")
|
||||
print(f" TN (True Negative): {metrics['TN']}")
|
||||
print(f" FP (False Positive): {metrics['FP']}")
|
||||
print(f" FN (False Negative): {metrics['FN']}")
|
||||
|
||||
# 显示原始三分类分布
|
||||
print(f"\n=== Original Label Distribution ===")
|
||||
unique, counts = np.unique(y_true, return_counts=True)
|
||||
for label, count in zip(unique, counts):
|
||||
print(f"Class {label}: {count} samples")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -186,6 +186,7 @@ class ParquetDatasetStorage:
|
||||
'dataset_id': dataset_id,
|
||||
'description': f'Auto-regenerated metadata for dataset {dataset_id}',
|
||||
'stats': {
|
||||
'total_records': total_records,
|
||||
'regenerated': True,
|
||||
'regeneration_reason': 'missing_or_corrupted_metadata_file'
|
||||
},
|
||||
@ -404,7 +405,6 @@ class ParquetDatasetStorage:
|
||||
try:
|
||||
# Extract dataset_id from filename (remove .parquet extension)
|
||||
dataset_id = parquet_file.stem
|
||||
print(dataset_id)
|
||||
|
||||
# Try to load metadata, this will automatically regenerate if missing
|
||||
metadata = self.load_dataset_metadata(dataset_id)
|
||||
|
||||
@ -2,14 +2,23 @@
|
||||
|
||||
model = "jina-embedding-v3-m2v"
|
||||
|
||||
[models.qwen3-embedding]
|
||||
name = "text-embedding-v4"
|
||||
dimensions = 2048
|
||||
# [models.qwen3-embedding]
|
||||
# name = "text-embedding-v4"
|
||||
# dimensions = 2048
|
||||
# type = "openai-compatible"
|
||||
# api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
# max_tokens = 8192
|
||||
# max_batch_size = 10
|
||||
# api_key_env = "ALIYUN_KEY"
|
||||
|
||||
[models.qwen3-embedding-8b]
|
||||
name = "qwen/qwen3-embedding-8b"
|
||||
dimensions = 4096
|
||||
type = "openai-compatible"
|
||||
api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
max_tokens = 8192
|
||||
max_batch_size = 10
|
||||
api_key_env = "ALIYUN_KEY"
|
||||
api_endpoint = "https://openrouter.ai/api/v1"
|
||||
max_tokens = 32768
|
||||
max_batch_size = 100
|
||||
api_key_env = "OPENROUTER_KEY"
|
||||
|
||||
[models.jina-embedding-v3-m2v]
|
||||
name = "jina-embedding-v3-m2v-1024"
|
||||
|
||||
@ -3,6 +3,7 @@ Data models for dataset building functionality
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any, Literal
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@ -51,6 +52,7 @@ class DatasetBuildTaskStatus(BaseModel):
|
||||
|
||||
class DatasetBuildRequest(BaseModel):
|
||||
"""Request model for dataset building"""
|
||||
id: Optional[str] = Field(str(uuid.uuid4()), description="Dataset ID")
|
||||
aid_list: List[int] = Field(..., description="List of video AIDs")
|
||||
embedding_model: str = Field(..., description="Embedding model name")
|
||||
force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings")
|
||||
|
||||
@ -101,7 +101,7 @@ async def build_dataset_endpoint(request: DatasetBuildRequest):
|
||||
if request.embedding_model not in config_loader.get_embedding_models():
|
||||
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
|
||||
|
||||
dataset_id = str(uuid.uuid4())
|
||||
dataset_id = request.id or str(uuid.uuid4())
|
||||
|
||||
# Start task-based dataset building
|
||||
task_id = await dataset_builder.start_dataset_build_task(
|
||||
@ -335,7 +335,7 @@ async def create_dataset_with_sampling_endpoint(request: DatasetCreateRequest):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
|
||||
|
||||
import uuid
|
||||
dataset_id = str(uuid.uuid4())
|
||||
dataset_id = request.id or str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# First sample the AIDs
|
||||
|
||||
@ -197,43 +197,9 @@ def load_model_from_experiment(
|
||||
def safe_extract_aid(metadata_entry):
|
||||
"""Safely extract aid from metadata entry"""
|
||||
if isinstance(metadata_entry, dict) and 'aid' in metadata_entry:
|
||||
return metadata_entry['aid']
|
||||
return metadata_entry['aid'].tolist()
|
||||
return None
|
||||
|
||||
def normalize_batch_metadata(metadata, expected_batch_size):
|
||||
"""
|
||||
Normalize batch metadata to ensure consistent structure
|
||||
|
||||
Args:
|
||||
metadata: Raw metadata from DataLoader (could be various formats)
|
||||
expected_batch_size: Expected number of metadata entries
|
||||
|
||||
Returns:
|
||||
List of metadata dictionaries
|
||||
"""
|
||||
# Handle different metadata structures
|
||||
if metadata is None:
|
||||
return [{}] * expected_batch_size
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
# Single metadata object - duplicate for entire batch
|
||||
return [metadata] * expected_batch_size
|
||||
|
||||
if isinstance(metadata, (list, tuple)):
|
||||
if len(metadata) == expected_batch_size:
|
||||
return list(metadata)
|
||||
elif len(metadata) < expected_batch_size:
|
||||
# Pad with empty dicts
|
||||
padded = list(metadata) + [{}] * (expected_batch_size - len(metadata))
|
||||
return padded
|
||||
else:
|
||||
# Truncate to expected size
|
||||
return list(metadata[:expected_batch_size])
|
||||
|
||||
# Unknown format - return empty dicts
|
||||
logger.warning(f"Unknown metadata format: {type(metadata)}")
|
||||
return [{}] * expected_batch_size
|
||||
|
||||
def evaluate_model(
|
||||
model,
|
||||
test_loader: DataLoader,
|
||||
@ -257,7 +223,6 @@ def evaluate_model(
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
all_probabilities = []
|
||||
all_metadata = []
|
||||
fn_aids = []
|
||||
fp_aids = []
|
||||
|
||||
@ -277,33 +242,9 @@ def evaluate_model(
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
all_probabilities.extend(probabilities.cpu().numpy())
|
||||
|
||||
# Collect metadata and track FN/FP
|
||||
batch_size = len(labels)
|
||||
batch_metadata = normalize_batch_metadata(metadata, batch_size)
|
||||
all_metadata.extend(batch_metadata)
|
||||
|
||||
# Track FN and FP aids for this batch
|
||||
logger.debug(f"Batch {batch_idx}: labels shape {labels.shape}, predictions shape {predictions.shape}, metadata structure: {type(batch_metadata)}")
|
||||
if len(batch_metadata) != len(labels):
|
||||
logger.warning(f"Metadata length mismatch: {len(batch_metadata)} metadata entries vs {len(labels)} samples")
|
||||
|
||||
for i, (true_label, pred_label) in enumerate(zip(labels.cpu().numpy(), predictions.cpu().numpy())):
|
||||
try:
|
||||
# Safely get metadata entry with bounds checking
|
||||
if i >= len(batch_metadata):
|
||||
logger.warning(f"Index {i} out of range for batch_metadata (length: {len(batch_metadata)})")
|
||||
continue
|
||||
|
||||
meta_entry = batch_metadata[i]
|
||||
if not isinstance(meta_entry, dict):
|
||||
logger.warning(f"Metadata entry {i} is not a dict: {type(meta_entry)}")
|
||||
continue
|
||||
|
||||
if 'aid' not in meta_entry:
|
||||
logger.debug(f"No 'aid' key in metadata entry {i}")
|
||||
continue
|
||||
|
||||
aid = safe_extract_aid(meta_entry)
|
||||
aid = metadata['aid'].tolist()[i]
|
||||
if aid is not None:
|
||||
if true_label == 1 and pred_label == 0: # False Negative
|
||||
fn_aids.append(aid)
|
||||
@ -620,30 +561,28 @@ def main():
|
||||
)
|
||||
|
||||
# Print results
|
||||
logger.info("=" * 50)
|
||||
logger.info("Test Results")
|
||||
logger.info("=" * 50)
|
||||
logger.info(f"Dataset: {args.dataset_id}")
|
||||
print("Test Results")
|
||||
print("=" * 50)
|
||||
print(f"Dataset: {args.dataset_id}")
|
||||
if args.use_api:
|
||||
logger.info(f"Method: API ({args.api_url})")
|
||||
print(f"Method: API ({args.api_url})")
|
||||
else:
|
||||
logger.info(f"Experiment: {args.experiment}")
|
||||
logger.info(f"Total samples: {metrics['total_samples']}")
|
||||
logger.info(f"Class distribution: {metrics['class_distribution']}")
|
||||
print(f"Experiment: {args.experiment}")
|
||||
print(f"Total samples: {metrics['total_samples']}")
|
||||
print(f"Class distribution: {metrics['class_distribution']}")
|
||||
if 'failed_requests' in metrics:
|
||||
logger.info(f"Failed API requests: {metrics['failed_requests']}")
|
||||
logger.info("-" * 50)
|
||||
logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
|
||||
logger.info(f"Precision: {metrics['precision']:.4f}")
|
||||
logger.info(f"Recall: {metrics['recall']:.4f}")
|
||||
logger.info(f"F1 Score: {metrics['f1']:.4f}")
|
||||
logger.info(f"AUC: {metrics['auc']:.4f}")
|
||||
logger.info("-" * 50)
|
||||
logger.info(f"True Positives: {metrics['true_positives']}")
|
||||
logger.info(f"True Negatives: {metrics['true_negatives']}")
|
||||
logger.info(f"False Positives: {metrics['false_positives']}")
|
||||
logger.info(f"False Negatives: {metrics['false_negatives']}")
|
||||
logger.info("=" * 50)
|
||||
print("-" * 50)
|
||||
print(f"Accuracy: {metrics['accuracy']:.4f}")
|
||||
print(f"Precision: {metrics['precision']:.4f}")
|
||||
print(f"Recall: {metrics['recall']:.4f}")
|
||||
print(f"F1 Score: {metrics['f1']:.4f}")
|
||||
print(f"AUC: {metrics['auc']:.4f}")
|
||||
print(f"True Positives: {metrics['true_positives']}")
|
||||
print(f"True Negatives: {metrics['true_negatives']}")
|
||||
print(f"False Positives: {metrics['false_positives']}")
|
||||
print(f"False Negatives: {metrics['false_negatives']}")
|
||||
print("=" * 50)
|
||||
|
||||
# Save detailed results if requested
|
||||
if args.output:
|
||||
|
||||
7
packages/ml_panel/biome.json
Normal file
7
packages/ml_panel/biome.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"root": false,
|
||||
"$schema": "https://biomejs.dev/schemas/2.3.8/schema.json",
|
||||
"linter": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
@ -10,6 +10,7 @@
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@radix-ui/react-checkbox": "^1.3.3",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-label": "^2.1.8",
|
||||
"@radix-ui/react-progress": "^1.1.8",
|
||||
|
||||
@ -5,7 +5,6 @@ import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/com
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
@ -24,6 +23,8 @@ import { Trash2, Plus, Database, Upload } from "lucide-react";
|
||||
import { apiClient } from "@/lib/api";
|
||||
import { toast } from "sonner";
|
||||
import { Spinner } from "@/components/ui/spinner";
|
||||
import { Input } from "./ui/input";
|
||||
import { Checkbox } from "./ui/checkbox";
|
||||
|
||||
export function DatasetManager() {
|
||||
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
|
||||
@ -37,14 +38,20 @@ export function DatasetManager() {
|
||||
aidListFile: null as File | null,
|
||||
aidList: [] as number[]
|
||||
});
|
||||
const [loadCost, setLoadCost] = useState(0);
|
||||
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
// Fetch datasets
|
||||
const { data: datasetsData, isLoading: datasetsLoading } = useQuery({
|
||||
queryKey: ["datasets"],
|
||||
queryFn: () => apiClient.getDatasets(),
|
||||
refetchInterval: 30000 // Refresh every 30 seconds
|
||||
queryFn: async () => {
|
||||
const t = performance.now();
|
||||
const r = await apiClient.getDatasets();
|
||||
setLoadCost(performance.now() - t);
|
||||
return r;
|
||||
},
|
||||
refetchInterval: 5000
|
||||
});
|
||||
|
||||
// Fetch embedding models
|
||||
@ -235,9 +242,9 @@ export function DatasetManager() {
|
||||
{/* Create Dataset Button */}
|
||||
<div className="flex justify-between items-center">
|
||||
<div>
|
||||
<h3 className="text-lg font-medium">Dataset List</h3>
|
||||
<h3 className="text-lg font-medium">Datasets</h3>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
{datasetsData?.datasets?.length || 0} datasets created
|
||||
{datasetsData?.datasets?.length || 0} datasets loaded. ({Math.round(loadCost)} ms)
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@ -245,21 +252,17 @@ export function DatasetManager() {
|
||||
<DialogTrigger asChild>
|
||||
<Button>
|
||||
<Plus className="h-4 w-4 mr-2" />
|
||||
Create Dataset
|
||||
New Dataset
|
||||
</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent className="sm:max-w-[500px]">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Create New Dataset</DialogTitle>
|
||||
<DialogDescription>
|
||||
Select sampling strategy and configuration parameters to create a
|
||||
new dataset
|
||||
</DialogDescription>
|
||||
<DialogTitle>New Dataset</DialogTitle>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="grid gap-4 py-4">
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="creationMode">Creation Mode</Label>
|
||||
<Label htmlFor="creationMode">Create new database from</Label>
|
||||
<Select
|
||||
value={createFormData.creationMode}
|
||||
onValueChange={(value) =>
|
||||
@ -273,18 +276,20 @@ export function DatasetManager() {
|
||||
}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select creation mode" />
|
||||
<SelectValue placeholder="Create from..." />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="sampling">Sampling Strategy</SelectItem>
|
||||
<SelectItem value="aidList">Upload Aid List</SelectItem>
|
||||
<SelectItem value="sampling">
|
||||
sampling the database
|
||||
</SelectItem>
|
||||
<SelectItem value="aidList">given aid list</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
|
||||
{createFormData.creationMode === "sampling" && (
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="strategy">Sampling Strategy</Label>
|
||||
<Label htmlFor="strategy">Sampling strategy</Label>
|
||||
<Select
|
||||
value={createFormData.strategy}
|
||||
onValueChange={(value) =>
|
||||
@ -298,8 +303,8 @@ export function DatasetManager() {
|
||||
<SelectValue placeholder="Select sampling strategy" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">All Videos</SelectItem>
|
||||
<SelectItem value="random">Random Sampling</SelectItem>
|
||||
<SelectItem value="all">All videos</SelectItem>
|
||||
<SelectItem value="random">Random sampling</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
@ -308,8 +313,8 @@ export function DatasetManager() {
|
||||
{createFormData.creationMode === "sampling" &&
|
||||
createFormData.strategy === "random" && (
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="limit">Sample Count</Label>
|
||||
<Textarea
|
||||
<Label htmlFor="limit">Sample count</Label>
|
||||
<Input
|
||||
id="limit"
|
||||
placeholder="Enter number of samples, e.g., 1000"
|
||||
value={createFormData.limit}
|
||||
@ -325,7 +330,7 @@ export function DatasetManager() {
|
||||
|
||||
{createFormData.creationMode === "aidList" && (
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="aidListFile">Aid List File</Label>
|
||||
<Label htmlFor="aidListFile">aid list file</Label>
|
||||
<div
|
||||
className="border-2 border-dashed rounded-lg p-4 cursor-pointer"
|
||||
onClick={() =>
|
||||
@ -362,7 +367,7 @@ export function DatasetManager() {
|
||||
/>
|
||||
{createFormData.aidList.length > 0 && (
|
||||
<div className="text-sm text-green-600">
|
||||
✓ Loaded {createFormData.aidList.length} AIDs from{" "}
|
||||
Loaded {createFormData.aidList.length} aids from{" "}
|
||||
{createFormData.aidListFile?.name}
|
||||
</div>
|
||||
)}
|
||||
@ -370,7 +375,7 @@ export function DatasetManager() {
|
||||
)}
|
||||
|
||||
<div className="grid gap-2">
|
||||
<Label htmlFor="model">Embedding Model</Label>
|
||||
<Label htmlFor="model">Embedding model</Label>
|
||||
<Select
|
||||
value={createFormData.embeddingModel}
|
||||
onValueChange={(value) =>
|
||||
@ -411,14 +416,13 @@ export function DatasetManager() {
|
||||
</div>
|
||||
|
||||
<div className="flex items-center space-x-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
<Checkbox
|
||||
id="forceRegenerate"
|
||||
checked={createFormData.forceRegenerate}
|
||||
onChange={(e) =>
|
||||
onCheckedChange={(e) =>
|
||||
setCreateFormData((prev) => ({
|
||||
...prev,
|
||||
forceRegenerate: e.target.checked
|
||||
forceRegenerate: e ? true : false
|
||||
}))
|
||||
}
|
||||
/>
|
||||
@ -480,9 +484,11 @@ export function DatasetManager() {
|
||||
<span>{dataset.stats.total_records} records</span>
|
||||
<span>{dataset.stats.embedding_model}</span>
|
||||
<span>{formatDate(dataset.created_at)}</span>
|
||||
<span className="text-muted-foreground">
|
||||
New: {dataset.stats.new_embeddings}
|
||||
</span>
|
||||
{dataset.stats.new_embeddings && (
|
||||
<span className="text-muted-foreground">
|
||||
New: {dataset.stats.new_embeddings}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
30
packages/ml_panel/src/components/ui/checkbox.tsx
Normal file
30
packages/ml_panel/src/components/ui/checkbox.tsx
Normal file
@ -0,0 +1,30 @@
|
||||
import * as React from "react"
|
||||
import * as CheckboxPrimitive from "@radix-ui/react-checkbox"
|
||||
import { CheckIcon } from "lucide-react"
|
||||
|
||||
import { cn } from "@/lib/utils"
|
||||
|
||||
function Checkbox({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof CheckboxPrimitive.Root>) {
|
||||
return (
|
||||
<CheckboxPrimitive.Root
|
||||
data-slot="checkbox"
|
||||
className={cn(
|
||||
"peer border-input dark:bg-input/30 data-[state=checked]:bg-primary data-[state=checked]:text-primary-foreground dark:data-[state=checked]:bg-primary data-[state=checked]:border-primary focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive size-4 shrink-0 rounded-[4px] border shadow-xs transition-shadow outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<CheckboxPrimitive.Indicator
|
||||
data-slot="checkbox-indicator"
|
||||
className="grid place-content-center text-current transition-none"
|
||||
>
|
||||
<CheckIcon className="size-3.5" />
|
||||
</CheckboxPrimitive.Indicator>
|
||||
</CheckboxPrimitive.Root>
|
||||
)
|
||||
}
|
||||
|
||||
export { Checkbox }
|
||||
@ -76,7 +76,6 @@ class ApiClient {
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
async sampleDataset(data: {
|
||||
strategy: string;
|
||||
limit?: number;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user