update: filter model v3.8
This commit is contained in:
parent
546a9f7bc8
commit
175f3c3f6b
@ -16,3 +16,6 @@ Note
|
|||||||
0229: V3.3-test3 # 重新回到V3迭代
|
0229: V3.3-test3 # 重新回到V3迭代
|
||||||
0316: V3.4-test3 # 模型架构修改,自定义Loss与FC层修改
|
0316: V3.4-test3 # 模型架构修改,自定义Loss与FC层修改
|
||||||
0324: V3.5-test3 # 用回3.2的FC层试试
|
0324: V3.5-test3 # 用回3.2的FC层试试
|
||||||
|
0331: V3.6-test3 # 3.5不太行,我试着调下超参
|
||||||
|
0335: V3.7-test3 # 3.6还行,再调超参试试看
|
||||||
|
0352: V3.8-test3 # 3.7不行,从3.6的基础重新调
|
@ -1,109 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class VideoClassifierV3_5(nn.Module):
|
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3):
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = 4
|
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
|
||||||
|
|
||||||
# 可学习温度系数
|
|
||||||
self.temperature = nn.Parameter(torch.tensor(1.7))
|
|
||||||
|
|
||||||
# 带约束的通道权重(使用Sigmoid替代Softmax)
|
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
|
||||||
|
|
||||||
# 增强的非线性层
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
|
||||||
nn.BatchNorm1d(hidden_dim*2),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim*2, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.Linear(hidden_dim, output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 权重初始化
|
|
||||||
self._init_weights()
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
for layer in self.fc:
|
|
||||||
if isinstance(layer, nn.Linear):
|
|
||||||
# 使用ReLU的初始化参数(GELU的近似)
|
|
||||||
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里
|
|
||||||
|
|
||||||
# 或者使用Xavier初始化(更适合通用场景)
|
|
||||||
# nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu'))
|
|
||||||
|
|
||||||
nn.init.zeros_(layer.bias)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input_texts, sentence_transformer):
|
|
||||||
# 合并文本进行批量编码
|
|
||||||
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
|
|
||||||
|
|
||||||
# 冻结的文本编码
|
|
||||||
with torch.no_grad():
|
|
||||||
embeddings = torch.tensor(
|
|
||||||
sentence_transformer.encode(all_texts),
|
|
||||||
device=next(self.parameters()).device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 分割并加权通道特征
|
|
||||||
split_sizes = [len(input_texts[name]) for name in self.channel_names]
|
|
||||||
channel_features = torch.split(embeddings, split_sizes, dim=0)
|
|
||||||
channel_features = torch.stack(channel_features, dim=1)
|
|
||||||
|
|
||||||
# 自适应通道权重(Sigmoid约束)
|
|
||||||
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
|
|
||||||
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
|
|
||||||
|
|
||||||
# 特征拼接
|
|
||||||
combined = weighted_features.view(weighted_features.size(0), -1)
|
|
||||||
|
|
||||||
return self.fc(combined)
|
|
||||||
|
|
||||||
def get_channel_weights(self):
|
|
||||||
"""获取各通道权重(带温度调节)"""
|
|
||||||
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveRecallLoss(nn.Module):
|
|
||||||
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
class_weights (torch.Tensor): 类别权重
|
|
||||||
alpha (float): 召回率调节因子(0-1)
|
|
||||||
gamma (float): Focal Loss参数
|
|
||||||
fp_penalty (float): 类别0假阳性惩罚强度
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.class_weights = class_weights
|
|
||||||
self.alpha = alpha
|
|
||||||
self.gamma = gamma
|
|
||||||
self.fp_penalty = fp_penalty
|
|
||||||
|
|
||||||
def forward(self, logits, targets):
|
|
||||||
# 基础交叉熵损失
|
|
||||||
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
|
|
||||||
|
|
||||||
# Focal Loss组件
|
|
||||||
pt = torch.exp(-ce_loss)
|
|
||||||
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
|
|
||||||
|
|
||||||
# 召回率增强(对困难样本加权)
|
|
||||||
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
|
|
||||||
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
|
|
||||||
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
|
|
||||||
|
|
||||||
# 类别0假阳性惩罚
|
|
||||||
probs = F.softmax(logits, dim=1)
|
|
||||||
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
|
|
||||||
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
|
|
||||||
|
|
||||||
# 总损失
|
|
||||||
total_loss = recall_loss.mean() + fp_loss / len(targets)
|
|
||||||
|
|
||||||
return total_loss
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from dataset import MultiChannelDataset
|
from dataset import MultiChannelDataset
|
||||||
from modelV3_5 import VideoClassifierV3_5, AdaptiveRecallLoss
|
from modelV3_4 import VideoClassifierV3_4, AdaptiveRecallLoss
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
||||||
import os
|
import os
|
||||||
@ -52,18 +52,18 @@ class_weights = torch.tensor(
|
|||||||
|
|
||||||
# 初始化模型和SentenceTransformer
|
# 初始化模型和SentenceTransformer
|
||||||
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
||||||
model = VideoClassifierV3_5()
|
model = VideoClassifierV3_4()
|
||||||
checkpoint_name = './filter/checkpoints/best_model_V3.5.pt'
|
checkpoint_name = './filter/checkpoints/best_model_V3.8.pt'
|
||||||
|
|
||||||
# 模型保存路径
|
# 模型保存路径
|
||||||
os.makedirs('./filter/checkpoints', exist_ok=True)
|
os.makedirs('./filter/checkpoints', exist_ok=True)
|
||||||
|
|
||||||
# 优化器
|
# 优化器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=2e-3)
|
optimizer = optim.AdamW(model.parameters(), lr=4e-4)
|
||||||
criterion = AdaptiveRecallLoss(
|
criterion = AdaptiveRecallLoss(
|
||||||
class_weights=class_weights,
|
class_weights=class_weights,
|
||||||
alpha=0.7, # 召回率权重
|
alpha=0.9, # 召回率权重
|
||||||
gamma=1.5, # 困难样本聚焦
|
gamma=1.6, # 困难样本聚焦
|
||||||
fp_penalty=0.8 # 假阳性惩罚强度
|
fp_penalty=0.8 # 假阳性惩罚强度
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user