add: filter model V3.17 with deployment

This commit is contained in:
alikia2x (寒寒) 2025-03-03 02:09:45 +08:00
parent 748e2e2aaa
commit 3842c63ad1
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
9 changed files with 132 additions and 28 deletions

View File

@ -27,3 +27,4 @@ Note
2350: V3.13 # V3.12, 换用普通交叉熵损失 2350: V3.13 # V3.12, 换用普通交叉熵损失
0012: V3.11 # 换用普通交叉熵损失 0012: V3.11 # 换用普通交叉熵损失
0039: V3.11 # 级联分类,但使用两个独立模型 0039: V3.11 # 级联分类,但使用两个独立模型
0122: V3.15 # 删除author_info通道

View File

@ -103,8 +103,7 @@ class MultiChannelDataset(Dataset):
texts = { texts = {
'title': example['title'], 'title': example['title'],
'description': example['description'], 'description': example['description'],
'tags': tags_text, 'tags': tags_text
'author_info': example['author_info']
} }
return { return {

View File

@ -11,8 +11,7 @@ def prepare_batch(batch_data, device="cpu"):
batch_data (dict): 输入的 batch 数据格式为 { batch_data (dict): 输入的 batch 数据格式为 {
"title": [text1, text2, ...], "title": [text1, text2, ...],
"description": [text1, text2, ...], "description": [text1, text2, ...],
"tags": [text1, text2, ...], "tags": [text1, text2, ...]
"author_info": [text1, text2, ...]
} }
device (str): 模型运行的设备 "cpu" "cuda" device (str): 模型运行的设备 "cpu" "cuda"
@ -22,7 +21,7 @@ def prepare_batch(batch_data, device="cpu"):
# 1. 对每个通道的文本分别编码 # 1. 对每个通道的文本分别编码
channel_embeddings = [] channel_embeddings = []
model = StaticModel.from_pretrained("./model/embedding_1024/") model = StaticModel.from_pretrained("./model/embedding_1024/")
for channel in ["title", "description", "tags", "author_info"]: for channel in ["title", "description", "tags"]:
texts = batch_data[channel] # 获取当前通道的文本列表 texts = batch_data[channel] # 获取当前通道的文本列表
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim] embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
channel_embeddings.append(embeddings) channel_embeddings.append(embeddings)

97
filter/modelV3_15.py Normal file
View File

@ -0,0 +1,97 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV3_15(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3, temperature=1.7):
super().__init__()
self.num_channels = 3
self.channel_names = ['title', 'description', 'tags']
# 可学习温度系数
self.temperature = nn.Parameter(torch.tensor(temperature))
# 带约束的通道权重使用Sigmoid替代Softmax
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
# 增强的非线性层
self.fc = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim*2, output_dim)
)
# 权重初始化
self._init_weights()
def _init_weights(self):
for layer in self.fc:
if isinstance(layer, nn.Linear):
# 使用ReLU的初始化参数GELU的近似
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里
# 或者使用Xavier初始化更适合通用场景
# nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu'))
nn.init.zeros_(layer.bias)
def forward(self, channel_features: torch.Tensor):
"""
输入格式: [batch_size, num_channels, embedding_dim]
输出格式: [batch_size, output_dim]
"""
# 自适应通道权重Sigmoid约束
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 特征拼接
combined = weighted_features.view(weighted_features.size(0), -1)
return self.fc(combined)
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
class AdaptiveRecallLoss(nn.Module):
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
"""
Args:
class_weights (torch.Tensor): 类别权重
alpha (float): 召回率调节因子0-1
gamma (float): Focal Loss参数
fp_penalty (float): 类别0假阳性惩罚强度
"""
super().__init__()
self.class_weights = class_weights
self.alpha = alpha
self.gamma = gamma
self.fp_penalty = fp_penalty
def forward(self, logits, targets):
# 基础交叉熵损失
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
# Focal Loss组件
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
# 召回率增强(对困难样本加权)
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
# 类别0假阳性惩罚
probs = F.softmax(logits, dim=1)
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
# 总损失
total_loss = recall_loss.mean() + fp_loss / len(targets)
return total_loss

View File

@ -1,16 +1,16 @@
import torch import torch
from modelV3_12 import VideoClassifierV3_12 from modelV3_15 import VideoClassifierV3_15
def export_onnx(model_path="./filter/checkpoints/best_model_V3.13.pt", def export_onnx(model_path="./filter/checkpoints/best_model_V3.17.pt",
onnx_path="./model/video_classifier_v3_13.onnx"): onnx_path="./model/video_classifier_v3_17.onnx"):
# 初始化模型 # 初始化模型
model = VideoClassifierV3_12() model = VideoClassifierV3_15()
model.load_state_dict(torch.load(model_path)) model.load_state_dict(torch.load(model_path))
model.eval() model.eval()
# 创建符合输入规范的虚拟输入 # 创建符合输入规范的虚拟输入
dummy_input = torch.randn(1, 4, 1024) # [batch=1, channels=4, embedding_dim=1024] dummy_input = torch.randn(1, 3, 1024) # [batch=1, channels=4, embedding_dim=1024]
# 导出ONNX # 导出ONNX
torch.onnx.export( torch.onnx.export(

View File

@ -4,7 +4,7 @@ import numpy as np
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.optim as optim import torch.optim as optim
from dataset import MultiChannelDataset from dataset import MultiChannelDataset
from filter.modelV3_12 import VideoClassifierV3_12 from filter.modelV3_15 import AdaptiveRecallLoss, VideoClassifierV3_15
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
import os import os
import torch import torch
@ -51,16 +51,20 @@ class_weights = torch.tensor(
) )
# 初始化模型和SentenceTransformer # 初始化模型和SentenceTransformer
model = VideoClassifierV3_12() model = VideoClassifierV3_15()
checkpoint_name = './filter/checkpoints/best_model_V3.12.pt' checkpoint_name = './filter/checkpoints/best_model_V3.17.pt'
# 模型保存路径 # 模型保存路径
os.makedirs('./filter/checkpoints', exist_ok=True) os.makedirs('./filter/checkpoints', exist_ok=True)
# 优化器 # 优化器
optimizer = optim.AdamW(model.parameters(), lr=4e-4) optimizer = optim.AdamW(model.parameters(), lr=4e-4)
# Cross entropy loss criterion = AdaptiveRecallLoss(
criterion = nn.CrossEntropyLoss() class_weights=class_weights,
alpha=0.9, # 召回率权重
gamma=1.6, # 困难样本聚焦
fp_penalty=0.8 # 假阳性惩罚强度
)
def count_trainable_parameters(model): def count_trainable_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)

