cvsa/filter/modelV3_12.py

80 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV3_12(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=648):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 可学习温度系数
self.temperature = nn.Parameter(torch.tensor(1.7))
# 带约束的通道权重使用Sigmoid替代Softmax
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
# 第一个二分类器0 vs 1/2
self.first_classifier = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim*2, 2) # 输出为2类0 vs 1/2
)
# 第二个二分类器1 vs 2
self.second_classifier = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim*2, 2) # 输出为2类1 vs 2
)
# 权重初始化
self._init_weights()
def _init_weights(self):
for layer in self.first_classifier:
if isinstance(layer, nn.Linear):
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
nn.init.zeros_(layer.bias)
for layer in self.second_classifier:
if isinstance(layer, nn.Linear):
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
nn.init.zeros_(layer.bias)
def forward(self, channel_features: torch.Tensor):
"""
输入格式: [batch_size, num_channels, embedding_dim]
输出格式: [batch_size, output_dim]
"""
# 自适应通道权重Sigmoid约束
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 特征拼接
combined = weighted_features.view(weighted_features.size(0), -1)
# 第一个二分类器0 vs 1/2
first_output = self.first_classifier(combined)
first_probs = F.softmax(first_output, dim=1)
# 第二个二分类器1 vs 2
second_output = self.second_classifier(combined)
second_probs = F.softmax(second_output, dim=1)
# 合并结果
final_probs = torch.zeros(channel_features.size(0), 3).to(channel_features.device)
final_probs[:, 0] = first_probs[:, 0] # 类别0的概率
final_probs[:, 1] = first_probs[:, 1] * second_probs[:, 0] # 类别1的概率
final_probs[:, 2] = first_probs[:, 1] * second_probs[:, 1] # 类别2的概率
return final_probs
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()