From 3842c63ad1844f47236f08d24c1fd7a5f96c4c44 Mon Sep 17 00:00:00 2001 From: alikia2x Date: Mon, 3 Mar 2025 02:09:45 +0800 Subject: [PATCH] add: filter model V3.17 with deployment --- filter/RunningLogs.txt | 3 +- filter/dataset.py | 3 +- filter/embedding.py | 5 +- filter/modelV3_15.py | 97 ++++++++++++++++++++++++++++++++++++ filter/onnx_export.py | 10 ++-- filter/train.py | 14 ++++-- lib/ml/benchmark.ts | 17 +++++-- lib/ml/filter_inference.ts | 8 ++- lib/mq/exec/classifyVideo.ts | 3 +- 9 files changed, 132 insertions(+), 28 deletions(-) create mode 100644 filter/modelV3_15.py diff --git a/filter/RunningLogs.txt b/filter/RunningLogs.txt index 80640f4..65b0d04 100644 --- a/filter/RunningLogs.txt +++ b/filter/RunningLogs.txt @@ -26,4 +26,5 @@ Note 2337: V3.12 # 级联分类 2350: V3.13 # V3.12, 换用普通交叉熵损失 0012: V3.11 # 换用普通交叉熵损失 -0039: V3.11 # 级联分类,但使用两个独立模型 \ No newline at end of file +0039: V3.11 # 级联分类,但使用两个独立模型 +0122: V3.15 # 删除author_info通道 \ No newline at end of file diff --git a/filter/dataset.py b/filter/dataset.py index 7a4edc1..4f992b0 100644 --- a/filter/dataset.py +++ b/filter/dataset.py @@ -103,8 +103,7 @@ class MultiChannelDataset(Dataset): texts = { 'title': example['title'], 'description': example['description'], - 'tags': tags_text, - 'author_info': example['author_info'] + 'tags': tags_text } return { diff --git a/filter/embedding.py b/filter/embedding.py index 7e9dfc6..ccecc9a 100644 --- a/filter/embedding.py +++ b/filter/embedding.py @@ -11,8 +11,7 @@ def prepare_batch(batch_data, device="cpu"): batch_data (dict): 输入的 batch 数据,格式为 { "title": [text1, text2, ...], "description": [text1, text2, ...], - "tags": [text1, text2, ...], - "author_info": [text1, text2, ...] + "tags": [text1, text2, ...] } device (str): 模型运行的设备(如 "cpu" 或 "cuda")。 @@ -22,7 +21,7 @@ def prepare_batch(batch_data, device="cpu"): # 1. 对每个通道的文本分别编码 channel_embeddings = [] model = StaticModel.from_pretrained("./model/embedding_1024/") - for channel in ["title", "description", "tags", "author_info"]: + for channel in ["title", "description", "tags"]: texts = batch_data[channel] # 获取当前通道的文本列表 embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim] channel_embeddings.append(embeddings) diff --git a/filter/modelV3_15.py b/filter/modelV3_15.py new file mode 100644 index 0000000..9e6be19 --- /dev/null +++ b/filter/modelV3_15.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class VideoClassifierV3_15(nn.Module): + def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3, temperature=1.7): + super().__init__() + self.num_channels = 3 + self.channel_names = ['title', 'description', 'tags'] + + # 可学习温度系数 + self.temperature = nn.Parameter(torch.tensor(temperature)) + + # 带约束的通道权重(使用Sigmoid替代Softmax) + self.channel_weights = nn.Parameter(torch.ones(self.num_channels)) + + # 增强的非线性层 + self.fc = nn.Sequential( + nn.Linear(embedding_dim * self.num_channels, hidden_dim*2), + nn.BatchNorm1d(hidden_dim*2), + nn.Dropout(0.2), + nn.GELU(), + nn.Linear(hidden_dim*2, output_dim) + ) + + # 权重初始化 + self._init_weights() + + def _init_weights(self): + for layer in self.fc: + if isinstance(layer, nn.Linear): + # 使用ReLU的初始化参数(GELU的近似) + nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里 + + # 或者使用Xavier初始化(更适合通用场景) + # nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu')) + + nn.init.zeros_(layer.bias) + + + def forward(self, channel_features: torch.Tensor): + """ + 输入格式: [batch_size, num_channels, embedding_dim] + 输出格式: [batch_size, output_dim] + """ + + # 自适应通道权重(Sigmoid约束) + weights = torch.sigmoid(self.channel_weights) # [0,1]范围 + weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1) + + # 特征拼接 + combined = weighted_features.view(weighted_features.size(0), -1) + + return self.fc(combined) + + def get_channel_weights(self): + """获取各通道权重(带温度调节)""" + return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy() + + +class AdaptiveRecallLoss(nn.Module): + def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5): + """ + Args: + class_weights (torch.Tensor): 类别权重 + alpha (float): 召回率调节因子(0-1) + gamma (float): Focal Loss参数 + fp_penalty (float): 类别0假阳性惩罚强度 + """ + super().__init__() + self.class_weights = class_weights + self.alpha = alpha + self.gamma = gamma + self.fp_penalty = fp_penalty + + def forward(self, logits, targets): + # 基础交叉熵损失 + ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none') + + # Focal Loss组件 + pt = torch.exp(-ce_loss) + focal_loss = ((1 - pt) ** self.gamma) * ce_loss + + # 召回率增强(对困难样本加权) + class_mask = F.one_hot(targets, num_classes=len(self.class_weights)) + class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask + recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1) + + # 类别0假阳性惩罚 + probs = F.softmax(logits, dim=1) + fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0) + fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum() + + # 总损失 + total_loss = recall_loss.mean() + fp_loss / len(targets) + + return total_loss \ No newline at end of file diff --git a/filter/onnx_export.py b/filter/onnx_export.py index cd3c615..848cda0 100644 --- a/filter/onnx_export.py +++ b/filter/onnx_export.py @@ -1,16 +1,16 @@ import torch -from modelV3_12 import VideoClassifierV3_12 +from modelV3_15 import VideoClassifierV3_15 -def export_onnx(model_path="./filter/checkpoints/best_model_V3.13.pt", - onnx_path="./model/video_classifier_v3_13.onnx"): +def export_onnx(model_path="./filter/checkpoints/best_model_V3.17.pt", + onnx_path="./model/video_classifier_v3_17.onnx"): # 初始化模型 - model = VideoClassifierV3_12() + model = VideoClassifierV3_15() model.load_state_dict(torch.load(model_path)) model.eval() # 创建符合输入规范的虚拟输入 - dummy_input = torch.randn(1, 4, 1024) # [batch=1, channels=4, embedding_dim=1024] + dummy_input = torch.randn(1, 3, 1024) # [batch=1, channels=4, embedding_dim=1024] # 导出ONNX torch.onnx.export( diff --git a/filter/train.py b/filter/train.py index 065d38e..dca219f 100644 --- a/filter/train.py +++ b/filter/train.py @@ -4,7 +4,7 @@ import numpy as np from torch.utils.data import DataLoader import torch.optim as optim from dataset import MultiChannelDataset -from filter.modelV3_12 import VideoClassifierV3_12 +from filter.modelV3_15 import AdaptiveRecallLoss, VideoClassifierV3_15 from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report import os import torch @@ -51,16 +51,20 @@ class_weights = torch.tensor( ) # 初始化模型和SentenceTransformer -model = VideoClassifierV3_12() -checkpoint_name = './filter/checkpoints/best_model_V3.12.pt' +model = VideoClassifierV3_15() +checkpoint_name = './filter/checkpoints/best_model_V3.17.pt' # 模型保存路径 os.makedirs('./filter/checkpoints', exist_ok=True) # 优化器 optimizer = optim.AdamW(model.parameters(), lr=4e-4) -# Cross entropy loss -criterion = nn.CrossEntropyLoss() +criterion = AdaptiveRecallLoss( + class_weights=class_weights, + alpha=0.9, # 召回率权重 + gamma=1.6, # 困难样本聚焦 + fp_penalty=0.8 # 假阳性惩罚强度 +) def count_trainable_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) diff --git a/lib/ml/benchmark.ts b/lib/ml/benchmark.ts index d15a2f8..478b224 100644 --- a/lib/ml/benchmark.ts +++ b/lib/ml/benchmark.ts @@ -4,9 +4,9 @@ import { softmax } from "lib/ml/filter_inference.ts"; // 配置参数 const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024"; -const onnxClassifierPath = "./model/video_classifier_v3_11.onnx"; +const onnxClassifierPath = "./model/video_classifier_v3_17.onnx"; const onnxEmbeddingPath = "./model/embedding_original.onnx"; -const testDataPath = "./data/filter/test.jsonl"; +const testDataPath = "./data/filter/test1.jsonl"; // 初始化会话 const [sessionClassifier, sessionEmbedding] = await Promise.all([ @@ -53,7 +53,7 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession) async function runClassification(embeddings: number[]): Promise { const inputTensor = new ort.Tensor( Float32Array.from(embeddings), - [1, 4, 1024], + [1, 3, 1024], ); const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); @@ -69,6 +69,14 @@ function calculateMetrics(labels: number[], predictions: number[], elapsedTime: "Class 0 Prec": number; speed: string; } { + // 输出label和prediction不一样的index列表 + const arr = [] + for (let i = 0; i < labels.length; i++) { + if (labels[i] !== predictions[i] && predictions[i] == 0) { + arr.push([i + 1, labels[i], predictions[i]]) + } + } + console.log(arr) // 初始化混淆矩阵 const classCount = Math.max(...labels, ...predictions) + 1; const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0)); @@ -138,8 +146,7 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{ const embeddings = await getONNXEmbeddings([ sample.title, sample.description, - sample.tags.join(","), - sample.author_info, + sample.tags.join(",") ], session); const probabilities = await runClassification(embeddings); diff --git a/lib/ml/filter_inference.ts b/lib/ml/filter_inference.ts index 8758b4d..019061f 100644 --- a/lib/ml/filter_inference.ts +++ b/lib/ml/filter_inference.ts @@ -4,9 +4,9 @@ import logger from "lib/log/logger.ts"; import { WorkerError } from "lib/mq/schema.ts"; const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; -const onnxClassifierPath = "./model/video_classifier_v3_11.onnx"; +const onnxClassifierPath = "./model/video_classifier_v3_17.onnx"; const onnxEmbeddingOriginalPath = "./model/model.onnx"; -export const modelVersion = "3.11"; +export const modelVersion = "3.17"; let sessionClassifier: ort.InferenceSession | null = null; let sessionEmbedding: ort.InferenceSession | null = null; @@ -72,7 +72,7 @@ async function runClassification(embeddings: number[]): Promise { } const inputTensor = new ort.Tensor( Float32Array.from(embeddings), - [1, 4, 1024], + [1, 3, 1024], ); const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); @@ -83,7 +83,6 @@ export async function classifyVideo( title: string, description: string, tags: string, - author_info: string, aid: number, ): Promise { if (!sessionEmbedding) { @@ -93,7 +92,6 @@ export async function classifyVideo( title, description, tags, - author_info, ], sessionEmbedding); const probabilities = await runClassification(embeddings); logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml"); diff --git a/lib/mq/exec/classifyVideo.ts b/lib/mq/exec/classifyVideo.ts index 7641383..26d7053 100644 --- a/lib/mq/exec/classifyVideo.ts +++ b/lib/mq/exec/classifyVideo.ts @@ -17,8 +17,7 @@ export const classifyVideoWorker = async (job: Job) => { const title = videoInfo.title?.trim() || "untitled"; const description = videoInfo.description?.trim() || "N/A"; const tags = videoInfo.tags?.trim() || "empty"; - const authorInfo = videoInfo.author_info || "N/A"; - const label = await classifyVideo(title, description, tags, authorInfo, aid); + const label = await classifyVideo(title, description, tags, aid); if (label == -1) { logger.warn(`Failed to classify video ${aid}`, "ml"); }