add: filter model V3.17 with deployment
This commit is contained in:
parent
748e2e2aaa
commit
3842c63ad1
@ -27,3 +27,4 @@ Note
|
|||||||
2350: V3.13 # V3.12, 换用普通交叉熵损失
|
2350: V3.13 # V3.12, 换用普通交叉熵损失
|
||||||
0012: V3.11 # 换用普通交叉熵损失
|
0012: V3.11 # 换用普通交叉熵损失
|
||||||
0039: V3.11 # 级联分类,但使用两个独立模型
|
0039: V3.11 # 级联分类,但使用两个独立模型
|
||||||
|
0122: V3.15 # 删除author_info通道
|
@ -103,8 +103,7 @@ class MultiChannelDataset(Dataset):
|
|||||||
texts = {
|
texts = {
|
||||||
'title': example['title'],
|
'title': example['title'],
|
||||||
'description': example['description'],
|
'description': example['description'],
|
||||||
'tags': tags_text,
|
'tags': tags_text
|
||||||
'author_info': example['author_info']
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -11,8 +11,7 @@ def prepare_batch(batch_data, device="cpu"):
|
|||||||
batch_data (dict): 输入的 batch 数据,格式为 {
|
batch_data (dict): 输入的 batch 数据,格式为 {
|
||||||
"title": [text1, text2, ...],
|
"title": [text1, text2, ...],
|
||||||
"description": [text1, text2, ...],
|
"description": [text1, text2, ...],
|
||||||
"tags": [text1, text2, ...],
|
"tags": [text1, text2, ...]
|
||||||
"author_info": [text1, text2, ...]
|
|
||||||
}
|
}
|
||||||
device (str): 模型运行的设备(如 "cpu" 或 "cuda")。
|
device (str): 模型运行的设备(如 "cpu" 或 "cuda")。
|
||||||
|
|
||||||
@ -22,7 +21,7 @@ def prepare_batch(batch_data, device="cpu"):
|
|||||||
# 1. 对每个通道的文本分别编码
|
# 1. 对每个通道的文本分别编码
|
||||||
channel_embeddings = []
|
channel_embeddings = []
|
||||||
model = StaticModel.from_pretrained("./model/embedding_1024/")
|
model = StaticModel.from_pretrained("./model/embedding_1024/")
|
||||||
for channel in ["title", "description", "tags", "author_info"]:
|
for channel in ["title", "description", "tags"]:
|
||||||
texts = batch_data[channel] # 获取当前通道的文本列表
|
texts = batch_data[channel] # 获取当前通道的文本列表
|
||||||
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
|
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
|
||||||
channel_embeddings.append(embeddings)
|
channel_embeddings.append(embeddings)
|
||||||
|
97
filter/modelV3_15.py
Normal file
97
filter/modelV3_15.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class VideoClassifierV3_15(nn.Module):
|
||||||
|
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3, temperature=1.7):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = 3
|
||||||
|
self.channel_names = ['title', 'description', 'tags']
|
||||||
|
|
||||||
|
# 可学习温度系数
|
||||||
|
self.temperature = nn.Parameter(torch.tensor(temperature))
|
||||||
|
|
||||||
|
# 带约束的通道权重(使用Sigmoid替代Softmax)
|
||||||
|
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
||||||
|
|
||||||
|
# 增强的非线性层
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
||||||
|
nn.BatchNorm1d(hidden_dim*2),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(hidden_dim*2, output_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 权重初始化
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for layer in self.fc:
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
# 使用ReLU的初始化参数(GELU的近似)
|
||||||
|
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里
|
||||||
|
|
||||||
|
# 或者使用Xavier初始化(更适合通用场景)
|
||||||
|
# nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu'))
|
||||||
|
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, channel_features: torch.Tensor):
|
||||||
|
"""
|
||||||
|
输入格式: [batch_size, num_channels, embedding_dim]
|
||||||
|
输出格式: [batch_size, output_dim]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 自适应通道权重(Sigmoid约束)
|
||||||
|
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
|
||||||
|
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
|
||||||
|
|
||||||
|
# 特征拼接
|
||||||
|
combined = weighted_features.view(weighted_features.size(0), -1)
|
||||||
|
|
||||||
|
return self.fc(combined)
|
||||||
|
|
||||||
|
def get_channel_weights(self):
|
||||||
|
"""获取各通道权重(带温度调节)"""
|
||||||
|
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveRecallLoss(nn.Module):
|
||||||
|
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
class_weights (torch.Tensor): 类别权重
|
||||||
|
alpha (float): 召回率调节因子(0-1)
|
||||||
|
gamma (float): Focal Loss参数
|
||||||
|
fp_penalty (float): 类别0假阳性惩罚强度
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.class_weights = class_weights
|
||||||
|
self.alpha = alpha
|
||||||
|
self.gamma = gamma
|
||||||
|
self.fp_penalty = fp_penalty
|
||||||
|
|
||||||
|
def forward(self, logits, targets):
|
||||||
|
# 基础交叉熵损失
|
||||||
|
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
|
||||||
|
|
||||||
|
# Focal Loss组件
|
||||||
|
pt = torch.exp(-ce_loss)
|
||||||
|
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
|
||||||
|
|
||||||
|
# 召回率增强(对困难样本加权)
|
||||||
|
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
|
||||||
|
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
|
||||||
|
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
|
||||||
|
|
||||||
|
# 类别0假阳性惩罚
|
||||||
|
probs = F.softmax(logits, dim=1)
|
||||||
|
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
|
||||||
|
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
|
||||||
|
|
||||||
|
# 总损失
|
||||||
|
total_loss = recall_loss.mean() + fp_loss / len(targets)
|
||||||
|
|
||||||
|
return total_loss
|
@ -1,16 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
from modelV3_12 import VideoClassifierV3_12
|
from modelV3_15 import VideoClassifierV3_15
|
||||||
|
|
||||||
|
|
||||||
def export_onnx(model_path="./filter/checkpoints/best_model_V3.13.pt",
|
def export_onnx(model_path="./filter/checkpoints/best_model_V3.17.pt",
|
||||||
onnx_path="./model/video_classifier_v3_13.onnx"):
|
onnx_path="./model/video_classifier_v3_17.onnx"):
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
model = VideoClassifierV3_12()
|
model = VideoClassifierV3_15()
|
||||||
model.load_state_dict(torch.load(model_path))
|
model.load_state_dict(torch.load(model_path))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# 创建符合输入规范的虚拟输入
|
# 创建符合输入规范的虚拟输入
|
||||||
dummy_input = torch.randn(1, 4, 1024) # [batch=1, channels=4, embedding_dim=1024]
|
dummy_input = torch.randn(1, 3, 1024) # [batch=1, channels=4, embedding_dim=1024]
|
||||||
|
|
||||||
# 导出ONNX
|
# 导出ONNX
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from dataset import MultiChannelDataset
|
from dataset import MultiChannelDataset
|
||||||
from filter.modelV3_12 import VideoClassifierV3_12
|
from filter.modelV3_15 import AdaptiveRecallLoss, VideoClassifierV3_15
|
||||||
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
@ -51,16 +51,20 @@ class_weights = torch.tensor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 初始化模型和SentenceTransformer
|
# 初始化模型和SentenceTransformer
|
||||||
model = VideoClassifierV3_12()
|
model = VideoClassifierV3_15()
|
||||||
checkpoint_name = './filter/checkpoints/best_model_V3.12.pt'
|
checkpoint_name = './filter/checkpoints/best_model_V3.17.pt'
|
||||||
|
|
||||||
# 模型保存路径
|
# 模型保存路径
|
||||||
os.makedirs('./filter/checkpoints', exist_ok=True)
|
os.makedirs('./filter/checkpoints', exist_ok=True)
|
||||||
|
|
||||||
# 优化器
|
# 优化器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=4e-4)
|
optimizer = optim.AdamW(model.parameters(), lr=4e-4)
|
||||||
# Cross entropy loss
|
criterion = AdaptiveRecallLoss(
|
||||||
criterion = nn.CrossEntropyLoss()
|
class_weights=class_weights,
|
||||||
|
alpha=0.9, # 召回率权重
|
||||||
|
gamma=1.6, # 困难样本聚焦
|
||||||
|
fp_penalty=0.8 # 假阳性惩罚强度
|
||||||
|
)
|
||||||
|
|
||||||
def count_trainable_parameters(model):
|
def count_trainable_parameters(model):
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
@ -4,9 +4,9 @@ import { softmax } from "lib/ml/filter_inference.ts";
|
|||||||
|
|
||||||
// 配置参数
|
// 配置参数
|
||||||
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
const sentenceTransformerModelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||||
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
|
const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
|
||||||
const onnxEmbeddingPath = "./model/embedding_original.onnx";
|
const onnxEmbeddingPath = "./model/embedding_original.onnx";
|
||||||
const testDataPath = "./data/filter/test.jsonl";
|
const testDataPath = "./data/filter/test1.jsonl";
|
||||||
|
|
||||||
// 初始化会话
|
// 初始化会话
|
||||||
const [sessionClassifier, sessionEmbedding] = await Promise.all([
|
const [sessionClassifier, sessionEmbedding] = await Promise.all([
|
||||||
@ -53,7 +53,7 @@ async function getONNXEmbeddings(texts: string[], session: ort.InferenceSession)
|
|||||||
async function runClassification(embeddings: number[]): Promise<number[]> {
|
async function runClassification(embeddings: number[]): Promise<number[]> {
|
||||||
const inputTensor = new ort.Tensor(
|
const inputTensor = new ort.Tensor(
|
||||||
Float32Array.from(embeddings),
|
Float32Array.from(embeddings),
|
||||||
[1, 4, 1024],
|
[1, 3, 1024],
|
||||||
);
|
);
|
||||||
|
|
||||||
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
||||||
@ -69,6 +69,14 @@ function calculateMetrics(labels: number[], predictions: number[], elapsedTime:
|
|||||||
"Class 0 Prec": number;
|
"Class 0 Prec": number;
|
||||||
speed: string;
|
speed: string;
|
||||||
} {
|
} {
|
||||||
|
// 输出label和prediction不一样的index列表
|
||||||
|
const arr = []
|
||||||
|
for (let i = 0; i < labels.length; i++) {
|
||||||
|
if (labels[i] !== predictions[i] && predictions[i] == 0) {
|
||||||
|
arr.push([i + 1, labels[i], predictions[i]])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
console.log(arr)
|
||||||
// 初始化混淆矩阵
|
// 初始化混淆矩阵
|
||||||
const classCount = Math.max(...labels, ...predictions) + 1;
|
const classCount = Math.max(...labels, ...predictions) + 1;
|
||||||
const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
|
const matrix = Array.from({ length: classCount }, () => Array.from({ length: classCount }, () => 0));
|
||||||
@ -138,8 +146,7 @@ async function evaluateModel(session: ort.InferenceSession): Promise<{
|
|||||||
const embeddings = await getONNXEmbeddings([
|
const embeddings = await getONNXEmbeddings([
|
||||||
sample.title,
|
sample.title,
|
||||||
sample.description,
|
sample.description,
|
||||||
sample.tags.join(","),
|
sample.tags.join(",")
|
||||||
sample.author_info,
|
|
||||||
], session);
|
], session);
|
||||||
|
|
||||||
const probabilities = await runClassification(embeddings);
|
const probabilities = await runClassification(embeddings);
|
||||||
|
@ -4,9 +4,9 @@ import logger from "lib/log/logger.ts";
|
|||||||
import { WorkerError } from "lib/mq/schema.ts";
|
import { WorkerError } from "lib/mq/schema.ts";
|
||||||
|
|
||||||
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||||
const onnxClassifierPath = "./model/video_classifier_v3_11.onnx";
|
const onnxClassifierPath = "./model/video_classifier_v3_17.onnx";
|
||||||
const onnxEmbeddingOriginalPath = "./model/model.onnx";
|
const onnxEmbeddingOriginalPath = "./model/model.onnx";
|
||||||
export const modelVersion = "3.11";
|
export const modelVersion = "3.17";
|
||||||
|
|
||||||
let sessionClassifier: ort.InferenceSession | null = null;
|
let sessionClassifier: ort.InferenceSession | null = null;
|
||||||
let sessionEmbedding: ort.InferenceSession | null = null;
|
let sessionEmbedding: ort.InferenceSession | null = null;
|
||||||
@ -72,7 +72,7 @@ async function runClassification(embeddings: number[]): Promise<number[]> {
|
|||||||
}
|
}
|
||||||
const inputTensor = new ort.Tensor(
|
const inputTensor = new ort.Tensor(
|
||||||
Float32Array.from(embeddings),
|
Float32Array.from(embeddings),
|
||||||
[1, 4, 1024],
|
[1, 3, 1024],
|
||||||
);
|
);
|
||||||
|
|
||||||
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
const { logits } = await sessionClassifier.run({ channel_features: inputTensor });
|
||||||
@ -83,7 +83,6 @@ export async function classifyVideo(
|
|||||||
title: string,
|
title: string,
|
||||||
description: string,
|
description: string,
|
||||||
tags: string,
|
tags: string,
|
||||||
author_info: string,
|
|
||||||
aid: number,
|
aid: number,
|
||||||
): Promise<number> {
|
): Promise<number> {
|
||||||
if (!sessionEmbedding) {
|
if (!sessionEmbedding) {
|
||||||
@ -93,7 +92,6 @@ export async function classifyVideo(
|
|||||||
title,
|
title,
|
||||||
description,
|
description,
|
||||||
tags,
|
tags,
|
||||||
author_info,
|
|
||||||
], sessionEmbedding);
|
], sessionEmbedding);
|
||||||
const probabilities = await runClassification(embeddings);
|
const probabilities = await runClassification(embeddings);
|
||||||
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml");
|
logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml");
|
||||||
|
@ -17,8 +17,7 @@ export const classifyVideoWorker = async (job: Job) => {
|
|||||||
const title = videoInfo.title?.trim() || "untitled";
|
const title = videoInfo.title?.trim() || "untitled";
|
||||||
const description = videoInfo.description?.trim() || "N/A";
|
const description = videoInfo.description?.trim() || "N/A";
|
||||||
const tags = videoInfo.tags?.trim() || "empty";
|
const tags = videoInfo.tags?.trim() || "empty";
|
||||||
const authorInfo = videoInfo.author_info || "N/A";
|
const label = await classifyVideo(title, description, tags, aid);
|
||||||
const label = await classifyVideo(title, description, tags, authorInfo, aid);
|
|
||||||
if (label == -1) {
|
if (label == -1) {
|
||||||
logger.warn(`Failed to classify video ${aid}`, "ml");
|
logger.warn(`Failed to classify video ${aid}`, "ml");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user