1
0

fix: test scripts in old and new ML package

This commit is contained in:
alikia2x (寒寒) 2025-12-17 03:02:05 +08:00
parent fc06b3d69f
commit f4127d7c2e
WARNING! Although there is a key with this ID in the database it does not verify this commit! This commit is SUSPICIOUS.
GPG Key ID: 56209E0CCD8420C6
14 changed files with 330 additions and 140 deletions

View File

@ -1,6 +1,5 @@
{
"lockfileVersion": 1,
"configVersion": 0,
"workspaces": {
"": {
"name": "cvsa",
@ -86,6 +85,7 @@
"name": "ml_panel",
"version": "0.0.0",
"dependencies": {
"@radix-ui/react-checkbox": "^1.3.3",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-label": "^2.1.8",
"@radix-ui/react-progress": "^1.1.8",

View File

@ -56,6 +56,11 @@ class MultiChannelDataset(Dataset):
self.max_length = max_length
self.mode = mode
if self.mode == 'test' and os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
self.examples = [json.loads(line) for line in f]
return
# 检查train、eval和test文件是否存在
train_file = os.path.join(os.path.dirname(file_path), 'train.jsonl')
eval_file = os.path.join(os.path.dirname(file_path), 'eval.jsonl')
@ -101,9 +106,9 @@ class MultiChannelDataset(Dataset):
# 返回文本字典
texts = {
'title': example['title'],
'description': example['description'],
'tags': tags_text
'title': example['title'] or 'no title',
'description': example['description'] or 'no description',
'tags': tags_text or 'no tags'
}
return {

View File

@ -1,6 +1,21 @@
from typing import List
import numpy as np
import torch
from model2vec import StaticModel
import onnxruntime as ort
from transformers import AutoTokenizer
# 初始化 tokenizer 和 ONNX 模型(全局缓存)
_tokenizer = None
_onnx_session = None
def _get_tokenizer_and_session():
"""获取全局缓存的 tokenizer 和 ONNX session"""
global _tokenizer, _onnx_session
if _tokenizer is None:
_tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3")
if _onnx_session is None:
_onnx_session = ort.InferenceSession("../model/embedding/model.onnx")
return _tokenizer, _onnx_session
def prepare_batch(batch_data, device="cpu"):
@ -18,21 +33,60 @@ def prepare_batch(batch_data, device="cpu"):
返回:
torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量
"""
# 1. 对每个通道的文本分别编码
channel_embeddings = []
model = StaticModel.from_pretrained("./model/embedding_1024/")
for channel in ["title", "description", "tags"]:
texts = batch_data[channel] # 获取当前通道的文本列表
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
channel_embeddings.append(embeddings)
# 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim]
batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠
return batch_tensor
title_embeddings = get_jina_embeddings_1024(batch_data['title'])
desc_embeddings = get_jina_embeddings_1024(batch_data['description'])
tags_embeddings = get_jina_embeddings_1024(batch_data['tags'])
return torch.stack([title_embeddings, desc_embeddings, tags_embeddings], dim=1).to(device)
def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray:
"""Get Jina embeddings using tokenizer and ONNX-like processing"""
[tokenizer, session] = _get_tokenizer_and_session()
encoded_inputs = tokenizer(
texts,
add_special_tokens=False,
return_attention_mask=False,
return_tensors=None, # 返回原生Python列表便于后续处理
)
input_ids = encoded_inputs[
"input_ids"
] # 形状: [batch_size, seq_len_i](每个样本长度可能不同)
# 2. 计算offsets与JS的cumsum逻辑完全一致
# 先获取每个样本的token长度
lengths = [len(ids) for ids in input_ids]
# 计算累积和(排除最后一个样本)
cumsum = []
current_sum = 0
for l in lengths[:-1]: # 只累加前n-1个样本的长度
current_sum += l
cumsum.append(current_sum)
# 构建offsets起始为0后面跟累积和
offsets = [0] + cumsum # 形状: [batch_size]
# 3. 展平input_ids为一维数组
flattened_input_ids = []
for ids in input_ids:
flattened_input_ids.extend(ids) # 直接拼接所有token id
flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64)
# 4. 准备ONNX输入与JS的tensor形状保持一致
inputs = {
"input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids),
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64)),
}
# 5. 运行模型推理
outputs = session.run(None, inputs)
embeddings = outputs[
0
] # 假设第一个输出是embeddings形状: [batch_size, embedding_dim]
return torch.tensor(embeddings, dtype=torch.float32)
import onnxruntime as ort
from transformers import AutoTokenizer
from itertools import accumulate
def prepare_batch_per_token(batch_data, max_length=1024):
"""
@ -67,6 +121,7 @@ def prepare_batch_per_token(batch_data, max_length=1024):
input_ids_lengths = [len(enc["input_ids"][0]) for enc in encoded_inputs]
# 生成 offsets: [0, len1, len1+len2, ...]
from itertools import accumulate
offsets = list(accumulate([0] + input_ids_lengths[:-1])) # 累积和,排除最后一个长度
# 将所有 input_ids 展平为一维数组

137
ml/filter/test_new.py Normal file
View File

@ -0,0 +1,137 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
import numpy as np
import torch
from torch.utils.data import DataLoader
from dataset import MultiChannelDataset
from filter.modelV3_15 import VideoClassifierV3_15
from embedding import prepare_batch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import argparse
def load_model(checkpoint_path):
"""加载模型权重"""
model = VideoClassifierV3_15()
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
model.eval()
return model
def load_test_data(test_file):
"""加载测试数据"""
test_dataset = MultiChannelDataset(test_file, mode='test')
test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False)
return test_loader
def convert_to_binary_labels(labels):
"""将三分类标签转换为二分类类别1和2合并为类别1"""
binary_labels = np.where(labels >= 1, 1, 0)
return binary_labels
def run_inference(model, test_loader):
"""运行推理"""
all_preds = []
all_labels = []
with torch.no_grad():
for batch in test_loader:
# 准备文本数据
batch_tensor = prepare_batch(batch['texts'])
# 前向传播
logits = model(batch_tensor)
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(batch['label'].cpu().numpy())
return np.array(all_preds), np.array(all_labels)
def calculate_metrics(y_true, y_pred):
"""计算二分类指标"""
# 转换为二分类
y_true_binary = convert_to_binary_labels(y_true)
y_pred_binary = convert_to_binary_labels(y_pred)
# 计算混淆矩阵
cm = confusion_matrix(y_true_binary, y_pred_binary)
if cm.shape == (2, 2):
tn, fp, fn, tp = cm.ravel()
else:
# 如果只有一类的情况
tn = fp = fn = tp = 0
if y_true_binary.sum() == 0:
tn = (y_true_binary == y_pred_binary).sum()
else:
tp = ((y_true_binary == 1) & (y_pred_binary == 1)).sum()
fp = ((y_true_binary == 0) & (y_pred_binary == 1)).sum()
fn = ((y_true_binary == 1) & (y_pred_binary == 0)).sum()
# 计算指标
accuracy = accuracy_score(y_true_binary, y_pred_binary)
precision = precision_score(y_true_binary, y_pred_binary, zero_division=0)
recall = recall_score(y_true_binary, y_pred_binary, zero_division=0)
f1 = f1_score(y_true_binary, y_pred_binary, zero_division=0)
return {
'Acc': accuracy,
'Prec': precision,
'Recall': recall,
'F1': f1,
'TP': tp,
'FP': fp,
'TN': tn,
'FN': fn
}
def main():
parser = argparse.ArgumentParser(description='Test model on JSONL data')
parser.add_argument('--model_path', type=str, default='./filter/checkpoints/best_model_V3.17.pt',
help='Path to model checkpoint')
parser.add_argument('--test_file', type=str, default='./data/filter/test1.jsonl',
help='Path to test JSONL file')
args = parser.parse_args()
# 加载模型
print(f"Loading model from {args.model_path}")
model = load_model(args.model_path)
# 加载测试数据
print(f"Loading test data from {args.test_file}")
test_loader = load_test_data(args.test_file)
# 运行推理
print("Running inference...")
y_pred, y_true = run_inference(model, test_loader)
# 计算指标
print("\nCalculating metrics...")
metrics = calculate_metrics(y_true, y_pred)
# 打印结果
print("\n=== Test Results (Binary Classification) ===")
print(f"Accuracy (Acc): {metrics['Acc']:.4f}")
print(f"Precision (Prec): {metrics['Prec']:.4f}")
print(f"Recall: {metrics['Recall']:.4f}")
print(f"F1 Score: {metrics['F1']:.4f}")
print(f"\nConfusion Matrix:")
print(f" TP (True Positive): {metrics['TP']}")
print(f" TN (True Negative): {metrics['TN']}")
print(f" FP (False Positive): {metrics['FP']}")
print(f" FN (False Negative): {metrics['FN']}")
# 显示原始三分类分布
print(f"\n=== Original Label Distribution ===")
unique, counts = np.unique(y_true, return_counts=True)
for label, count in zip(unique, counts):
print(f"Class {label}: {count} samples")
if __name__ == '__main__':
main()

View File

@ -186,6 +186,7 @@ class ParquetDatasetStorage:
'dataset_id': dataset_id,
'description': f'Auto-regenerated metadata for dataset {dataset_id}',
'stats': {
'total_records': total_records,
'regenerated': True,
'regeneration_reason': 'missing_or_corrupted_metadata_file'
},
@ -404,7 +405,6 @@ class ParquetDatasetStorage:
try:
# Extract dataset_id from filename (remove .parquet extension)
dataset_id = parquet_file.stem
print(dataset_id)
# Try to load metadata, this will automatically regenerate if missing
metadata = self.load_dataset_metadata(dataset_id)

View File

@ -2,14 +2,23 @@
model = "jina-embedding-v3-m2v"
[models.qwen3-embedding]
name = "text-embedding-v4"
dimensions = 2048
# [models.qwen3-embedding]
# name = "text-embedding-v4"
# dimensions = 2048
# type = "openai-compatible"
# api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
# max_tokens = 8192
# max_batch_size = 10
# api_key_env = "ALIYUN_KEY"
[models.qwen3-embedding-8b]
name = "qwen/qwen3-embedding-8b"
dimensions = 4096
type = "openai-compatible"
api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1"
max_tokens = 8192
max_batch_size = 10
api_key_env = "ALIYUN_KEY"
api_endpoint = "https://openrouter.ai/api/v1"
max_tokens = 32768
max_batch_size = 100
api_key_env = "OPENROUTER_KEY"
[models.jina-embedding-v3-m2v]
name = "jina-embedding-v3-m2v-1024"

View File

@ -3,6 +3,7 @@ Data models for dataset building functionality
"""
from typing import List, Optional, Dict, Any, Literal
import uuid
from pydantic import BaseModel, Field
from datetime import datetime
from enum import Enum
@ -51,6 +52,7 @@ class DatasetBuildTaskStatus(BaseModel):
class DatasetBuildRequest(BaseModel):
"""Request model for dataset building"""
id: Optional[str] = Field(str(uuid.uuid4()), description="Dataset ID")
aid_list: List[int] = Field(..., description="List of video AIDs")
embedding_model: str = Field(..., description="Embedding model name")
force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings")

View File

@ -101,7 +101,7 @@ async def build_dataset_endpoint(request: DatasetBuildRequest):
if request.embedding_model not in config_loader.get_embedding_models():
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
dataset_id = str(uuid.uuid4())
dataset_id = request.id or str(uuid.uuid4())
# Start task-based dataset building
task_id = await dataset_builder.start_dataset_build_task(
@ -335,7 +335,7 @@ async def create_dataset_with_sampling_endpoint(request: DatasetCreateRequest):
raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}")
import uuid
dataset_id = str(uuid.uuid4())
dataset_id = request.id or str(uuid.uuid4())
try:
# First sample the AIDs

View File

@ -197,43 +197,9 @@ def load_model_from_experiment(
def safe_extract_aid(metadata_entry):
"""Safely extract aid from metadata entry"""
if isinstance(metadata_entry, dict) and 'aid' in metadata_entry:
return metadata_entry['aid']
return metadata_entry['aid'].tolist()
return None
def normalize_batch_metadata(metadata, expected_batch_size):
"""
Normalize batch metadata to ensure consistent structure
Args:
metadata: Raw metadata from DataLoader (could be various formats)
expected_batch_size: Expected number of metadata entries
Returns:
List of metadata dictionaries
"""
# Handle different metadata structures
if metadata is None:
return [{}] * expected_batch_size
if isinstance(metadata, dict):
# Single metadata object - duplicate for entire batch
return [metadata] * expected_batch_size
if isinstance(metadata, (list, tuple)):
if len(metadata) == expected_batch_size:
return list(metadata)
elif len(metadata) < expected_batch_size:
# Pad with empty dicts
padded = list(metadata) + [{}] * (expected_batch_size - len(metadata))
return padded
else:
# Truncate to expected size
return list(metadata[:expected_batch_size])
# Unknown format - return empty dicts
logger.warning(f"Unknown metadata format: {type(metadata)}")
return [{}] * expected_batch_size
def evaluate_model(
model,
test_loader: DataLoader,
@ -257,7 +223,6 @@ def evaluate_model(
all_predictions = []
all_labels = []
all_probabilities = []
all_metadata = []
fn_aids = []
fp_aids = []
@ -277,33 +242,9 @@ def evaluate_model(
all_labels.extend(labels.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
# Collect metadata and track FN/FP
batch_size = len(labels)
batch_metadata = normalize_batch_metadata(metadata, batch_size)
all_metadata.extend(batch_metadata)
# Track FN and FP aids for this batch
logger.debug(f"Batch {batch_idx}: labels shape {labels.shape}, predictions shape {predictions.shape}, metadata structure: {type(batch_metadata)}")
if len(batch_metadata) != len(labels):
logger.warning(f"Metadata length mismatch: {len(batch_metadata)} metadata entries vs {len(labels)} samples")
for i, (true_label, pred_label) in enumerate(zip(labels.cpu().numpy(), predictions.cpu().numpy())):
try:
# Safely get metadata entry with bounds checking
if i >= len(batch_metadata):
logger.warning(f"Index {i} out of range for batch_metadata (length: {len(batch_metadata)})")
continue
meta_entry = batch_metadata[i]
if not isinstance(meta_entry, dict):
logger.warning(f"Metadata entry {i} is not a dict: {type(meta_entry)}")
continue
if 'aid' not in meta_entry:
logger.debug(f"No 'aid' key in metadata entry {i}")
continue
aid = safe_extract_aid(meta_entry)
aid = metadata['aid'].tolist()[i]
if aid is not None:
if true_label == 1 and pred_label == 0: # False Negative
fn_aids.append(aid)
@ -620,30 +561,28 @@ def main():
)
# Print results
logger.info("=" * 50)
logger.info("Test Results")
logger.info("=" * 50)
logger.info(f"Dataset: {args.dataset_id}")
print("Test Results")
print("=" * 50)
print(f"Dataset: {args.dataset_id}")
if args.use_api:
logger.info(f"Method: API ({args.api_url})")
print(f"Method: API ({args.api_url})")
else:
logger.info(f"Experiment: {args.experiment}")
logger.info(f"Total samples: {metrics['total_samples']}")
logger.info(f"Class distribution: {metrics['class_distribution']}")
print(f"Experiment: {args.experiment}")
print(f"Total samples: {metrics['total_samples']}")
print(f"Class distribution: {metrics['class_distribution']}")
if 'failed_requests' in metrics:
logger.info(f"Failed API requests: {metrics['failed_requests']}")
logger.info("-" * 50)
logger.info(f"Accuracy: {metrics['accuracy']:.4f}")
logger.info(f"Precision: {metrics['precision']:.4f}")
logger.info(f"Recall: {metrics['recall']:.4f}")
logger.info(f"F1 Score: {metrics['f1']:.4f}")
logger.info(f"AUC: {metrics['auc']:.4f}")
logger.info("-" * 50)
logger.info(f"True Positives: {metrics['true_positives']}")
logger.info(f"True Negatives: {metrics['true_negatives']}")
logger.info(f"False Positives: {metrics['false_positives']}")
logger.info(f"False Negatives: {metrics['false_negatives']}")
logger.info("=" * 50)
print("-" * 50)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1 Score: {metrics['f1']:.4f}")
print(f"AUC: {metrics['auc']:.4f}")
print(f"True Positives: {metrics['true_positives']}")
print(f"True Negatives: {metrics['true_negatives']}")
print(f"False Positives: {metrics['false_positives']}")
print(f"False Negatives: {metrics['false_negatives']}")
print("=" * 50)
# Save detailed results if requested
if args.output:

View File

@ -0,0 +1,7 @@
{
"root": false,
"$schema": "https://biomejs.dev/schemas/2.3.8/schema.json",
"linter": {
"enabled": false
}
}

View File

@ -10,6 +10,7 @@
"preview": "vite preview"
},
"dependencies": {
"@radix-ui/react-checkbox": "^1.3.3",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-label": "^2.1.8",
"@radix-ui/react-progress": "^1.1.8",

View File

@ -5,7 +5,6 @@ import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/com
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
@ -24,6 +23,8 @@ import { Trash2, Plus, Database, Upload } from "lucide-react";
import { apiClient } from "@/lib/api";
import { toast } from "sonner";
import { Spinner } from "@/components/ui/spinner";
import { Input } from "./ui/input";
import { Checkbox } from "./ui/checkbox";
export function DatasetManager() {
const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false);
@ -37,14 +38,20 @@ export function DatasetManager() {
aidListFile: null as File | null,
aidList: [] as number[]
});
const [loadCost, setLoadCost] = useState(0);
const queryClient = useQueryClient();
// Fetch datasets
const { data: datasetsData, isLoading: datasetsLoading } = useQuery({
queryKey: ["datasets"],
queryFn: () => apiClient.getDatasets(),
refetchInterval: 30000 // Refresh every 30 seconds
queryFn: async () => {
const t = performance.now();
const r = await apiClient.getDatasets();
setLoadCost(performance.now() - t);
return r;
},
refetchInterval: 5000
});
// Fetch embedding models
@ -235,9 +242,9 @@ export function DatasetManager() {
{/* Create Dataset Button */}
<div className="flex justify-between items-center">
<div>
<h3 className="text-lg font-medium">Dataset List</h3>
<h3 className="text-lg font-medium">Datasets</h3>
<p className="text-sm text-muted-foreground">
{datasetsData?.datasets?.length || 0} datasets created
{datasetsData?.datasets?.length || 0} datasets loaded. ({Math.round(loadCost)} ms)
</p>
</div>
@ -245,21 +252,17 @@ export function DatasetManager() {
<DialogTrigger asChild>
<Button>
<Plus className="h-4 w-4 mr-2" />
Create Dataset
New Dataset
</Button>
</DialogTrigger>
<DialogContent className="sm:max-w-[500px]">
<DialogHeader>
<DialogTitle>Create New Dataset</DialogTitle>
<DialogDescription>
Select sampling strategy and configuration parameters to create a
new dataset
</DialogDescription>
<DialogTitle>New Dataset</DialogTitle>
</DialogHeader>
<div className="grid gap-4 py-4">
<div className="grid gap-2">
<Label htmlFor="creationMode">Creation Mode</Label>
<Label htmlFor="creationMode">Create new database from</Label>
<Select
value={createFormData.creationMode}
onValueChange={(value) =>
@ -273,18 +276,20 @@ export function DatasetManager() {
}
>
<SelectTrigger>
<SelectValue placeholder="Select creation mode" />
<SelectValue placeholder="Create from..." />
</SelectTrigger>
<SelectContent>
<SelectItem value="sampling">Sampling Strategy</SelectItem>
<SelectItem value="aidList">Upload Aid List</SelectItem>
<SelectItem value="sampling">
sampling the database
</SelectItem>
<SelectItem value="aidList">given aid list</SelectItem>
</SelectContent>
</Select>
</div>
{createFormData.creationMode === "sampling" && (
<div className="grid gap-2">
<Label htmlFor="strategy">Sampling Strategy</Label>
<Label htmlFor="strategy">Sampling strategy</Label>
<Select
value={createFormData.strategy}
onValueChange={(value) =>
@ -298,8 +303,8 @@ export function DatasetManager() {
<SelectValue placeholder="Select sampling strategy" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">All Videos</SelectItem>
<SelectItem value="random">Random Sampling</SelectItem>
<SelectItem value="all">All videos</SelectItem>
<SelectItem value="random">Random sampling</SelectItem>
</SelectContent>
</Select>
</div>
@ -308,8 +313,8 @@ export function DatasetManager() {
{createFormData.creationMode === "sampling" &&
createFormData.strategy === "random" && (
<div className="grid gap-2">
<Label htmlFor="limit">Sample Count</Label>
<Textarea
<Label htmlFor="limit">Sample count</Label>
<Input
id="limit"
placeholder="Enter number of samples, e.g., 1000"
value={createFormData.limit}
@ -325,7 +330,7 @@ export function DatasetManager() {
{createFormData.creationMode === "aidList" && (
<div className="grid gap-2">
<Label htmlFor="aidListFile">Aid List File</Label>
<Label htmlFor="aidListFile">aid list file</Label>
<div
className="border-2 border-dashed rounded-lg p-4 cursor-pointer"
onClick={() =>
@ -362,7 +367,7 @@ export function DatasetManager() {
/>
{createFormData.aidList.length > 0 && (
<div className="text-sm text-green-600">
Loaded {createFormData.aidList.length} AIDs from{" "}
Loaded {createFormData.aidList.length} aids from{" "}
{createFormData.aidListFile?.name}
</div>
)}
@ -370,7 +375,7 @@ export function DatasetManager() {
)}
<div className="grid gap-2">
<Label htmlFor="model">Embedding Model</Label>
<Label htmlFor="model">Embedding model</Label>
<Select
value={createFormData.embeddingModel}
onValueChange={(value) =>
@ -411,14 +416,13 @@ export function DatasetManager() {
</div>
<div className="flex items-center space-x-2">
<input
type="checkbox"
<Checkbox
id="forceRegenerate"
checked={createFormData.forceRegenerate}
onChange={(e) =>
onCheckedChange={(e) =>
setCreateFormData((prev) => ({
...prev,
forceRegenerate: e.target.checked
forceRegenerate: e ? true : false
}))
}
/>
@ -480,9 +484,11 @@ export function DatasetManager() {
<span>{dataset.stats.total_records} records</span>
<span>{dataset.stats.embedding_model}</span>
<span>{formatDate(dataset.created_at)}</span>
<span className="text-muted-foreground">
New: {dataset.stats.new_embeddings}
</span>
{dataset.stats.new_embeddings && (
<span className="text-muted-foreground">
New: {dataset.stats.new_embeddings}
</span>
)}
</div>
</CardContent>
</Card>

View File

@ -0,0 +1,30 @@
import * as React from "react"
import * as CheckboxPrimitive from "@radix-ui/react-checkbox"
import { CheckIcon } from "lucide-react"
import { cn } from "@/lib/utils"
function Checkbox({
className,
...props
}: React.ComponentProps<typeof CheckboxPrimitive.Root>) {
return (
<CheckboxPrimitive.Root
data-slot="checkbox"
className={cn(
"peer border-input dark:bg-input/30 data-[state=checked]:bg-primary data-[state=checked]:text-primary-foreground dark:data-[state=checked]:bg-primary data-[state=checked]:border-primary focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive size-4 shrink-0 rounded-[4px] border shadow-xs transition-shadow outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50",
className
)}
{...props}
>
<CheckboxPrimitive.Indicator
data-slot="checkbox-indicator"
className="grid place-content-center text-current transition-none"
>
<CheckIcon className="size-3.5" />
</CheckboxPrimitive.Indicator>
</CheckboxPrimitive.Root>
)
}
export { Checkbox }

View File

@ -76,7 +76,6 @@ class ApiClient {
});
}
async sampleDataset(data: {
strategy: string;
limit?: number;