cvsa/filter/modelV6_3.py

112 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
class VideoClassifierV6_3(nn.Module):
def __init__(self, embedding_dim=256, hidden_dim=256, output_dim=3, num_heads=4, num_layers=2):
super().__init__()
self.num_channels = 3
self.channel_names = ['title', 'description', 'tags']
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# 通道独立处理模块每个通道独立的Transformer编码器
self.channel_processors = nn.ModuleList()
for _ in range(self.num_channels):
layers = []
# 首先将输入维度转换为hidden_dim
layers.extend([
nn.Linear(embedding_dim, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim)
])
# 添加num_layers层的Transformer块
for _ in range(num_layers):
layers.extend([
# 自注意力层使用hidden_dim作为embed_dim
nn.MultiheadAttention(
embed_dim=hidden_dim, # 修改为hidden_dim
num_heads=num_heads,
dropout=0.1
),
nn.LayerNorm(hidden_dim),
# 前馈网络部分
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim)
])
self.channel_processors.append(nn.Sequential(*layers))
# 通道权重可学习Sigmoid约束
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
# 全连接层(扩展维度)
self.fc = nn.Sequential(
nn.Linear(self.num_channels * hidden_dim, 1024),
nn.BatchNorm1d(1024),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(512, output_dim)
)
self._init_weights()
def _init_weights(self):
"""权重初始化Xavier初始化"""
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.MultiheadAttention):
# 初始化MultiheadAttention的参数输入投影和输出投影
for name, param in m.named_parameters():
if "in_proj" in name or "out_proj" in name:
if "weight" in name:
nn.init.xavier_uniform_(param)
elif "bias" in name:
nn.init.zeros_(param)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
def forward(self, channel_features: torch.Tensor):
"""
输入格式: [batch_size, num_channels, seq_length, embedding_dim]
输出格式: [batch_size, output_dim]
"""
batch_size = channel_features.size(0)
processed_channels = []
for c in range(self.num_channels):
c_data = channel_features[:, c].permute(1, 0, 2) # 转为 [S, B, E]
# 通道独立处理
x = c_data
for layer in self.channel_processors[c]:
if isinstance(layer, nn.MultiheadAttention):
# 自注意力层需要显式提供键、值
x = layer(x, x, x)[0]
else:
x = layer(x)
# 转换回 [B, S, hidden_dim] 并全局平均池化
x = x.permute(1, 0, 2)
pooled = x.mean(dim=1)
processed_channels.append(pooled)
# 堆叠通道特征
processed_channels = torch.stack(processed_channels, dim=1)
# 应用通道权重Sigmoid约束
weights = torch.sigmoid(self.channel_weights).view(1, -1, 1)
weighted_features = processed_channels * weights
# 拼接所有通道特征
combined = weighted_features.view(batch_size, -1)
return self.fc(combined)