24 lines
758 B
Python
24 lines
758 B
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class CompactPredictor(nn.Module):
|
|
def __init__(self, input_size):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.BatchNorm1d(input_size),
|
|
nn.Linear(input_size, 256),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Dropout(0.3),
|
|
nn.Linear(256, 128),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Dropout(0.2),
|
|
nn.Linear(128, 64),
|
|
nn.Tanh(), # 使用Tanh限制输出范围
|
|
nn.Linear(64, 1)
|
|
)
|
|
# 初始化最后一层为接近零的值
|
|
nn.init.uniform_(self.net[-1].weight, -0.01, 0.01)
|
|
nn.init.constant_(self.net[-1].bias, 0.0)
|
|
|
|
def forward(self, x):
|
|
return self.net(x) |