cvsa/filter/modelV3_15.py

97 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV3_15(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3, temperature=1.7):
super().__init__()
self.num_channels = 3
self.channel_names = ['title', 'description', 'tags']
# 可学习温度系数
self.temperature = nn.Parameter(torch.tensor(temperature))
# 带约束的通道权重使用Sigmoid替代Softmax
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
# 增强的非线性层
self.fc = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim*2, output_dim)
)
# 权重初始化
self._init_weights()
def _init_weights(self):
for layer in self.fc:
if isinstance(layer, nn.Linear):
# 使用ReLU的初始化参数GELU的近似
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里
# 或者使用Xavier初始化更适合通用场景
# nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu'))
nn.init.zeros_(layer.bias)
def forward(self, channel_features: torch.Tensor):
"""
输入格式: [batch_size, num_channels, embedding_dim]
输出格式: [batch_size, output_dim]
"""
# 自适应通道权重Sigmoid约束
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 特征拼接
combined = weighted_features.view(weighted_features.size(0), -1)
return self.fc(combined)
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
class AdaptiveRecallLoss(nn.Module):
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
"""
Args:
class_weights (torch.Tensor): 类别权重
alpha (float): 召回率调节因子0-1
gamma (float): Focal Loss参数
fp_penalty (float): 类别0假阳性惩罚强度
"""
super().__init__()
self.class_weights = class_weights
self.alpha = alpha
self.gamma = gamma
self.fp_penalty = fp_penalty
def forward(self, logits, targets):
# 基础交叉熵损失
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
# Focal Loss组件
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
# 召回率增强(对困难样本加权)
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
# 类别0假阳性惩罚
probs = F.softmax(logits, dim=1)
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
# 总损失
total_loss = recall_loss.mean() + fp_loss / len(targets)
return total_loss