cvsa/filter/modelV6_0.py

93 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV6_0(nn.Module):
def __init__(self, embedding_dim=256, seq_length=1024, hidden_dim=512, output_dim=3):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# CNN特征提取层
self.conv_layers = nn.Sequential(
# 第一层卷积
nn.Conv2d(self.num_channels, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.GELU(),
nn.MaxPool2d(kernel_size=(2, 2)),
# 第二层卷积
nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(128),
nn.GELU(),
nn.MaxPool2d(kernel_size=(2, 2)),
# 第三层卷积
nn.Conv2d(128, 256, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(256),
nn.GELU(),
# 全局平均池化层
# 输出形状为 [batch_size, 256, 1, 1]
nn.AdaptiveAvgPool2d((1, 1))
)
# 全局池化后的特征维度固定为 256
self.feature_dim = 256
# 全连接层
self.fc = nn.Sequential(
nn.Linear(self.feature_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, channel_features: torch.Tensor):
"""
输入格式: [batch_size, num_channels, seq_length, embedding_dim]
输出格式: [batch_size, output_dim]
"""
# CNN特征提取
conv_features = self.conv_layers(channel_features)
# 展平特征(全局池化后形状为 [batch_size, 256, 1, 1]
flat_features = conv_features.view(conv_features.size(0), -1) # [batch_size, 256]
# 全连接层分类
return self.fc(flat_features)
# 损失函数保持不变
class AdaptiveRecallLoss(nn.Module):
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
super().__init__()
self.class_weights = class_weights
self.alpha = alpha
self.gamma = gamma
self.fp_penalty = fp_penalty
def forward(self, logits, targets):
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
probs = F.softmax(logits, dim=1)
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
total_loss = recall_loss.mean() + fp_loss / len(targets)
return total_loss