54 lines
2.0 KiB
Python
54 lines
2.0 KiB
Python
# model.py
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from training.config import DIMENSIONS
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, input_dim, heads):
|
|
super(SelfAttention, self).__init__()
|
|
self.heads = heads
|
|
self.scale = (input_dim // heads) ** -0.5
|
|
self.qkv = nn.Linear(input_dim, input_dim * 3)
|
|
self.fc = nn.Linear(input_dim, input_dim)
|
|
|
|
def forward(self, x):
|
|
batch_size, seq_length, embedding_dim = x.shape
|
|
qkv = self.qkv(x).view(
|
|
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
|
|
)
|
|
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
|
|
q = q.permute(0, 2, 1, 3)
|
|
k = k.permute(0, 2, 1, 3)
|
|
v = v.permute(0, 2, 1, 3)
|
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
attention_output = torch.matmul(attn_weights, v)
|
|
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
|
|
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
|
|
return self.fc(attention_output)
|
|
|
|
|
|
class AttentionBasedModel(nn.Module):
|
|
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512, num_layers=3):
|
|
super(AttentionBasedModel, self).__init__()
|
|
self.self_attention_layers = nn.ModuleList([
|
|
SelfAttention(input_dim, heads) for _ in range(num_layers)
|
|
])
|
|
self.fc1 = nn.Linear(input_dim, dim_feedforward)
|
|
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
|
self.dropout = nn.Dropout(0.5)
|
|
self.norm = nn.LayerNorm(input_dim)
|
|
|
|
def forward(self, x):
|
|
for attn_layer in self.self_attention_layers:
|
|
attn_output = attn_layer(x)
|
|
x = self.norm(attn_output + x)
|
|
pooled_output = torch.mean(x, dim=1)
|
|
x = F.relu(self.fc1(pooled_output))
|
|
x = self.dropout(x)
|
|
x = self.fc2(x)
|
|
return x
|