add: the 'term' param in Dataset class
This commit is contained in:
parent
23e5d6a8c9
commit
f0148ec444
@ -1,4 +1,3 @@
|
||||
# dataset.py
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
@ -9,25 +8,33 @@ from torch.utils.data import Dataset
|
||||
import datetime
|
||||
|
||||
class VideoPlayDataset(Dataset):
|
||||
def __init__(self, data_dir, publish_time_path, max_future_days=7):
|
||||
def __init__(self, data_dir, publish_time_path, term = 'long'):
|
||||
self.data_dir = data_dir
|
||||
self.max_future_seconds = max_future_days * 86400
|
||||
self.series_dict = self._load_and_process_data(publish_time_path)
|
||||
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
||||
self.feature_windows = [3600, 3*3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600, 60*24*3600]
|
||||
self.term = term
|
||||
if term == 'long':
|
||||
self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600, 30*24*3600, 100*24*3600]
|
||||
else:
|
||||
self.feature_windows = [3600, 6*3600, 12*3600, 24*3600, 3*24*3600, 7*24*3600, 60*24*3600]
|
||||
|
||||
def _extract_features(self, series, current_idx, target_idx):
|
||||
"""提取增量特征"""
|
||||
"""Extract incremental features"""
|
||||
current_time = series['abs_time'][current_idx]
|
||||
current_play = series['play_count'][current_idx]
|
||||
dt = datetime.datetime.fromtimestamp(current_time)
|
||||
# 时间特征
|
||||
time_features = [
|
||||
(dt.hour * 3600 + dt.minute * 60 + dt.second) / 86400, (dt.weekday() * 24 + dt.hour) / 168,
|
||||
np.log2(max(current_time - series['create_time'],1))
|
||||
]
|
||||
|
||||
# 窗口增长特征(增量)
|
||||
if self.term == 'long':
|
||||
time_features = [
|
||||
np.log2(max(current_time - series['create_time'],1))
|
||||
]
|
||||
else:
|
||||
time_features = [
|
||||
(dt.hour * 3600 + dt.minute * 60 + dt.second) / 86400, (dt.weekday() * 24 + dt.hour) / 168,
|
||||
np.log2(max(current_time - series['create_time'],1))
|
||||
]
|
||||
|
||||
# Window growth features (incremental)
|
||||
growth_features = []
|
||||
for window in self.feature_windows:
|
||||
prev_time = current_time - window
|
||||
@ -45,7 +52,7 @@ class VideoPlayDataset(Dataset):
|
||||
return [np.log2(max(time_diff,1))] + [np.log2(current_play + 1)] + growth_features + time_features
|
||||
|
||||
def _load_and_process_data(self, publish_time_path):
|
||||
# 加载发布时间数据
|
||||
# Load publish time data
|
||||
publish_df = pd.read_csv(publish_time_path)
|
||||
publish_df['published_at'] = pd.to_datetime(publish_df['published_at'])
|
||||
publish_dict = dict(zip(publish_df['aid'], publish_df['published_at']))
|
||||
@ -71,10 +78,10 @@ class VideoPlayDataset(Dataset):
|
||||
return series_dict
|
||||
|
||||
def __len__(self):
|
||||
return 100000 # 使用虚拟长度实现无限采样
|
||||
return 100000 # Use virtual length for infinite sampling
|
||||
|
||||
def _get_nearest_value(self, series, target_time, current_idx):
|
||||
"""获取指定时间前最近的数据点"""
|
||||
"""Get the nearest data point before the specified time"""
|
||||
min_diff = float('inf')
|
||||
for i in range(current_idx + 1, len(series['abs_time'])):
|
||||
diff = abs(series['abs_time'][i] - target_time)
|
||||
@ -84,22 +91,26 @@ class VideoPlayDataset(Dataset):
|
||||
return i - 1
|
||||
return len(series['abs_time']) - 1
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def __getitem__(self, _idx):
|
||||
series = random.choice(self.valid_series)
|
||||
current_idx = random.randint(0, len(series['abs_time'])-2)
|
||||
target_idx = random.randint(max(0, current_idx-10), current_idx)
|
||||
if self.term == 'long':
|
||||
range_length = 50
|
||||
else:
|
||||
range_length = 10
|
||||
target_idx = random.randint(max(0, current_idx-range_length), current_idx)
|
||||
|
||||
# 提取特征
|
||||
# Extract features
|
||||
features = self._extract_features(series, current_idx, target_idx)
|
||||
|
||||
# 目标值:未来播放量增量
|
||||
# Target value: future play count increment
|
||||
current_play = series['play_count'][current_idx]
|
||||
target_play = series['play_count'][target_idx]
|
||||
target_delta = max(target_play - current_play, 0) # 增量
|
||||
target_delta = max(target_play - current_play, 0) # Increment
|
||||
|
||||
return {
|
||||
'features': torch.FloatTensor(features),
|
||||
'target': torch.log2(torch.FloatTensor([target_delta]) + 1) # 输出增量
|
||||
'target': torch.log2(torch.FloatTensor([target_delta]) + 1) # Output increment
|
||||
}
|
||||
|
||||
def collate_fn(batch):
|
||||
|
28
pred/export_onnx.py
Normal file
28
pred/export_onnx.py
Normal file
@ -0,0 +1,28 @@
|
||||
import torch
|
||||
import torch.onnx
|
||||
from model import CompactPredictor
|
||||
|
||||
def export_model(input_size, checkpoint_path, onnx_path):
|
||||
model = CompactPredictor(input_size)
|
||||
model.load_state_dict(torch.load(checkpoint_path))
|
||||
|
||||
dummy_input = torch.randn(1, input_size)
|
||||
|
||||
model.eval()
|
||||
|
||||
torch.onnx.export(model, # Model to be exported
|
||||
dummy_input, # Model input
|
||||
onnx_path, # Save path
|
||||
export_params=True, # Whether to export model parameters
|
||||
opset_version=11, # ONNX opset version
|
||||
do_constant_folding=True, # Whether to perform constant folding optimization
|
||||
input_names=['input'], # Input node name
|
||||
output_names=['output'], # Output node name
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # Dynamic batch size
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
print(f"ONNX model has been exported to: {onnx_path}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
export_model(10, './pred/checkpoints/long_term.pt', 'long_term.onnx')
|
||||
export_model(12, './pred/checkpoints/short_term.pt', 'short_term.onnx')
|
@ -1,27 +1,31 @@
|
||||
import datetime
|
||||
import numpy as np
|
||||
from model import CompactPredictor
|
||||
import torch
|
||||
|
||||
def main():
|
||||
model = CompactPredictor(10).to('cpu', dtype=torch.float32)
|
||||
model.load_state_dict(torch.load('./pred/checkpoints/play_predictor.pth'))
|
||||
model = CompactPredictor(12).to('cpu', dtype=torch.float32)
|
||||
model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0226.pt'))
|
||||
model.eval()
|
||||
# inference
|
||||
last = 999469
|
||||
for i in range(1, 48):
|
||||
hour = i / 2
|
||||
initial = 999469
|
||||
last = initial
|
||||
start_time = '2025-03-11 15:03:31'
|
||||
for i in range(1, 32):
|
||||
hour = i / 4.2342
|
||||
sec = hour * 3600
|
||||
time_d = np.log2(sec)
|
||||
data = [time_d, 19.9295936113, # time_delta, current_views
|
||||
6.1575520046,8.980,10.6183855023,12.0313328273,13.2537252486, # growth_feat
|
||||
0.625,0.2857142857,24.7794093257 # time_feat
|
||||
data = [time_d, np.log2(initial+1), # time_delta, current_views
|
||||
6.319254, 9.0611, 9.401403, 10.653134, 12.008604, 13.230796, 16.3302, # grows_feat
|
||||
0.627442, 0.232492, 24.778674 # time_feat
|
||||
]
|
||||
np_arr = np.array([data])
|
||||
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)
|
||||
output = model(tensor)
|
||||
num = output.detach().numpy()[0][0]
|
||||
views_pred = int(np.exp2(num)) + 999469
|
||||
print(f"{int(15+hour)%24:02d}:{int((15+hour)*60)%60:02d}", views_pred, views_pred - last)
|
||||
views_pred = int(np.exp2(num)) + initial
|
||||
current_time = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S') + datetime.timedelta(hours=hour)
|
||||
print(current_time.strftime('%m-%d %H:%M'), views_pred, views_pred - last)
|
||||
last = views_pred
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,5 +1,4 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class CompactPredictor(nn.Module):
|
||||
def __init__(self, input_size):
|
||||
@ -13,10 +12,10 @@ class CompactPredictor(nn.Module):
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(128, 64),
|
||||
nn.Tanh(), # 使用Tanh限制输出范围
|
||||
nn.Tanh(), # Use Tanh to limit the output range
|
||||
nn.Linear(64, 1)
|
||||
)
|
||||
# 初始化最后一层为接近零的值
|
||||
# Initialize the last layer to values close to zero
|
||||
nn.init.uniform_(self.net[-1].weight, -0.01, 0.01)
|
||||
nn.init.constant_(self.net[-1].bias, 0.0)
|
||||
|
||||
|
@ -37,14 +37,14 @@ def train(model, dataloader, device, epochs=100):
|
||||
writer.add_scalar('Loss/train', loss.item(), global_step)
|
||||
writer.add_scalar('LR', scheduler.get_last_lr()[0], global_step)
|
||||
if batch_idx % 50 == 0:
|
||||
# 监控梯度
|
||||
# Monitor gradients
|
||||
grad_norms = [
|
||||
torch.norm(p.grad).item()
|
||||
for p in model.parameters() if p.grad is not None
|
||||
]
|
||||
writer.add_scalar('Grad/Norm', sum(grad_norms)/len(grad_norms), global_step)
|
||||
|
||||
# 监控参数值
|
||||
# Monitor parameter values
|
||||
param_means = [torch.mean(p.data).item() for p in model.parameters()]
|
||||
writer.add_scalar('Params/Mean', sum(param_means)/len(param_means), global_step)
|
||||
|
||||
@ -62,7 +62,7 @@ def train(model, dataloader, device, epochs=100):
|
||||
reg_error = abs(inc - t)
|
||||
if model_error < reg_error:
|
||||
good += 1
|
||||
#print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {good/samples_count*100:.1f}%")
|
||||
# print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {good/samples_count*100:.1f}%")
|
||||
writer.add_scalar('Train/WinRate', good/samples_count, global_step)
|
||||
|
||||
print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(dataloader):.4f}")
|
||||
@ -73,16 +73,16 @@ def train(model, dataloader, device, epochs=100):
|
||||
if __name__ == "__main__":
|
||||
device = 'mps'
|
||||
|
||||
# 初始化数据集和模型
|
||||
dataset = VideoPlayDataset('./data/pred', './data/pred/publish_time.csv')
|
||||
# Initialize dataset and model
|
||||
dataset = VideoPlayDataset('./data/pred', './data/pred/publish_time.csv', 'short')
|
||||
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
|
||||
|
||||
# 获取特征维度
|
||||
# Get feature dimension
|
||||
sample = next(iter(dataloader))
|
||||
input_size = sample['features'].shape[1]
|
||||
|
||||
model = CompactPredictor(input_size).to(device)
|
||||
trained_model = train(model, dataloader, device, epochs=30)
|
||||
|
||||
# 保存模型
|
||||
torch.save(trained_model.state_dict(), 'play_predictor.pth')
|
||||
# Save model
|
||||
torch.save(trained_model.state_dict(), f"./pred/checkpoints/model_{time.strftime('%Y%m%d_%H%M')}.pt")
|
||||
|
Loading…
Reference in New Issue
Block a user