cvsa/pred/train.py

89 lines
3.6 KiB
Python

import random
import time
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torch
from dataset import VideoPlayDataset, collate_fn
from pred.model import CompactPredictor
def train(model, dataloader, device, epochs=100):
writer = SummaryWriter(f'./pred/runs/play_predictor_{time.strftime("%Y%m%d_%H%M")}')
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
total_steps=len(dataloader)*epochs)
criterion = torch.nn.MSELoss()
model.train()
global_step = 0
for epoch in range(epochs):
total_loss = 0.0
for batch_idx, batch in enumerate(dataloader):
features = batch['features'].to(device)
targets = batch['targets'].to(device)
optimizer.zero_grad()
outputs = model(features)
loss = criterion(outputs, targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
global_step += 1
if global_step % 100 == 0:
writer.add_scalar('Loss/train', loss.item(), global_step)
writer.add_scalar('LR', scheduler.get_last_lr()[0], global_step)
if batch_idx % 50 == 0:
# Monitor gradients
grad_norms = [
torch.norm(p.grad).item()
for p in model.parameters() if p.grad is not None
]
writer.add_scalar('Grad/Norm', sum(grad_norms)/len(grad_norms), global_step)
# Monitor parameter values
param_means = [torch.mean(p.data).item() for p in model.parameters()]
writer.add_scalar('Params/Mean', sum(param_means)/len(param_means), global_step)
samples_count = len(targets)
good = 0
for r in range(samples_count):
r = random.randint(0, samples_count-1)
t = float(torch.exp2(targets[r])) - 1
o = float(torch.exp2(outputs[r])) - 1
d = features[r].cpu().numpy()[0]
speed = np.exp2(features[r].cpu().numpy()[5]) / 24
time_diff = np.exp2(d) / 3600
inc = speed * time_diff
model_error = abs(t - o)
reg_error = abs(inc - t)
if model_error < reg_error:
good += 1
# print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {good/samples_count*100:.1f}%")
writer.add_scalar('Train/WinRate', good/samples_count, global_step)
print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(dataloader):.4f}")
writer.close()
return model
if __name__ == "__main__":
device = 'mps'
# Initialize dataset and model
dataset = VideoPlayDataset('./data/pred', './data/pred/publish_time.csv', 'short')
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
# Get feature dimension
sample = next(iter(dataloader))
input_size = sample['features'].shape[1]
model = CompactPredictor(input_size).to(device)
trained_model = train(model, dataloader, device, epochs=30)
# Save model
torch.save(trained_model.state_dict(), f"./pred/checkpoints/model_{time.strftime('%Y%m%d_%H%M')}.pt")