update: filter model V3.3
This commit is contained in:
parent
6c5dfaae8b
commit
5a83120ad6
@ -12,4 +12,5 @@ Note
|
|||||||
0125: V4.1-test3
|
0125: V4.1-test3
|
||||||
0133: V4.2-test3
|
0133: V4.2-test3
|
||||||
0138: V4.3-test3
|
0138: V4.3-test3
|
||||||
0155: V5-test3 # V4 的效果也不是特别好
|
0155: V5-test3 # V4 的效果也不是特别好
|
||||||
|
0229: V3.3-test3 # 重新回到V3迭代
|
@ -1,47 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class VideoClassifierV1_5(nn.Module):
|
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=256, output_dim=3):
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = 4
|
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
|
||||||
|
|
||||||
# 通道权重参数(可学习)
|
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
|
||||||
|
|
||||||
# 全连接层
|
|
||||||
self.fc1 = nn.Linear(embedding_dim * self.num_channels, hidden_dim)
|
|
||||||
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
|
||||||
self.log_softmax = nn.LogSoftmax(dim=1)
|
|
||||||
|
|
||||||
def forward(self, input_texts, sentence_transformer):
|
|
||||||
# 各通道特征提取
|
|
||||||
channel_features = []
|
|
||||||
for _, name in enumerate(self.channel_names):
|
|
||||||
# 获取当前通道的批量文本
|
|
||||||
batch_texts = input_texts[name]
|
|
||||||
|
|
||||||
# 使用SentenceTransformer生成嵌入
|
|
||||||
embeddings = torch.tensor(
|
|
||||||
sentence_transformer.encode(batch_texts, task="classification")
|
|
||||||
)
|
|
||||||
channel_features.append(embeddings)
|
|
||||||
|
|
||||||
# 将通道特征堆叠并加权
|
|
||||||
channel_features = torch.stack(channel_features, dim=1) # [batch_size, num_channels, embedding_dim]
|
|
||||||
channel_weights = torch.softmax(self.channel_weights, dim=0)
|
|
||||||
weighted_features = channel_features * channel_weights.unsqueeze(0).unsqueeze(-1)
|
|
||||||
|
|
||||||
# 拼接所有通道特征
|
|
||||||
combined_features = weighted_features.view(weighted_features.size(0), -1) # [batch_size, num_channels * embedding_dim]
|
|
||||||
|
|
||||||
# 全连接层
|
|
||||||
x = torch.relu(self.fc1(combined_features))
|
|
||||||
output = self.fc2(x)
|
|
||||||
output = self.log_softmax(output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def get_channel_weights(self):
|
|
||||||
"""获取各通道的权重(用于解释性分析)"""
|
|
||||||
return torch.softmax(self.channel_weights, dim=0).detach().cpu().numpy()
|
|
@ -1,28 +1,26 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class VideoClassifierV5(nn.Module):
|
class VideoClassifierV3_3(nn.Module):
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=640, output_dim=3):
|
def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_channels = 4
|
self.num_channels = 4
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
||||||
|
|
||||||
# 改进1:带温度系数的通道权重(比原始固定权重更灵活)
|
# 带温度系数的通道权重(比原始固定权重更灵活)
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
||||||
self.temperature = 1.4 # 可调节的平滑系数
|
self.temperature = 1.7 # 可调节的平滑系数
|
||||||
|
|
||||||
# 改进2:更稳健的全连接结构
|
# 改进后的非线性层
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
||||||
nn.BatchNorm1d(hidden_dim*2),
|
nn.BatchNorm1d(hidden_dim*2),
|
||||||
nn.Dropout(0.1),
|
nn.Dropout(0.1),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(hidden_dim*2, hidden_dim),
|
nn.Linear(hidden_dim*2, output_dim)
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.Linear(hidden_dim, output_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 改进3:输出层初始化
|
# 输出层初始化
|
||||||
nn.init.xavier_uniform_(self.fc[-1].weight)
|
nn.init.xavier_uniform_(self.fc[-1].weight)
|
||||||
nn.init.zeros_(self.fc[-1].bias)
|
nn.init.zeros_(self.fc[-1].bias)
|
||||||
|
|
||||||
@ -55,8 +53,4 @@ class VideoClassifierV5(nn.Module):
|
|||||||
|
|
||||||
def get_channel_weights(self):
|
def get_channel_weights(self):
|
||||||
"""获取各通道权重(带温度调节)"""
|
"""获取各通道权重(带温度调节)"""
|
||||||
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
||||||
|
|
||||||
def set_temperature(self, temperature):
|
|
||||||
"""设置温度值"""
|
|
||||||
self.temperature = temperature
|
|
@ -3,7 +3,7 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from dataset import MultiChannelDataset
|
from dataset import MultiChannelDataset
|
||||||
from modelV5 import VideoClassifierV5
|
from modelV3_3 import VideoClassifierV3_3
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
||||||
@ -39,8 +39,8 @@ test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False)
|
|||||||
|
|
||||||
# 初始化模型和SentenceTransformer
|
# 初始化模型和SentenceTransformer
|
||||||
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
||||||
model = VideoClassifierV5()
|
model = VideoClassifierV3_3()
|
||||||
checkpoint_name = './filter/checkpoints/best_model_V5.pt'
|
checkpoint_name = './filter/checkpoints/best_model_V3.3.pt'
|
||||||
|
|
||||||
# 模型保存路径
|
# 模型保存路径
|
||||||
os.makedirs('./filter/checkpoints', exist_ok=True)
|
os.makedirs('./filter/checkpoints', exist_ok=True)
|
||||||
@ -84,19 +84,12 @@ step = 0
|
|||||||
eval_interval = 50
|
eval_interval = 50
|
||||||
num_epochs = 8
|
num_epochs = 8
|
||||||
|
|
||||||
total_steps = num_epochs * len(train_loader) # 总训练步数
|
|
||||||
T_max = 1.4 # 初始温度
|
|
||||||
T_min = 0.15 # 最终温度
|
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
model.train()
|
model.train()
|
||||||
epoch_loss = 0
|
epoch_loss = 0
|
||||||
|
|
||||||
# 训练阶段
|
# 训练阶段
|
||||||
for batch_idx, batch in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
temperature = T_max - (T_max - T_min) * (step / total_steps)
|
|
||||||
model.set_temperature(temperature)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# 传入文本字典和sentence_transformer
|
# 传入文本字典和sentence_transformer
|
||||||
|
Loading…
Reference in New Issue
Block a user