View File

@ -4,9 +4,9 @@ import { softmax } from "lib/ml/filter_inference.ts";
// 配置参数 // 配置参数
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024"; const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx"; const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
const onnxEmbeddingPath = "./model/embedding_original.onnx"; const onnxEmbeddingPath = "./model/embedding_original.onnx";
const testDataPath = "./data/filter/test.jsonl"; const testDataPath = "./data/filter/test1.jsonl";
// 初始化会话 // 初始化会话
const [sessionClassifier, sessionEmbedding] = await Promise.all([ const [sessionClassifier, sessionEmbedding] = await Promise.all([
@ -53,7 +53,7 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
async function runClassification(embeddings: number[]): Promise<number[]> { async function runClassification(embeddings: number[]): Promise<number[]> {
const inputTensor = new ort.Tensor( const inputTensor = new ort.Tensor(
Float32Array.from(embeddings), Float32Array.from(embeddings),
[1, 4, 1024], [1, 3, 1024],
); );
const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
@ -69,6 +69,14 @@ function calculateMetrics(labels: number[], predictions: number[], elapsedTime:
"Class 0 Prec": number; "Class 0 Prec": number;
speed: string; speed: string;
} { } {
// 输出label和prediction不一样的index列表
const arr = []
for (let i = 0; i < labels.length; i++) {
if (labels[i] !== predictions[i] && predictions[i] == 0) {
arr.push([i + 1, labels[i], predictions[i]])
}
}
console.log(arr)
// 初始化混淆矩阵 // 初始化混淆矩阵
const classCount = Math.max(...labels, ...predictions) + 1; const classCount = Math.max(...labels, ...predictions) + 1;
const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0)); const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
@ -138,8 +146,7 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
const embeddings = await getONNXEmbeddings([ const embeddings = await getONNXEmbeddings([
sample.title, sample.title,
sample.description, sample.description,
sample.tags.join(","), sample.tags.join(",")
sample.author_info,
], session); ], session);
const probabilities = await runClassification(embeddings); const probabilities = await runClassification(embeddings);

View File

@ -4,9 +4,9 @@ import logger from "lib/log/logger.ts";
import { WorkerError } from "lib/mq/schema.ts"; import { WorkerError } from "lib/mq/schema.ts";
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx"; const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
const onnxEmbeddingOriginalPath = "./model/model.onnx"; const onnxEmbeddingOriginalPath = "./model/model.onnx";
export const modelVersion = "3.11"; export const modelVersion = "3.17";
let sessionClassifier: ort.InferenceSession | null = null; let sessionClassifier: ort.InferenceSession | null = null;
let sessionEmbedding: ort.InferenceSession | null = null; let sessionEmbedding: ort.InferenceSession | null = null;
@ -72,7 +72,7 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
} }
const inputTensor = new ort.Tensor( const inputTensor = new ort.Tensor(
Float32Array.from(embeddings), Float32Array.from(embeddings),
[1, 4, 1024], [1, 3, 1024],
); );
const { logits } = await sessionClassifier.run({ channel_features: inputTensor }); const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
@ -83,7 +83,6 @@ export async function classifyVideo(
title: string, title: string,
description: string, description: string,
tags: string, tags: string,
author_info: string,
aid: number, aid: number,
): Promise<number> { ): Promise<number> {
if (!sessionEmbedding) { if (!sessionEmbedding) {
@ -93,7 +92,6 @@ export async function classifyVideo(
title, title,
description, description,
tags, tags,
author_info,
], sessionEmbedding); ], sessionEmbedding);
const probabilities = await runClassification(embeddings); const probabilities = await runClassification(embeddings);
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml"); logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml");

View File

@ -17,8 +17,7 @@ export const classifyVideoWorker = async (job: Job) => {
const title = videoInfo.title?.trim() || "untitled"; const title = videoInfo.title?.trim() || "untitled";
const description = videoInfo.description?.trim() || "N/A"; const description = videoInfo.description?.trim() || "N/A";
const tags = videoInfo.tags?.trim() || "empty"; const tags = videoInfo.tags?.trim() || "empty";
const authorInfo = videoInfo.author_info || "N/A"; const label = await classifyVideo(title, description, tags, aid);
const label = await classifyVideo(title, description, tags, authorInfo, aid);
if (label == -1) { if (label == -1) {
logger.warn(`Failed to classify video ${aid}`, "ml"); logger.warn(`Failed to classify video ${aid}`, "ml");
} }