update: the stable pred model
This commit is contained in:
parent
296e4ef4d6
commit
1de8d85d2b
2
.gitignore
vendored
2
.gitignore
vendored
@ -79,6 +79,8 @@ node_modules/
|
||||
logs/
|
||||
__pycache__
|
||||
filter/runs
|
||||
pred/runs
|
||||
pred/checkpoints
|
||||
data/
|
||||
filter/checkpoints
|
||||
scripts
|
||||
|
65
pred/1
65
pred/1
@ -1,65 +0,0 @@
|
||||
1151133233
|
||||
61967870
|
||||
977494472
|
||||
891815995
|
||||
375265162
|
||||
341686360
|
||||
2228953
|
||||
1951059019
|
||||
799277283
|
||||
844610791
|
||||
1706212240
|
||||
339432
|
||||
243913657
|
||||
16576108
|
||||
583566710
|
||||
802536340
|
||||
2976394
|
||||
8321047
|
||||
261045912
|
||||
381806
|
||||
1203136639
|
||||
316228425
|
||||
257550414
|
||||
242976248
|
||||
9230106
|
||||
517962327
|
||||
752662232
|
||||
771373147
|
||||
63924898
|
||||
221567994
|
||||
840428043
|
||||
78978783
|
||||
24990703
|
||||
820756
|
||||
27171791
|
||||
80473511
|
||||
847707089
|
||||
418226861
|
||||
11757544
|
||||
232040007
|
||||
2371972
|
||||
84183673
|
||||
829450
|
||||
6844720
|
||||
39129342
|
||||
1203992885
|
||||
800408956
|
||||
316720732
|
||||
33139201
|
||||
860855406
|
||||
4497808
|
||||
25432055
|
||||
7366509
|
||||
40841777
|
||||
1656397450
|
||||
371830092
|
||||
799978121
|
||||
50897913
|
||||
674312444
|
||||
651329836
|
||||
875035826
|
||||
469433434
|
||||
58814955
|
||||
33044780
|
||||
946091445
|
@ -1,15 +1,19 @@
|
||||
import os
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
|
||||
with open("1", "r") as fp:
|
||||
with open("./pred/2", "r") as fp:
|
||||
raw = fp.readlines()
|
||||
aids = [ int(x.strip()) for x in raw ]
|
||||
|
||||
for aid in aids:
|
||||
if os.path.exists(f"./data/pred/{aid}.json"):
|
||||
continue
|
||||
url = f"https://api.bunnyxt.com/tdd/v2/video/{aid}/record?last_count=5000"
|
||||
r = requests.get(url)
|
||||
data = r.json()
|
||||
with open (f"./data/pred/{aid}.json", "w") as fp:
|
||||
json.dump(data, fp, ensure_ascii=False, indent=4)
|
||||
time.sleep(5)
|
||||
print(aid)
|
241
pred/dataset.py
241
pred/dataset.py
@ -1,206 +1,109 @@
|
||||
# dataset.py
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import datetime
|
||||
|
||||
class VideoPlayDataset(Dataset):
|
||||
def __init__(self, data_dir, publish_time_path,
|
||||
min_seq_len=6, max_seq_len=200,
|
||||
min_forecast_span=60, max_forecast_span=604800):
|
||||
"""
|
||||
改进后的数据集类,支持非等间隔时间序列
|
||||
:param data_dir: JSON文件目录
|
||||
:param publish_time_path: 发布时间CSV路径
|
||||
:param min_seq_len: 最小历史数据点数
|
||||
:param max_seq_len: 最大历史数据点数
|
||||
:param min_forecast_span: 最小预测时间跨度(秒)
|
||||
:param max_forecast_span: 最大预测时间跨度(秒)
|
||||
"""
|
||||
def __init__(self, data_dir, publish_time_path, max_future_days=7):
|
||||
self.data_dir = data_dir
|
||||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.min_forecast_span = min_forecast_span
|
||||
self.max_forecast_span = max_forecast_span
|
||||
self.series_dict = self._load_and_process_data(data_dir, publish_time_path)
|
||||
self.valid_series = self._generate_valid_series()
|
||||
self.max_future_seconds = max_future_days * 86400
|
||||
self.series_dict = self._load_and_process_data(publish_time_path)
|
||||
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
||||
self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600] # 1h,6h,24h,3d,7d
|
||||
|
||||
def _load_and_process_data(self, data_dir, publish_time_path):
|
||||
def _extract_features(self, series, current_idx, target_idx):
|
||||
"""提取增量特征"""
|
||||
current_time = series['abs_time'][current_idx]
|
||||
current_play = series['play_count'][current_idx]
|
||||
dt = datetime.datetime.fromtimestamp(current_time)
|
||||
# 时间特征
|
||||
time_features = [
|
||||
dt.hour / 24, (dt.weekday() + 1) / 7,
|
||||
np.log2(max(current_time - series['create_time'],1))
|
||||
]
|
||||
|
||||
# 窗口增长特征(增量)
|
||||
growth_features = []
|
||||
for window in self.feature_windows:
|
||||
prev_time = current_time - window
|
||||
prev_idx = self._get_nearest_value(series, prev_time, current_idx)
|
||||
if prev_idx is not None:
|
||||
time_diff = current_time - series['abs_time'][prev_idx]
|
||||
play_diff = current_play - series['play_count'][prev_idx]
|
||||
scaled_diff = play_diff / (time_diff / window) if time_diff > 0 else 0.0
|
||||
else:
|
||||
scaled_diff = 0.0
|
||||
growth_features.append(np.log2(max(scaled_diff,1)))
|
||||
|
||||
time_diff = series['abs_time'][target_idx] - series['abs_time'][current_idx]
|
||||
|
||||
return [np.log2(max(time_diff,1))] + [np.log2(current_play + 1)] + growth_features + time_features
|
||||
|
||||
def _load_and_process_data(self, publish_time_path):
|
||||
# 加载发布时间数据
|
||||
publish_df = pd.read_csv(publish_time_path)
|
||||
publish_df['published_at'] = pd.to_datetime(publish_df['published_at'])
|
||||
publish_dict = dict(zip(publish_df['aid'], publish_df['published_at']))
|
||||
|
||||
# 加载并处理JSON数据
|
||||
series_dict = {}
|
||||
for filename in os.listdir(data_dir):
|
||||
for filename in os.listdir(self.data_dir):
|
||||
if not filename.endswith('.json'):
|
||||
continue
|
||||
filepath = os.path.join(data_dir, filename)
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
for item in json_data:
|
||||
aid = item['aid']
|
||||
if aid not in publish_dict:
|
||||
with open(os.path.join(self.data_dir, filename), 'r') as f:
|
||||
data = json.load(f)
|
||||
if 'code' in data:
|
||||
continue
|
||||
|
||||
# 计算相对时间
|
||||
added_time = datetime.fromtimestamp(item['added'])
|
||||
published_time = publish_dict[aid]
|
||||
rel_time = (added_time - published_time).total_seconds()
|
||||
|
||||
# 按视频组织数据
|
||||
for item in data:
|
||||
aid = item['aid']
|
||||
published_time = pd.to_datetime(publish_dict[aid]).timestamp()
|
||||
if aid not in series_dict:
|
||||
series_dict[aid] = {
|
||||
'abs_time': [],
|
||||
'rel_time': [],
|
||||
'play_count': []
|
||||
'play_count': [],
|
||||
'create_time': published_time
|
||||
}
|
||||
|
||||
series_dict[aid]['abs_time'].append(item['added'])
|
||||
series_dict[aid]['rel_time'].append(rel_time)
|
||||
series_dict[aid]['play_count'].append(item['view'])
|
||||
|
||||
# 按时间排序并计算时间间隔
|
||||
for aid in series_dict:
|
||||
# 按时间排序
|
||||
sorted_idx = np.argsort(series_dict[aid]['abs_time'])
|
||||
for key in ['abs_time', 'rel_time', 'play_count']:
|
||||
series_dict[aid][key] = np.array(series_dict[aid][key])[sorted_idx]
|
||||
|
||||
# 计算时间间隔特征
|
||||
abs_time_arr = series_dict[aid]['abs_time']
|
||||
time_deltas = np.diff(abs_time_arr, prepend=abs_time_arr[0])
|
||||
series_dict[aid]['time_delta'] = time_deltas
|
||||
|
||||
return series_dict
|
||||
|
||||
def _generate_valid_series(self):
|
||||
# 生成有效数据序列
|
||||
valid_series = []
|
||||
for aid in self.series_dict:
|
||||
series = self.series_dict[aid]
|
||||
n_points = len(series['play_count'])
|
||||
|
||||
# 过滤数据量不足的视频
|
||||
if n_points < self.min_seq_len + 1:
|
||||
continue
|
||||
|
||||
valid_series.append({
|
||||
'aid': aid,
|
||||
'length': n_points,
|
||||
'abs_time': series['abs_time'],
|
||||
'rel_time': series['rel_time'],
|
||||
'play_count': series['play_count'],
|
||||
'time_delta': series['time_delta']
|
||||
})
|
||||
return valid_series
|
||||
|
||||
def __len__(self):
|
||||
return sum(s['length'] - self.min_seq_len for s in self.valid_series)
|
||||
return 100000 # 使用虚拟长度实现无限采样
|
||||
|
||||
def _get_nearest_value(self, series, target_time, current_idx):
|
||||
"""获取指定时间前最近的数据点"""
|
||||
min_diff = float('inf')
|
||||
for i in range(current_idx + 1, len(series['abs_time']), 1):
|
||||
diff = abs(series['abs_time'][i] - target_time)
|
||||
if diff < min_diff:
|
||||
min_diff = diff
|
||||
else:
|
||||
return i - 1
|
||||
return None
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# 随机选择视频序列
|
||||
series = random.choice(self.valid_series)
|
||||
max_start = series['length'] - self.min_seq_len - 1
|
||||
start_idx = random.randint(0, max_start)
|
||||
current_idx = random.randint(0, len(series['abs_time'])-2)
|
||||
target_idx = random.randint(max(0, current_idx-50), current_idx)
|
||||
|
||||
# 动态确定历史窗口长度
|
||||
seq_len = random.randint(self.min_seq_len, min(self.max_seq_len, series['length'] - start_idx - 1))
|
||||
end_idx = start_idx + seq_len
|
||||
# 提取特征
|
||||
features = self._extract_features(series, current_idx, target_idx)
|
||||
|
||||
# 提取历史窗口特征
|
||||
hist_slice = slice(start_idx, end_idx)
|
||||
x_play = np.log1p(series['play_count'][hist_slice])
|
||||
x_abs_time = series['abs_time'][hist_slice]
|
||||
x_rel_time = series['rel_time'][hist_slice]
|
||||
x_time_delta = series['time_delta'][hist_slice]
|
||||
|
||||
# 生成预测目标(动态时间跨度)
|
||||
forecast_span = random.randint(self.min_forecast_span, self.max_forecast_span)
|
||||
target_time = x_abs_time[-1] + forecast_span
|
||||
|
||||
# 寻找实际目标点(处理数据间隙)
|
||||
future_times = series['abs_time'][end_idx:]
|
||||
future_plays = series['play_count'][end_idx:]
|
||||
|
||||
# 找到第一个超过目标时间的点
|
||||
target_idx = np.searchsorted(future_times, target_time)
|
||||
if target_idx >= len(future_times):
|
||||
# 若超出数据范围,取最后一个点
|
||||
y_play = future_plays[-1] if len(future_plays) > 0 else x_play[-1]
|
||||
actual_span = future_times[-1] - x_abs_time[-1] if len(future_times) > 0 else self.max_forecast_span
|
||||
else:
|
||||
y_play = future_plays[target_idx]
|
||||
actual_span = future_times[target_idx] - x_abs_time[-1]
|
||||
|
||||
y_play_val = np.log1p(y_play)
|
||||
|
||||
# 构造时间相关特征
|
||||
time_features = np.stack([
|
||||
x_abs_time,
|
||||
x_rel_time,
|
||||
x_time_delta,
|
||||
np.log1p(x_time_delta), # 对数变换处理长尾分布
|
||||
(x_time_delta > 3600).astype(float) # 间隔是否大于1小时
|
||||
], axis=-1)
|
||||
# 目标值:未来播放量增量
|
||||
current_play = series['play_count'][current_idx]
|
||||
target_play = series['play_count'][target_idx]
|
||||
target_delta = max(target_play - current_play, 0) # 增量
|
||||
|
||||
return {
|
||||
'x_play': torch.FloatTensor(x_play),
|
||||
'x_time_feat': torch.FloatTensor(time_features),
|
||||
'y_play': torch.FloatTensor([y_play_val]),
|
||||
'forecast_span': torch.FloatTensor([actual_span])
|
||||
'features': torch.FloatTensor(features),
|
||||
'target': torch.log2(torch.FloatTensor([target_delta]) + 1) # 输出增量
|
||||
}
|
||||
|
||||
def collate_fn(batch):
|
||||
"""动态填充处理"""
|
||||
max_len = max(item['x_play'].shape[0] for item in batch)
|
||||
|
||||
padded_batch = {
|
||||
'x_play': [],
|
||||
'x_time_feat': [],
|
||||
'y_play': [],
|
||||
'forecast_span': [],
|
||||
'padding_mask': []
|
||||
return {
|
||||
'features': torch.stack([x['features'] for x in batch]),
|
||||
'targets': torch.stack([x['target'] for x in batch])
|
||||
}
|
||||
|
||||
for item in batch:
|
||||
seq_len = item['x_play'].shape[0]
|
||||
pad_len = max_len - seq_len
|
||||
|
||||
# 填充播放量数据
|
||||
padded_play = torch.cat([
|
||||
item['x_play'],
|
||||
torch.zeros(pad_len)
|
||||
])
|
||||
padded_batch['x_play'].append(padded_play)
|
||||
|
||||
# 填充时间特征
|
||||
padded_time_feat = torch.cat([
|
||||
item['x_time_feat'],
|
||||
torch.zeros(pad_len, item['x_time_feat'].shape[1])
|
||||
])
|
||||
padded_batch['x_time_feat'].append(padded_time_feat)
|
||||
|
||||
# 创建padding mask
|
||||
mask = torch.cat([
|
||||
torch.ones(seq_len),
|
||||
torch.zeros(pad_len)
|
||||
])
|
||||
padded_batch['padding_mask'].append(mask.bool())
|
||||
|
||||
# 其他字段
|
||||
padded_batch['y_play'].append(item['y_play'])
|
||||
padded_batch['forecast_span'].append(item['forecast_span'])
|
||||
|
||||
# 转换为张量
|
||||
padded_batch['x_play'] = torch.stack(padded_batch['x_play'])
|
||||
padded_batch['x_time_feat'] = torch.stack(padded_batch['x_time_feat'])
|
||||
padded_batch['y_play'] = torch.stack(padded_batch['y_play'])
|
||||
padded_batch['forecast_span'] = torch.stack(padded_batch['forecast_span'])
|
||||
padded_batch['padding_mask'] = torch.stack(padded_batch['padding_mask'])
|
||||
|
||||
return padded_batch
|
||||
|
17
pred/inference.py
Normal file
17
pred/inference.py
Normal file
@ -0,0 +1,17 @@
|
||||
import numpy as np
|
||||
from model import CompactPredictor
|
||||
import torch
|
||||
|
||||
def main():
|
||||
model = CompactPredictor(10).to('cpu', dtype=torch.float32)
|
||||
model.load_state_dict(torch.load('play_predictor.pth'))
|
||||
model.eval()
|
||||
# inference
|
||||
data = [3,3.9315974229,5.4263146604,9.4958550269,10.9203528554,11.5835529305,13.0426853722,0.7916666667,0.2857142857,24.7794093257]
|
||||
np_arr = np.array([data])
|
||||
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)
|
||||
output = model(tensor)
|
||||
print(output)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
194
pred/model.py
194
pred/model.py
@ -1,182 +1,24 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class TimeEmbedding(nn.Module):
|
||||
"""时间特征编码模块"""
|
||||
def __init__(self, embed_dim):
|
||||
class CompactPredictor(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.norm = nn.LayerNorm(5)
|
||||
|
||||
# 时间特征编码(适配新的5维时间特征)
|
||||
self.time_encoder = nn.Sequential(
|
||||
nn.Linear(5, 64), # 输入维度对应x_time_feat的5个特征
|
||||
nn.GELU(),
|
||||
nn.LayerNorm(64),
|
||||
nn.Linear(64, embed_dim)
|
||||
self.net = nn.Sequential(
|
||||
nn.BatchNorm1d(input_size),
|
||||
nn.Linear(input_size, 256),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(128, 64),
|
||||
nn.Tanh(), # 使用Tanh限制输出范围
|
||||
nn.Linear(64, 1)
|
||||
)
|
||||
# 初始化最后一层为接近零的值
|
||||
nn.init.uniform_(self.net[-1].weight, -0.01, 0.01)
|
||||
nn.init.constant_(self.net[-1].bias, 0.0)
|
||||
|
||||
def forward(self, time_feat):
|
||||
"""
|
||||
time_feat: 时间特征 (batch, seq_len, 5)
|
||||
"""
|
||||
time_feat = self.norm(time_feat) # 应用归一化
|
||||
return self.time_encoder(time_feat)
|
||||
|
||||
|
||||
class MultiScaleEncoder(nn.Module):
|
||||
"""多尺度特征编码器"""
|
||||
def __init__(self, input_dim, d_model, nhead, conv_kernels=[3, 7, 23]):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
|
||||
self.conv_branches = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Conv1d(input_dim, d_model, kernel_size=k, padding=k//2),
|
||||
nn.GELU(),
|
||||
) for k in conv_kernels
|
||||
])
|
||||
|
||||
# 添加 LayerNorm 到单独的列表中
|
||||
self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in conv_kernels])
|
||||
|
||||
# Transformer编码器
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=d_model*4,
|
||||
batch_first=True # 修改为batch_first
|
||||
),
|
||||
num_layers=4
|
||||
)
|
||||
|
||||
# 特征融合层
|
||||
self.fusion = nn.Linear(d_model*(len(conv_kernels)+1), d_model)
|
||||
|
||||
def forward(self, x, padding_mask=None):
|
||||
"""
|
||||
x: 输入特征 (batch, seq_len, input_dim)
|
||||
padding_mask: 填充掩码 (batch, seq_len)
|
||||
"""
|
||||
|
||||
# 卷积分支处理
|
||||
conv_features = []
|
||||
x_conv = x.permute(0, 2, 1) # (batch, input_dim, seq_len)
|
||||
for i, branch in enumerate(self.conv_branches):
|
||||
feat = branch(x_conv) # 输出形状 (batch, d_model, seq_len)
|
||||
# 手动转置并应用 LayerNorm
|
||||
feat = feat.permute(0, 2, 1) # (batch, seq_len, d_model)
|
||||
feat = self.layer_norms[i](feat) # 应用 LayerNorm
|
||||
conv_features.append(feat)
|
||||
|
||||
# Transformer分支处理
|
||||
trans_feat = self.transformer(
|
||||
x,
|
||||
src_key_padding_mask=padding_mask
|
||||
) # (batch, seq_len, d_model)
|
||||
|
||||
# 特征拼接与融合
|
||||
combined = torch.cat(conv_features + [trans_feat], dim=-1)
|
||||
fused = self.fusion(combined) # (batch, seq_len, d_model)
|
||||
|
||||
return fused
|
||||
|
||||
class VideoPlayPredictor(nn.Module):
|
||||
def __init__(self, d_model=256, nhead=8):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
|
||||
# 特征嵌入
|
||||
self.time_embed = TimeEmbedding(embed_dim=64)
|
||||
self.base_embed = nn.Linear(1 + 64, d_model) # 播放量 + 时间特征
|
||||
|
||||
# 编码器
|
||||
self.encoder = MultiScaleEncoder(d_model, d_model, nhead)
|
||||
|
||||
# 时间感知预测头
|
||||
self.forecast_head = nn.Sequential(
|
||||
nn.Linear(2 * d_model + 1, d_model * 4), # 关键修改:输入维度为 2*d_model +1
|
||||
nn.GELU(),
|
||||
nn.Linear(d_model * 4, 1),
|
||||
nn.ReLU() # 确保输出非负
|
||||
)
|
||||
|
||||
# 上下文提取器
|
||||
self.context_extractor = nn.LSTM(
|
||||
input_size=d_model,
|
||||
hidden_size=d_model,
|
||||
num_layers=2,
|
||||
bidirectional=True,
|
||||
batch_first=True
|
||||
)
|
||||
|
||||
# 初始化参数
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for name, p in self.named_parameters():
|
||||
if 'forecast_head' in name:
|
||||
if 'weight' in name:
|
||||
nn.init.xavier_normal_(p, gain=1e-2) # 缩小初始化范围
|
||||
elif 'bias' in name:
|
||||
nn.init.constant_(p, 0.0)
|
||||
elif p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, x_play, x_time_feat, padding_mask, forecast_span):
|
||||
"""
|
||||
x_play: 历史播放量 (batch, seq_len)
|
||||
x_time_feat: 时间特征 (batch, seq_len, 5)
|
||||
padding_mask: 填充掩码 (batch, seq_len)
|
||||
forecast_span: 预测时间跨度 (batch, 1)
|
||||
"""
|
||||
batch_size = x_play.size(0)
|
||||
|
||||
# 时间特征编码
|
||||
time_emb = self.time_embed(x_time_feat) # (batch, seq_len, 64)
|
||||
|
||||
# 基础特征拼接
|
||||
base_feat = torch.cat([
|
||||
x_play.unsqueeze(-1), # (batch, seq_len, 1)
|
||||
time_emb
|
||||
], dim=-1) # (batch, seq_len, 1+64)
|
||||
|
||||
# 投影到模型维度
|
||||
embedded = self.base_embed(base_feat) # (batch, seq_len, d_model)
|
||||
|
||||
# 编码特征
|
||||
encoded = self.encoder(embedded, padding_mask) # (batch, seq_len, d_model)
|
||||
|
||||
# 提取上下文
|
||||
context, _ = self.context_extractor(encoded) # (batch, seq_len, d_model*2)
|
||||
context = context.mean(dim=1) # (batch, d_model*2)
|
||||
|
||||
# 融合时间跨度特征
|
||||
span_feat = torch.log1p(forecast_span) / 10 # 归一化
|
||||
combined = torch.cat([
|
||||
context,
|
||||
span_feat
|
||||
], dim=-1) # (batch, d_model*2 + 1)
|
||||
|
||||
# 最终预测
|
||||
pred = self.forecast_head(combined) # (batch, 1)
|
||||
|
||||
return pred
|
||||
|
||||
class MultiTaskWrapper(nn.Module):
|
||||
"""适配新数据结构的封装"""
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, batch):
|
||||
return self.model(
|
||||
batch['x_play'],
|
||||
batch['x_time_feat'],
|
||||
batch['padding_mask'],
|
||||
batch['forecast_span']
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
129
pred/train.py
129
pred/train.py
@ -1,76 +1,83 @@
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import DataLoader
|
||||
from model import MultiTaskWrapper, VideoPlayPredictor
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from dataset import VideoPlayDataset, collate_fn
|
||||
from pred.model import CompactPredictor
|
||||
|
||||
def train(model, dataloader, epochs=100, device='mps'):
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
|
||||
def train(model, dataloader, device, epochs=100):
|
||||
writer = SummaryWriter(f'./pred/runs/play_predictor_{time.strftime("%Y%m%d_%H%M")}')
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
|
||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
|
||||
total_steps=len(dataloader)*epochs)
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
steps = 0
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
global_step = 0
|
||||
for epoch in range(epochs):
|
||||
total_loss = 0.0
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
features = batch['features'].to(device)
|
||||
targets = batch['targets'].to(device)
|
||||
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
# movel whole batch to device
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = v.to(device)
|
||||
|
||||
# 前向传播
|
||||
pred = model(batch)
|
||||
|
||||
y_play = batch['y_play']
|
||||
|
||||
real = np.expm1(y_play.cpu().detach().numpy())
|
||||
yhat = np.expm1(pred.cpu().detach().numpy())
|
||||
print("real", [int(real[0][0]), int(real[1][0])])
|
||||
print("yhat", [int(yhat[0][0]), int(yhat[1][0])], [float(pred.cpu().detach().numpy()[0][0]), float(pred.cpu().detach().numpy()[1][0])])
|
||||
|
||||
# 计算加权损失
|
||||
weights = torch.log1p(batch['forecast_span']) # 时间越长权重越低
|
||||
loss_per_sample = F.huber_loss(pred, y_play, reduction='none')
|
||||
loss = (loss_per_sample * weights).mean()
|
||||
|
||||
# 反向传播
|
||||
outputs = model(features)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
#torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
|
||||
steps += 1
|
||||
|
||||
print(f"Epoch {epoch+1} | Step {steps} | Loss: {loss.item():.4f}")
|
||||
|
||||
scheduler.step()
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.4f}")
|
||||
|
||||
# 初始化模型
|
||||
device = 'mps'
|
||||
model = MultiTaskWrapper(VideoPlayPredictor())
|
||||
model = model.to(device)
|
||||
total_loss += loss.item()
|
||||
global_step += 1
|
||||
|
||||
data_dir = './data/pred'
|
||||
publish_time_path = './data/pred/publish_time.csv'
|
||||
dataset = VideoPlayDataset(
|
||||
data_dir=data_dir,
|
||||
publish_time_path=publish_time_path,
|
||||
min_seq_len=2, # 至少2个历史点
|
||||
max_seq_len=350, # 最多350个历史点
|
||||
min_forecast_span=60, # 预测跨度1分钟到
|
||||
max_forecast_span=86400 * 10 # 10天
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn, # 使用自定义collate函数
|
||||
)
|
||||
if global_step % 100 == 0:
|
||||
writer.add_scalar('Loss/train', loss.item(), global_step)
|
||||
writer.add_scalar('LR', scheduler.get_last_lr()[0], global_step)
|
||||
if batch_idx % 50 == 0:
|
||||
# 监控梯度
|
||||
grad_norms = [
|
||||
torch.norm(p.grad).item()
|
||||
for p in model.parameters() if p.grad is not None
|
||||
]
|
||||
writer.add_scalar('Grad/Norm', sum(grad_norms)/len(grad_norms), global_step)
|
||||
|
||||
# 开始训练
|
||||
train(model, dataloader, epochs=20, device=device)
|
||||
# 监控参数值
|
||||
param_means = [torch.mean(p.data).item() for p in model.parameters()]
|
||||
writer.add_scalar('Params/Mean', sum(param_means)/len(param_means), global_step)
|
||||
|
||||
samples_count = len(targets)
|
||||
r = random.randint(0, samples_count-1)
|
||||
t = float(torch.exp2(targets[r])) - 1
|
||||
o = float(torch.exp2(outputs[r])) - 1
|
||||
d = features[r].cpu().numpy()[0]
|
||||
speed = np.exp2(features[r].cpu().numpy()[2])
|
||||
time_diff = np.exp2(d) / 3600
|
||||
inc = speed * time_diff
|
||||
model_error = abs(t - o)
|
||||
reg_error = abs(inc - t)
|
||||
print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {model_error < reg_error}")
|
||||
|
||||
print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(dataloader):.4f}")
|
||||
|
||||
writer.close()
|
||||
return model
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = 'mps'
|
||||
|
||||
# 初始化数据集和模型
|
||||
dataset = VideoPlayDataset('./data/pred', './data/pred/publish_time.csv')
|
||||
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
|
||||
|
||||
# 获取特征维度
|
||||
sample = next(iter(dataloader))
|
||||
input_size = sample['features'].shape[1]
|
||||
|
||||
model = CompactPredictor(input_size).to(device)
|
||||
trained_model = train(model, dataloader, device, epochs=30)
|
||||
|
||||
# 保存模型
|
||||
torch.save(trained_model.state_dict(), 'play_predictor.pth')
|
Loading…
Reference in New Issue
Block a user