update: mps support for model training
This commit is contained in:
parent
0ed59f60d0
commit
a6319f4303
@ -32,7 +32,7 @@ def prepare_batch(batch_data, device="cpu"):
|
|||||||
|
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
def prepare_batch_per_token(session, tokenizer, batch_data, max_length=1024):
|
def prepare_batch_per_token(session, tokenizer, batch_data, device = 'cpu', max_length=1024):
|
||||||
"""
|
"""
|
||||||
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, seq_length, embedding_dim]。
|
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, seq_length, embedding_dim]。
|
||||||
|
|
||||||
@ -49,23 +49,23 @@ def prepare_batch_per_token(session, tokenizer, batch_data, max_length=1024):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
batch_size = len(batch_data["title"])
|
batch_size = len(batch_data["title"])
|
||||||
batch_tensor = torch.zeros(batch_size, 3, max_length, 256)
|
batch_tensor = torch.zeros(batch_size, 3, max_length, 256, device=device)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
channel_embeddings = torch.zeros((3, 1024, 256))
|
channel_embeddings = torch.zeros((3, 1024, 256), device=device)
|
||||||
for j, channel in enumerate(["title", "description", "tags"]):
|
for j, channel in enumerate(["title", "description", "tags"]):
|
||||||
# 获取当前通道的文本
|
# 获取当前通道的文本
|
||||||
text = batch_data[channel][i]
|
text = batch_data[channel][i]
|
||||||
encoded_inputs = tokenizer(text, truncation=True, max_length=max_length, return_tensors='np')
|
encoded_inputs = tokenizer(text, truncation=True, max_length=max_length, return_tensors='np')
|
||||||
|
|
||||||
# embeddings: [max_length, embedding_dim]
|
# embeddings: [max_length, embedding_dim]
|
||||||
embeddings = torch.zeros((1024, 256))
|
embeddings = torch.zeros((1024, 256), device=device)
|
||||||
for idx, token in enumerate(encoded_inputs['input_ids'][0]):
|
for idx, token in enumerate(encoded_inputs['input_ids'][0]):
|
||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": ort.OrtValue.ortvalue_from_numpy(np.array([token])),
|
"input_ids": ort.OrtValue.ortvalue_from_numpy(np.array([token])),
|
||||||
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array([0], dtype=np.int64))
|
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array([0], dtype=np.int64))
|
||||||
}
|
}
|
||||||
output = session.run(None, inputs)[0]
|
output = session.run(None, inputs)[0]
|
||||||
embeddings[idx] = torch.from_numpy(output)
|
embeddings[idx] = torch.from_numpy(output).to(device)
|
||||||
channel_embeddings[j] = embeddings
|
channel_embeddings[j] = embeddings
|
||||||
batch_tensor[i] = channel_embeddings
|
batch_tensor[i] = channel_embeddings
|
||||||
|
|
||||||
|
@ -1,68 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class VideoClassifierV6_0(nn.Module):
|
|
||||||
def __init__(self, embedding_dim=256, seq_length=1024, hidden_dim=512, output_dim=3):
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = 3
|
|
||||||
self.channel_names = ['title', 'description', 'tags']
|
|
||||||
|
|
||||||
# CNN特征提取层
|
|
||||||
self.conv_layers = nn.Sequential(
|
|
||||||
# 第一层卷积
|
|
||||||
nn.Conv2d(self.num_channels, 64, kernel_size=(3, 3), padding=1),
|
|
||||||
nn.BatchNorm2d(64),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.MaxPool2d(kernel_size=(2, 2)),
|
|
||||||
|
|
||||||
# 第二层卷积
|
|
||||||
nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
|
|
||||||
nn.BatchNorm2d(128),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.MaxPool2d(kernel_size=(2, 2)),
|
|
||||||
|
|
||||||
# 第三层卷积
|
|
||||||
nn.Conv2d(128, 256, kernel_size=(3, 3), padding=1),
|
|
||||||
nn.BatchNorm2d(256),
|
|
||||||
nn.GELU(),
|
|
||||||
|
|
||||||
# 全局平均池化层
|
|
||||||
# 输出形状为 [batch_size, 256, 1, 1]
|
|
||||||
nn.AdaptiveAvgPool2d((1, 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
# 全局池化后的特征维度固定为 256
|
|
||||||
self.feature_dim = 256
|
|
||||||
|
|
||||||
# 全连接层
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(self.feature_dim, hidden_dim),
|
|
||||||
nn.BatchNorm1d(hidden_dim),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(hidden_dim, output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._init_weights()
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
for module in self.modules():
|
|
||||||
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
|
|
||||||
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.zeros_(module.bias)
|
|
||||||
|
|
||||||
def forward(self, channel_features: torch.Tensor):
|
|
||||||
"""
|
|
||||||
输入格式: [batch_size, num_channels, seq_length, embedding_dim]
|
|
||||||
输出格式: [batch_size, output_dim]
|
|
||||||
"""
|
|
||||||
# CNN特征提取
|
|
||||||
conv_features = self.conv_layers(channel_features)
|
|
||||||
|
|
||||||
# 展平特征(全局池化后形状为 [batch_size, 256, 1, 1])
|
|
||||||
flat_features = conv_features.view(conv_features.size(0), -1) # [batch_size, 256]
|
|
||||||
|
|
||||||
# 全连接层分类
|
|
||||||
return self.fc(flat_features)
|
|
@ -1,9 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class VideoClassifierV6_1(nn.Module):
|
class VideoClassifierV6_1(nn.Module):
|
||||||
def __init__(self, embedding_dim=256, seq_length=1024, hidden_dim=256, output_dim=3, num_heads=4):
|
def __init__(self, embedding_dim=256, hidden_dim=256, output_dim=3, num_heads=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_channels = 3
|
self.num_channels = 3
|
||||||
self.channel_names = ['title', 'description', 'tags']
|
self.channel_names = ['title', 'description', 'tags']
|
||||||
|
@ -45,17 +45,20 @@ train_labels = []
|
|||||||
for batch in train_loader:
|
for batch in train_loader:
|
||||||
train_labels.extend(batch['label'].tolist())
|
train_labels.extend(batch['label'].tolist())
|
||||||
|
|
||||||
|
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
# 计算自适应类别权重
|
# 计算自适应类别权重
|
||||||
class_counts = np.bincount(train_labels)
|
class_counts = np.bincount(train_labels)
|
||||||
median_freq = np.median(class_counts)
|
median_freq = np.median(class_counts)
|
||||||
class_weights = torch.tensor(
|
class_weights = torch.tensor(
|
||||||
[median_freq / count for count in class_counts],
|
[median_freq / count for count in class_counts],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device='cpu'
|
device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
model = VideoClassifierV6_1()
|
model = VideoClassifierV6_1().to(device)
|
||||||
checkpoint_name = './filter/checkpoints/best_model_V6.2-test2.pt'
|
checkpoint_name = './filter/checkpoints/best_model_V6.2-mps.pt'
|
||||||
|
|
||||||
# 初始化tokenizer和embedding模型
|
# 初始化tokenizer和embedding模型
|
||||||
tokenizer = AutoTokenizer.from_pretrained("alikia2x/jina-embedding-v3-m2v-1024")
|
tokenizer = AutoTokenizer.from_pretrained("alikia2x/jina-embedding-v3-m2v-1024")
|
||||||
@ -73,7 +76,7 @@ optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
|
|||||||
cosine_annealing_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps - int(total_steps * warmup_rate))
|
cosine_annealing_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps - int(total_steps * warmup_rate))
|
||||||
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=int(total_steps * warmup_rate))
|
warmup_scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=int(total_steps * warmup_rate))
|
||||||
scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_annealing_scheduler], milestones=[int(total_steps * warmup_rate)])
|
scheduler = optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_annealing_scheduler], milestones=[int(total_steps * warmup_rate)])
|
||||||
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
criterion = nn.CrossEntropyLoss(weight=class_weights).to(device)
|
||||||
|
|
||||||
def count_trainable_parameters(model):
|
def count_trainable_parameters(model):
|
||||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
@ -85,11 +88,11 @@ def evaluate(model, dataloader):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch_tensor = prepare_batch_per_token(session, tokenizer, batch['texts'])
|
batch_tensor = prepare_batch_per_token(session, tokenizer, batch['texts']).to(device)
|
||||||
logits = model(batch_tensor)
|
logits = model(batch_tensor)
|
||||||
preds = torch.argmax(logits, dim=1)
|
preds = torch.argmax(logits, dim=1)
|
||||||
all_preds.extend(preds.cpu().numpy())
|
all_preds.extend(preds.cpu().numpy())
|
||||||
all_labels.extend(batch['label'].cpu().numpy())
|
all_labels.extend(batch['label'].to(device).cpu().numpy())
|
||||||
|
|
||||||
# 计算每个类别的 F1、Recall、Precision 和 Accuracy
|
# 计算每个类别的 F1、Recall、Precision 和 Accuracy
|
||||||
f1 = f1_score(all_labels, all_preds, average='weighted')
|
f1 = f1_score(all_labels, all_preds, average='weighted')
|
||||||
@ -117,11 +120,11 @@ for epoch in range(num_epochs):
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
batch_tensor = prepare_batch_per_token(session, tokenizer, batch['texts'])
|
batch_tensor = prepare_batch_per_token(session, tokenizer, batch['texts']).to(device)
|
||||||
|
|
||||||
logits = model(batch_tensor)
|
logits = model(batch_tensor)
|
||||||
|
|
||||||
loss = criterion(logits, batch['label'])
|
loss = criterion(logits, batch['label'].to(device))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
epoch_loss += loss.item()
|
epoch_loss += loss.item()
|
||||||
@ -187,6 +190,7 @@ for epoch in range(num_epochs):
|
|||||||
# 测试阶段
|
# 测试阶段
|
||||||
print("\nTesting...")
|
print("\nTesting...")
|
||||||
model.load_state_dict(torch.load(checkpoint_name))
|
model.load_state_dict(torch.load(checkpoint_name))
|
||||||
|
model.to(device)
|
||||||
test_f1, test_recall, test_precision, test_accuracy, test_class_report = evaluate(model, test_loader)
|
test_f1, test_recall, test_precision, test_accuracy, test_class_report = evaluate(model, test_loader)
|
||||||
writer.add_scalar('Test/F1', test_f1, step)
|
writer.add_scalar('Test/F1', test_f1, step)
|
||||||
writer.add_scalar('Test/Recall', test_recall, step)
|
writer.add_scalar('Test/Recall', test_recall, step)
|
||||||
|
Loading…
Reference in New Issue
Block a user