diff --git a/filter/RunningLogs.txt b/filter/RunningLogs.txt index 7dd0260..93309b3 100644 --- a/filter/RunningLogs.txt +++ b/filter/RunningLogs.txt @@ -13,4 +13,6 @@ Note 0133: V4.2-test3 0138: V4.3-test3 0155: V5-test3 # V4 的效果也不是特别好 -0229: V3.3-test3 # 重新回到V3迭代 \ No newline at end of file +0229: V3.3-test3 # 重新回到V3迭代 +0316: V3.4-test3 # 模型架构修改,自定义Loss与FC层修改 +0324: V3.5-test3 # 用回3.2的FC层试试 \ No newline at end of file diff --git a/filter/modelV3_4.py b/filter/modelV3_4.py new file mode 100644 index 0000000..4972d51 --- /dev/null +++ b/filter/modelV3_4.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class VideoClassifierV3_4(nn.Module): + def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3): + super().__init__() + self.num_channels = 4 + self.channel_names = ['title', 'description', 'tags', 'author_info'] + + # 可学习温度系数 + self.temperature = nn.Parameter(torch.tensor(1.7)) + + # 带约束的通道权重(使用Sigmoid替代Softmax) + self.channel_weights = nn.Parameter(torch.ones(self.num_channels)) + + # 增强的非线性层 + self.fc = nn.Sequential( + nn.Linear(embedding_dim * self.num_channels, hidden_dim*2), + nn.BatchNorm1d(hidden_dim*2), + nn.Dropout(0.3), + nn.GELU(), + nn.Linear(hidden_dim*2, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.Dropout(0.2), + nn.GELU(), + nn.Linear(hidden_dim, output_dim) + ) + + # 权重初始化 + self._init_weights() + + def _init_weights(self): + for layer in self.fc: + if isinstance(layer, nn.Linear): + # 使用ReLU的初始化参数(GELU的近似) + nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里 + + # 或者使用Xavier初始化(更适合通用场景) + # nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu')) + + nn.init.zeros_(layer.bias) + + + def forward(self, input_texts, sentence_transformer): + # 合并文本进行批量编码 + all_texts = [text for channel in self.channel_names for text in input_texts[channel]] + + # 冻结的文本编码 + with torch.no_grad(): + embeddings = torch.tensor( + sentence_transformer.encode(all_texts), + device=next(self.parameters()).device + ) + + # 分割并加权通道特征 + split_sizes = [len(input_texts[name]) for name in self.channel_names] + channel_features = torch.split(embeddings, split_sizes, dim=0) + channel_features = torch.stack(channel_features, dim=1) + + # 自适应通道权重(Sigmoid约束) + weights = torch.sigmoid(self.channel_weights) # [0,1]范围 + weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1) + + # 特征拼接 + combined = weighted_features.view(weighted_features.size(0), -1) + + return self.fc(combined) + + def get_channel_weights(self): + """获取各通道权重(带温度调节)""" + return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy() + + +class AdaptiveRecallLoss(nn.Module): + def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5): + """ + Args: + class_weights (torch.Tensor): 类别权重 + alpha (float): 召回率调节因子(0-1) + gamma (float): Focal Loss参数 + fp_penalty (float): 类别0假阳性惩罚强度 + """ + super().__init__() + self.class_weights = class_weights + self.alpha = alpha + self.gamma = gamma + self.fp_penalty = fp_penalty + + def forward(self, logits, targets): + # 基础交叉熵损失 + ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none') + + # Focal Loss组件 + pt = torch.exp(-ce_loss) + focal_loss = ((1 - pt) ** self.gamma) * ce_loss + + # 召回率增强(对困难样本加权) + class_mask = F.one_hot(targets, num_classes=len(self.class_weights)) + class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask + recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1) + + # 类别0假阳性惩罚 + probs = F.softmax(logits, dim=1) + fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0) + fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum() + + # 总损失 + total_loss = recall_loss.mean() + fp_loss / len(targets) + + return total_loss \ No newline at end of file diff --git a/filter/modelV3_5.py b/filter/modelV3_5.py new file mode 100644 index 0000000..905e079 --- /dev/null +++ b/filter/modelV3_5.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class VideoClassifierV3_5(nn.Module): + def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3): + super().__init__() + self.num_channels = 4 + self.channel_names = ['title', 'description', 'tags', 'author_info'] + + # 可学习温度系数 + self.temperature = nn.Parameter(torch.tensor(1.7)) + + # 带约束的通道权重(使用Sigmoid替代Softmax) + self.channel_weights = nn.Parameter(torch.ones(self.num_channels)) + + # 增强的非线性层 + self.fc = nn.Sequential( + nn.Linear(embedding_dim * self.num_channels, hidden_dim*2), + nn.BatchNorm1d(hidden_dim*2), + nn.Dropout(0.1), + nn.ReLU(), + nn.Linear(hidden_dim*2, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, output_dim) + ) + + # 权重初始化 + self._init_weights() + + def _init_weights(self): + for layer in self.fc: + if isinstance(layer, nn.Linear): + # 使用ReLU的初始化参数(GELU的近似) + nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里 + + # 或者使用Xavier初始化(更适合通用场景) + # nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu')) + + nn.init.zeros_(layer.bias) + + + def forward(self, input_texts, sentence_transformer): + # 合并文本进行批量编码 + all_texts = [text for channel in self.channel_names for text in input_texts[channel]] + + # 冻结的文本编码 + with torch.no_grad(): + embeddings = torch.tensor( + sentence_transformer.encode(all_texts), + device=next(self.parameters()).device + ) + + # 分割并加权通道特征 + split_sizes = [len(input_texts[name]) for name in self.channel_names] + channel_features = torch.split(embeddings, split_sizes, dim=0) + channel_features = torch.stack(channel_features, dim=1) + + # 自适应通道权重(Sigmoid约束) + weights = torch.sigmoid(self.channel_weights) # [0,1]范围 + weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1) + + # 特征拼接 + combined = weighted_features.view(weighted_features.size(0), -1) + + return self.fc(combined) + + def get_channel_weights(self): + """获取各通道权重(带温度调节)""" + return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy() + + +class AdaptiveRecallLoss(nn.Module): + def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5): + """ + Args: + class_weights (torch.Tensor): 类别权重 + alpha (float): 召回率调节因子(0-1) + gamma (float): Focal Loss参数 + fp_penalty (float): 类别0假阳性惩罚强度 + """ + super().__init__() + self.class_weights = class_weights + self.alpha = alpha + self.gamma = gamma + self.fp_penalty = fp_penalty + + def forward(self, logits, targets): + # 基础交叉熵损失 + ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none') + + # Focal Loss组件 + pt = torch.exp(-ce_loss) + focal_loss = ((1 - pt) ** self.gamma) * ce_loss + + # 召回率增强(对困难样本加权) + class_mask = F.one_hot(targets, num_classes=len(self.class_weights)) + class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask + recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1) + + # 类别0假阳性惩罚 + probs = F.softmax(logits, dim=1) + fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0) + fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum() + + # 总损失 + total_loss = recall_loss.mean() + fp_loss / len(targets) + + return total_loss \ No newline at end of file diff --git a/filter/train.py b/filter/train.py index f4b7ee1..38c4902 100644 --- a/filter/train.py +++ b/filter/train.py @@ -1,11 +1,11 @@ import os os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1" +import numpy as np from torch.utils.data import DataLoader import torch.optim as optim from dataset import MultiChannelDataset -from modelV3_3 import VideoClassifierV3_3 +from modelV3_5 import VideoClassifierV3_5, AdaptiveRecallLoss from sentence_transformers import SentenceTransformer -import torch.nn as nn from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report import os import torch @@ -37,17 +37,35 @@ train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True) eval_loader = DataLoader(eval_dataset, batch_size=24, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False) +train_labels = [] +for batch in train_loader: + train_labels.extend(batch['label'].tolist()) + +# 计算自适应类别权重 +class_counts = np.bincount(train_labels) +median_freq = np.median(class_counts) +class_weights = torch.tensor( + [median_freq / count for count in class_counts], + dtype=torch.float32, + device='cpu' +) + # 初始化模型和SentenceTransformer sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024") -model = VideoClassifierV3_3() -checkpoint_name = './filter/checkpoints/best_model_V3.3.pt' +model = VideoClassifierV3_5() +checkpoint_name = './filter/checkpoints/best_model_V3.5.pt' # 模型保存路径 os.makedirs('./filter/checkpoints', exist_ok=True) # 优化器 optimizer = optim.AdamW(model.parameters(), lr=2e-3) -criterion = nn.CrossEntropyLoss() +criterion = AdaptiveRecallLoss( + class_weights=class_weights, + alpha=0.7, # 召回率权重 + gamma=1.5, # 困难样本聚焦 + fp_penalty=0.8 # 假阳性惩罚强度 +) def count_trainable_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -81,7 +99,7 @@ print(f"Trainable parameters: {count_trainable_parameters(model)}") # 训练循环 best_f1 = 0 step = 0 -eval_interval = 50 +eval_interval = 20 num_epochs = 8 for epoch in range(num_epochs):