diff --git a/bun.lock b/bun.lock index f72d880..10fb5db 100644 --- a/bun.lock +++ b/bun.lock @@ -1,6 +1,5 @@ { "lockfileVersion": 1, - "configVersion": 0, "workspaces": { "": { "name": "cvsa", @@ -86,6 +85,7 @@ "name": "ml_panel", "version": "0.0.0", "dependencies": { + "@radix-ui/react-checkbox": "^1.3.3", "@radix-ui/react-dialog": "^1.1.15", "@radix-ui/react-label": "^2.1.8", "@radix-ui/react-progress": "^1.1.8", diff --git a/ml/filter/dataset.py b/ml/filter/dataset.py index 4f992b0..fb725c7 100644 --- a/ml/filter/dataset.py +++ b/ml/filter/dataset.py @@ -56,6 +56,11 @@ class MultiChannelDataset(Dataset): self.max_length = max_length self.mode = mode + if self.mode == 'test' and os.path.exists(file_path): + with open(file_path, 'r', encoding='utf-8') as f: + self.examples = [json.loads(line) for line in f] + return + # 检查train、eval和test文件是否存在 train_file = os.path.join(os.path.dirname(file_path), 'train.jsonl') eval_file = os.path.join(os.path.dirname(file_path), 'eval.jsonl') @@ -101,9 +106,9 @@ class MultiChannelDataset(Dataset): # 返回文本字典 texts = { - 'title': example['title'], - 'description': example['description'], - 'tags': tags_text + 'title': example['title'] or 'no title', + 'description': example['description'] or 'no description', + 'tags': tags_text or 'no tags' } return { diff --git a/ml/filter/embedding.py b/ml/filter/embedding.py index b97c342..4bf1e47 100644 --- a/ml/filter/embedding.py +++ b/ml/filter/embedding.py @@ -1,6 +1,21 @@ +from typing import List import numpy as np import torch -from model2vec import StaticModel +import onnxruntime as ort +from transformers import AutoTokenizer + +# 初始化 tokenizer 和 ONNX 模型(全局缓存) +_tokenizer = None +_onnx_session = None + +def _get_tokenizer_and_session(): + """获取全局缓存的 tokenizer 和 ONNX session""" + global _tokenizer, _onnx_session + if _tokenizer is None: + _tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3") + if _onnx_session is None: + _onnx_session = ort.InferenceSession("../model/embedding/model.onnx") + return _tokenizer, _onnx_session def prepare_batch(batch_data, device="cpu"): @@ -18,21 +33,60 @@ def prepare_batch(batch_data, device="cpu"): 返回: torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量。 """ - # 1. 对每个通道的文本分别编码 - channel_embeddings = [] - model = StaticModel.from_pretrained("./model/embedding_1024/") - for channel in ["title", "description", "tags"]: - texts = batch_data[channel] # 获取当前通道的文本列表 - embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim] - channel_embeddings.append(embeddings) - # 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim] - batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠 - return batch_tensor + title_embeddings = get_jina_embeddings_1024(batch_data['title']) + desc_embeddings = get_jina_embeddings_1024(batch_data['description']) + tags_embeddings = get_jina_embeddings_1024(batch_data['tags']) + + return torch.stack([title_embeddings, desc_embeddings, tags_embeddings], dim=1).to(device) + + +def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray: + """Get Jina embeddings using tokenizer and ONNX-like processing""" + [tokenizer, session] = _get_tokenizer_and_session() + + encoded_inputs = tokenizer( + texts, + add_special_tokens=False, + return_attention_mask=False, + return_tensors=None, # 返回原生Python列表,便于后续处理 + ) + input_ids = encoded_inputs[ + "input_ids" + ] # 形状: [batch_size, seq_len_i](每个样本长度可能不同) + + # 2. 计算offsets(与JS的cumsum逻辑完全一致) + # 先获取每个样本的token长度 + lengths = [len(ids) for ids in input_ids] + # 计算累积和(排除最后一个样本) + cumsum = [] + current_sum = 0 + for l in lengths[:-1]: # 只累加前n-1个样本的长度 + current_sum += l + cumsum.append(current_sum) + # 构建offsets:起始为0,后面跟累积和 + offsets = [0] + cumsum # 形状: [batch_size] + + # 3. 展平input_ids为一维数组 + flattened_input_ids = [] + for ids in input_ids: + flattened_input_ids.extend(ids) # 直接拼接所有token id + flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64) + + # 4. 准备ONNX输入(与JS的tensor形状保持一致) + inputs = { + "input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids), + "offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64)), + } + + # 5. 运行模型推理 + outputs = session.run(None, inputs) + embeddings = outputs[ + 0 + ] # 假设第一个输出是embeddings,形状: [batch_size, embedding_dim] + + return torch.tensor(embeddings, dtype=torch.float32) -import onnxruntime as ort -from transformers import AutoTokenizer -from itertools import accumulate def prepare_batch_per_token(batch_data, max_length=1024): """ @@ -67,6 +121,7 @@ def prepare_batch_per_token(batch_data, max_length=1024): input_ids_lengths = [len(enc["input_ids"][0]) for enc in encoded_inputs] # 生成 offsets: [0, len1, len1+len2, ...] + from itertools import accumulate offsets = list(accumulate([0] + input_ids_lengths[:-1])) # 累积和,排除最后一个长度 # 将所有 input_ids 展平为一维数组 diff --git a/ml/filter/test_new.py b/ml/filter/test_new.py new file mode 100644 index 0000000..24f70a3 --- /dev/null +++ b/ml/filter/test_new.py @@ -0,0 +1,137 @@ +import os +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1" +import numpy as np +import torch +from torch.utils.data import DataLoader +from dataset import MultiChannelDataset +from filter.modelV3_15 import VideoClassifierV3_15 +from embedding import prepare_batch +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix +import argparse + + +def load_model(checkpoint_path): + """加载模型权重""" + model = VideoClassifierV3_15() + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + model.eval() + return model + + +def load_test_data(test_file): + """加载测试数据""" + test_dataset = MultiChannelDataset(test_file, mode='test') + test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False) + return test_loader + + +def convert_to_binary_labels(labels): + """将三分类标签转换为二分类:类别1和2合并为类别1""" + binary_labels = np.where(labels >= 1, 1, 0) + return binary_labels + + +def run_inference(model, test_loader): + """运行推理""" + all_preds = [] + all_labels = [] + + with torch.no_grad(): + for batch in test_loader: + # 准备文本数据 + batch_tensor = prepare_batch(batch['texts']) + + # 前向传播 + logits = model(batch_tensor) + preds = torch.argmax(logits, dim=1) + + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(batch['label'].cpu().numpy()) + + return np.array(all_preds), np.array(all_labels) + + +def calculate_metrics(y_true, y_pred): + """计算二分类指标""" + # 转换为二分类 + y_true_binary = convert_to_binary_labels(y_true) + y_pred_binary = convert_to_binary_labels(y_pred) + + # 计算混淆矩阵 + cm = confusion_matrix(y_true_binary, y_pred_binary) + + if cm.shape == (2, 2): + tn, fp, fn, tp = cm.ravel() + else: + # 如果只有一类的情况 + tn = fp = fn = tp = 0 + if y_true_binary.sum() == 0: + tn = (y_true_binary == y_pred_binary).sum() + else: + tp = ((y_true_binary == 1) & (y_pred_binary == 1)).sum() + fp = ((y_true_binary == 0) & (y_pred_binary == 1)).sum() + fn = ((y_true_binary == 1) & (y_pred_binary == 0)).sum() + + # 计算指标 + accuracy = accuracy_score(y_true_binary, y_pred_binary) + precision = precision_score(y_true_binary, y_pred_binary, zero_division=0) + recall = recall_score(y_true_binary, y_pred_binary, zero_division=0) + f1 = f1_score(y_true_binary, y_pred_binary, zero_division=0) + + return { + 'Acc': accuracy, + 'Prec': precision, + 'Recall': recall, + 'F1': f1, + 'TP': tp, + 'FP': fp, + 'TN': tn, + 'FN': fn + } + + +def main(): + parser = argparse.ArgumentParser(description='Test model on JSONL data') + parser.add_argument('--model_path', type=str, default='./filter/checkpoints/best_model_V3.17.pt', + help='Path to model checkpoint') + parser.add_argument('--test_file', type=str, default='./data/filter/test1.jsonl', + help='Path to test JSONL file') + args = parser.parse_args() + + # 加载模型 + print(f"Loading model from {args.model_path}") + model = load_model(args.model_path) + + # 加载测试数据 + print(f"Loading test data from {args.test_file}") + test_loader = load_test_data(args.test_file) + + # 运行推理 + print("Running inference...") + y_pred, y_true = run_inference(model, test_loader) + + # 计算指标 + print("\nCalculating metrics...") + metrics = calculate_metrics(y_true, y_pred) + + # 打印结果 + print("\n=== Test Results (Binary Classification) ===") + print(f"Accuracy (Acc): {metrics['Acc']:.4f}") + print(f"Precision (Prec): {metrics['Prec']:.4f}") + print(f"Recall: {metrics['Recall']:.4f}") + print(f"F1 Score: {metrics['F1']:.4f}") + print(f"\nConfusion Matrix:") + print(f" TP (True Positive): {metrics['TP']}") + print(f" TN (True Negative): {metrics['TN']}") + print(f" FP (False Positive): {metrics['FP']}") + print(f" FN (False Negative): {metrics['FN']}") + + # 显示原始三分类分布 + print(f"\n=== Original Label Distribution ===") + unique, counts = np.unique(y_true, return_counts=True) + for label, count in zip(unique, counts): + print(f"Class {label}: {count} samples") + + +if __name__ == '__main__': + main() diff --git a/ml_new/data/dataset_storage_parquet.py b/ml_new/data/dataset_storage_parquet.py index 6a7d449..0da4d5c 100644 --- a/ml_new/data/dataset_storage_parquet.py +++ b/ml_new/data/dataset_storage_parquet.py @@ -186,6 +186,7 @@ class ParquetDatasetStorage: 'dataset_id': dataset_id, 'description': f'Auto-regenerated metadata for dataset {dataset_id}', 'stats': { + 'total_records': total_records, 'regenerated': True, 'regeneration_reason': 'missing_or_corrupted_metadata_file' }, @@ -404,7 +405,6 @@ class ParquetDatasetStorage: try: # Extract dataset_id from filename (remove .parquet extension) dataset_id = parquet_file.stem - print(dataset_id) # Try to load metadata, this will automatically regenerate if missing metadata = self.load_dataset_metadata(dataset_id) diff --git a/ml_new/embedding_models.toml b/ml_new/embedding_models.toml index 88a02f7..7c7df5b 100644 --- a/ml_new/embedding_models.toml +++ b/ml_new/embedding_models.toml @@ -2,14 +2,23 @@ model = "jina-embedding-v3-m2v" -[models.qwen3-embedding] -name = "text-embedding-v4" -dimensions = 2048 +# [models.qwen3-embedding] +# name = "text-embedding-v4" +# dimensions = 2048 +# type = "openai-compatible" +# api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1" +# max_tokens = 8192 +# max_batch_size = 10 +# api_key_env = "ALIYUN_KEY" + +[models.qwen3-embedding-8b] +name = "qwen/qwen3-embedding-8b" +dimensions = 4096 type = "openai-compatible" -api_endpoint = "https://dashscope.aliyuncs.com/compatible-mode/v1" -max_tokens = 8192 -max_batch_size = 10 -api_key_env = "ALIYUN_KEY" +api_endpoint = "https://openrouter.ai/api/v1" +max_tokens = 32768 +max_batch_size = 100 +api_key_env = "OPENROUTER_KEY" [models.jina-embedding-v3-m2v] name = "jina-embedding-v3-m2v-1024" diff --git a/ml_new/models.py b/ml_new/models.py index 1dc1db1..86cd2ec 100644 --- a/ml_new/models.py +++ b/ml_new/models.py @@ -3,6 +3,7 @@ Data models for dataset building functionality """ from typing import List, Optional, Dict, Any, Literal +import uuid from pydantic import BaseModel, Field from datetime import datetime from enum import Enum @@ -51,6 +52,7 @@ class DatasetBuildTaskStatus(BaseModel): class DatasetBuildRequest(BaseModel): """Request model for dataset building""" + id: Optional[str] = Field(str(uuid.uuid4()), description="Dataset ID") aid_list: List[int] = Field(..., description="List of video AIDs") embedding_model: str = Field(..., description="Embedding model name") force_regenerate: bool = Field(False, description="Whether to force regenerate embeddings") diff --git a/ml_new/routes/main.py b/ml_new/routes/main.py index f40ccc8..6664130 100644 --- a/ml_new/routes/main.py +++ b/ml_new/routes/main.py @@ -101,7 +101,7 @@ async def build_dataset_endpoint(request: DatasetBuildRequest): if request.embedding_model not in config_loader.get_embedding_models(): raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}") - dataset_id = str(uuid.uuid4()) + dataset_id = request.id or str(uuid.uuid4()) # Start task-based dataset building task_id = await dataset_builder.start_dataset_build_task( @@ -335,7 +335,7 @@ async def create_dataset_with_sampling_endpoint(request: DatasetCreateRequest): raise HTTPException(status_code=400, detail=f"Invalid embedding model: {request.embedding_model}") import uuid - dataset_id = str(uuid.uuid4()) + dataset_id = request.id or str(uuid.uuid4()) try: # First sample the AIDs diff --git a/ml_new/training/test.py b/ml_new/training/test.py index 37960b6..6a7e6d4 100644 --- a/ml_new/training/test.py +++ b/ml_new/training/test.py @@ -197,43 +197,9 @@ def load_model_from_experiment( def safe_extract_aid(metadata_entry): """Safely extract aid from metadata entry""" if isinstance(metadata_entry, dict) and 'aid' in metadata_entry: - return metadata_entry['aid'] + return metadata_entry['aid'].tolist() return None -def normalize_batch_metadata(metadata, expected_batch_size): - """ - Normalize batch metadata to ensure consistent structure - - Args: - metadata: Raw metadata from DataLoader (could be various formats) - expected_batch_size: Expected number of metadata entries - - Returns: - List of metadata dictionaries - """ - # Handle different metadata structures - if metadata is None: - return [{}] * expected_batch_size - - if isinstance(metadata, dict): - # Single metadata object - duplicate for entire batch - return [metadata] * expected_batch_size - - if isinstance(metadata, (list, tuple)): - if len(metadata) == expected_batch_size: - return list(metadata) - elif len(metadata) < expected_batch_size: - # Pad with empty dicts - padded = list(metadata) + [{}] * (expected_batch_size - len(metadata)) - return padded - else: - # Truncate to expected size - return list(metadata[:expected_batch_size]) - - # Unknown format - return empty dicts - logger.warning(f"Unknown metadata format: {type(metadata)}") - return [{}] * expected_batch_size - def evaluate_model( model, test_loader: DataLoader, @@ -257,7 +223,6 @@ def evaluate_model( all_predictions = [] all_labels = [] all_probabilities = [] - all_metadata = [] fn_aids = [] fp_aids = [] @@ -277,33 +242,9 @@ def evaluate_model( all_labels.extend(labels.cpu().numpy()) all_probabilities.extend(probabilities.cpu().numpy()) - # Collect metadata and track FN/FP - batch_size = len(labels) - batch_metadata = normalize_batch_metadata(metadata, batch_size) - all_metadata.extend(batch_metadata) - - # Track FN and FP aids for this batch - logger.debug(f"Batch {batch_idx}: labels shape {labels.shape}, predictions shape {predictions.shape}, metadata structure: {type(batch_metadata)}") - if len(batch_metadata) != len(labels): - logger.warning(f"Metadata length mismatch: {len(batch_metadata)} metadata entries vs {len(labels)} samples") - for i, (true_label, pred_label) in enumerate(zip(labels.cpu().numpy(), predictions.cpu().numpy())): try: - # Safely get metadata entry with bounds checking - if i >= len(batch_metadata): - logger.warning(f"Index {i} out of range for batch_metadata (length: {len(batch_metadata)})") - continue - - meta_entry = batch_metadata[i] - if not isinstance(meta_entry, dict): - logger.warning(f"Metadata entry {i} is not a dict: {type(meta_entry)}") - continue - - if 'aid' not in meta_entry: - logger.debug(f"No 'aid' key in metadata entry {i}") - continue - - aid = safe_extract_aid(meta_entry) + aid = metadata['aid'].tolist()[i] if aid is not None: if true_label == 1 and pred_label == 0: # False Negative fn_aids.append(aid) @@ -620,30 +561,28 @@ def main(): ) # Print results - logger.info("=" * 50) - logger.info("Test Results") - logger.info("=" * 50) - logger.info(f"Dataset: {args.dataset_id}") + print("Test Results") + print("=" * 50) + print(f"Dataset: {args.dataset_id}") if args.use_api: - logger.info(f"Method: API ({args.api_url})") + print(f"Method: API ({args.api_url})") else: - logger.info(f"Experiment: {args.experiment}") - logger.info(f"Total samples: {metrics['total_samples']}") - logger.info(f"Class distribution: {metrics['class_distribution']}") + print(f"Experiment: {args.experiment}") + print(f"Total samples: {metrics['total_samples']}") + print(f"Class distribution: {metrics['class_distribution']}") if 'failed_requests' in metrics: logger.info(f"Failed API requests: {metrics['failed_requests']}") - logger.info("-" * 50) - logger.info(f"Accuracy: {metrics['accuracy']:.4f}") - logger.info(f"Precision: {metrics['precision']:.4f}") - logger.info(f"Recall: {metrics['recall']:.4f}") - logger.info(f"F1 Score: {metrics['f1']:.4f}") - logger.info(f"AUC: {metrics['auc']:.4f}") - logger.info("-" * 50) - logger.info(f"True Positives: {metrics['true_positives']}") - logger.info(f"True Negatives: {metrics['true_negatives']}") - logger.info(f"False Positives: {metrics['false_positives']}") - logger.info(f"False Negatives: {metrics['false_negatives']}") - logger.info("=" * 50) + print("-" * 50) + print(f"Accuracy: {metrics['accuracy']:.4f}") + print(f"Precision: {metrics['precision']:.4f}") + print(f"Recall: {metrics['recall']:.4f}") + print(f"F1 Score: {metrics['f1']:.4f}") + print(f"AUC: {metrics['auc']:.4f}") + print(f"True Positives: {metrics['true_positives']}") + print(f"True Negatives: {metrics['true_negatives']}") + print(f"False Positives: {metrics['false_positives']}") + print(f"False Negatives: {metrics['false_negatives']}") + print("=" * 50) # Save detailed results if requested if args.output: diff --git a/packages/ml_panel/biome.json b/packages/ml_panel/biome.json new file mode 100644 index 0000000..fd88f3a --- /dev/null +++ b/packages/ml_panel/biome.json @@ -0,0 +1,7 @@ +{ + "root": false, + "$schema": "https://biomejs.dev/schemas/2.3.8/schema.json", + "linter": { + "enabled": false + } +} diff --git a/packages/ml_panel/package.json b/packages/ml_panel/package.json index efe506d..786ab2b 100644 --- a/packages/ml_panel/package.json +++ b/packages/ml_panel/package.json @@ -10,6 +10,7 @@ "preview": "vite preview" }, "dependencies": { + "@radix-ui/react-checkbox": "^1.3.3", "@radix-ui/react-dialog": "^1.1.15", "@radix-ui/react-label": "^2.1.8", "@radix-ui/react-progress": "^1.1.8", diff --git a/packages/ml_panel/src/components/DatasetManager.tsx b/packages/ml_panel/src/components/DatasetManager.tsx index 6f1dac4..4094907 100644 --- a/packages/ml_panel/src/components/DatasetManager.tsx +++ b/packages/ml_panel/src/components/DatasetManager.tsx @@ -5,7 +5,6 @@ import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/com import { Dialog, DialogContent, - DialogDescription, DialogFooter, DialogHeader, DialogTitle, @@ -24,6 +23,8 @@ import { Trash2, Plus, Database, Upload } from "lucide-react"; import { apiClient } from "@/lib/api"; import { toast } from "sonner"; import { Spinner } from "@/components/ui/spinner"; +import { Input } from "./ui/input"; +import { Checkbox } from "./ui/checkbox"; export function DatasetManager() { const [isCreateDialogOpen, setIsCreateDialogOpen] = useState(false); @@ -37,14 +38,20 @@ export function DatasetManager() { aidListFile: null as File | null, aidList: [] as number[] }); + const [loadCost, setLoadCost] = useState(0); const queryClient = useQueryClient(); // Fetch datasets const { data: datasetsData, isLoading: datasetsLoading } = useQuery({ queryKey: ["datasets"], - queryFn: () => apiClient.getDatasets(), - refetchInterval: 30000 // Refresh every 30 seconds + queryFn: async () => { + const t = performance.now(); + const r = await apiClient.getDatasets(); + setLoadCost(performance.now() - t); + return r; + }, + refetchInterval: 5000 }); // Fetch embedding models @@ -235,9 +242,9 @@ export function DatasetManager() { {/* Create Dataset Button */}
-

Dataset List

+

Datasets

- {datasetsData?.datasets?.length || 0} datasets created + {datasetsData?.datasets?.length || 0} datasets loaded. ({Math.round(loadCost)} ms)

@@ -245,21 +252,17 @@ export function DatasetManager() { - Create New Dataset - - Select sampling strategy and configuration parameters to create a - new dataset - + New Dataset
- +
{createFormData.creationMode === "sampling" && (
- +
@@ -308,8 +313,8 @@ export function DatasetManager() { {createFormData.creationMode === "sampling" && createFormData.strategy === "random" && (
- -