update: more accurate short-term prediction

This commit is contained in:
alikia2x (寒寒) 2025-03-15 04:19:04 +08:00
parent f0148ec444
commit a6211782cb
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
3 changed files with 83 additions and 53 deletions

View File

@ -1,6 +1,7 @@
import os import os
import json import json
import random import random
import bisect
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
@ -8,51 +9,69 @@ from torch.utils.data import Dataset
import datetime import datetime
class VideoPlayDataset(Dataset): class VideoPlayDataset(Dataset):
def __init__(self, data_dir, publish_time_path, term = 'long'): def __init__(self, data_dir, publish_time_path, term='long'):
self.data_dir = data_dir self.data_dir = data_dir
self.series_dict = self._load_and_process_data(publish_time_path) self.series_dict = self._load_and_process_data(publish_time_path)
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1] self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
self.term = term self.term = term
# Set time window based on term
self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600
if term == 'long': if term == 'long':
self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600, 30*24*3600, 100*24*3600] self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600, 30*24*3600, 100*24*3600]
else: else:
self.feature_windows = [3600, 6*3600, 12*3600, 24*3600, 3*24*3600, 7*24*3600, 60*24*3600] self.feature_windows = [
(3600, 0), (7200, 3600), (10800, 7200), (10800, 0),
(21600, 10800), (21600, 0), (64800, 43200), (86400, 21600),
(86400, 0), (172800, 86400), (259200, 0), (345600, 86400),
(604800, 0)
]
def _extract_features(self, series, current_idx, target_idx): def _extract_features(self, series, current_idx, target_idx):
"""Extract incremental features"""
current_time = series['abs_time'][current_idx] current_time = series['abs_time'][current_idx]
current_play = series['play_count'][current_idx] current_play = series['play_count'][current_idx]
dt = datetime.datetime.fromtimestamp(current_time) dt = datetime.datetime.fromtimestamp(current_time)
if self.term == 'long': if self.term == 'long':
time_features = [ time_features = [
np.log2(max(current_time - series['create_time'],1)) np.log2(max(current_time - series['create_time'], 1))
] ]
else: else:
time_features = [ time_features = [
(dt.hour * 3600 + dt.minute * 60 + dt.second) / 86400, (dt.weekday() * 24 + dt.hour) / 168, (dt.hour * 3600 + dt.minute * 60 + dt.second) / 86400,
np.log2(max(current_time - series['create_time'],1)) (dt.weekday() * 24 + dt.hour) / 168,
np.log2(max(current_time - series['create_time'], 1))
] ]
# Window growth features (incremental)
growth_features = [] growth_features = []
for window in self.feature_windows: if self.term == 'long':
prev_time = current_time - window for window in self.feature_windows:
prev_idx = self._get_nearest_value(series, prev_time, current_idx) prev_time = current_time - window
if prev_idx is not None: prev_idx = self._get_nearest_value(series, prev_time, current_idx)
time_diff = current_time - series['abs_time'][prev_idx] if prev_idx is not None:
play_diff = current_play - series['play_count'][prev_idx] time_diff = current_time - series['abs_time'][prev_idx]
scaled_diff = play_diff / (time_diff / window) if time_diff > 0 else 0.0 play_diff = current_play - series['play_count'][prev_idx]
else: scaled_diff = play_diff / (time_diff / window) if time_diff > 0 else 0.0
scaled_diff = 0.0 else:
growth_features.append(np.log2(max(scaled_diff,1))) scaled_diff = 0.0
growth_features.append(np.log2(max(scaled_diff, 1)))
else:
for window_start, window_end in self.feature_windows:
prev_time_start = current_time - window_start
prev_time_end = current_time - window_end # window_end is typically 0
prev_idx_start = self._get_nearest_value(series, prev_time_start, current_idx)
prev_idx_end = self._get_nearest_value(series, prev_time_end, current_idx)
if prev_idx_start is not None and prev_idx_end is not None:
time_diff = series['abs_time'][prev_idx_end] - series['abs_time'][prev_idx_start]
play_diff = series['play_count'][prev_idx_end] - series['play_count'][prev_idx_start]
scaled_diff = play_diff / (time_diff / (window_start - window_end)) if time_diff > 0 else 0.0
else:
scaled_diff = 0.0
growth_features.append(np.log2(max(scaled_diff, 1)))
time_diff = series['abs_time'][target_idx] - series['abs_time'][current_idx] time_diff = series['abs_time'][target_idx] - current_time
return [np.log2(max(time_diff, 1))] + [np.log2(current_play + 1)] + growth_features + time_features
return [np.log2(max(time_diff,1))] + [np.log2(current_play + 1)] + growth_features + time_features
def _load_and_process_data(self, publish_time_path): def _load_and_process_data(self, publish_time_path):
# Load publish time data
publish_df = pd.read_csv(publish_time_path) publish_df = pd.read_csv(publish_time_path)
publish_df['published_at'] = pd.to_datetime(publish_df['published_at']) publish_df['published_at'] = pd.to_datetime(publish_df['published_at'])
publish_dict = dict(zip(publish_df['aid'], publish_df['published_at'])) publish_dict = dict(zip(publish_df['aid'], publish_df['published_at']))
@ -75,42 +94,53 @@ class VideoPlayDataset(Dataset):
} }
series_dict[aid]['abs_time'].append(item['added']) series_dict[aid]['abs_time'].append(item['added'])
series_dict[aid]['play_count'].append(item['view']) series_dict[aid]['play_count'].append(item['view'])
# Sort each series by absolute time
for aid in series_dict:
sorted_indices = sorted(range(len(series_dict[aid]['abs_time'])),
key=lambda k: series_dict[aid]['abs_time'][k])
series_dict[aid]['abs_time'] = [series_dict[aid]['abs_time'][i] for i in sorted_indices]
series_dict[aid]['play_count'] = [series_dict[aid]['play_count'][i] for i in sorted_indices]
return series_dict return series_dict
def __len__(self): def __len__(self):
return 100000 # Use virtual length for infinite sampling return 100000 # Virtual length for sampling
def _get_nearest_value(self, series, target_time, current_idx): def _get_nearest_value(self, series, target_time, current_idx):
"""Get the nearest data point before the specified time""" times = series['abs_time']
min_diff = float('inf') pos = bisect.bisect_right(times, target_time, 0, current_idx + 1)
for i in range(current_idx + 1, len(series['abs_time'])): candidates = []
diff = abs(series['abs_time'][i] - target_time) if pos > 0:
if diff < min_diff: candidates.append(pos - 1)
min_diff = diff if pos <= current_idx:
else: candidates.append(pos)
return i - 1 if not candidates:
return len(series['abs_time']) - 1 return None
closest_idx = min(candidates, key=lambda i: abs(times[i] - target_time))
return closest_idx
def __getitem__(self, _idx): def __getitem__(self, _idx):
series = random.choice(self.valid_series) while True:
current_idx = random.randint(0, len(series['abs_time'])-2) series = random.choice(self.valid_series)
if self.term == 'long': if len(series['abs_time']) < 2:
range_length = 50 continue
else: current_idx = random.randint(0, len(series['abs_time']) - 2)
range_length = 10 current_time = series['abs_time'][current_idx]
target_idx = random.randint(max(0, current_idx-range_length), current_idx) max_target_time = current_time + self.time_window
candidate_indices = []
# Extract features for j in range(current_idx + 1, len(series['abs_time'])):
features = self._extract_features(series, current_idx, target_idx) if series['abs_time'][j] > max_target_time:
break
# Target value: future play count increment candidate_indices.append(j)
if not candidate_indices:
continue
target_idx = random.choice(candidate_indices)
break
current_play = series['play_count'][current_idx] current_play = series['play_count'][current_idx]
target_play = series['play_count'][target_idx] target_play = series['play_count'][target_idx]
target_delta = max(target_play - current_play, 0) # Increment target_delta = max(target_play - current_play, 0)
return { return {
'features': torch.FloatTensor(features), 'features': torch.FloatTensor(self._extract_features(series, current_idx, target_idx)),
'target': torch.log2(torch.FloatTensor([target_delta]) + 1) # Output increment 'target': torch.log2(torch.FloatTensor([target_delta]) + 1)
} }
def collate_fn(batch): def collate_fn(batch):

View File

@ -4,19 +4,19 @@ from model import CompactPredictor
import torch import torch
def main(): def main():
model = CompactPredictor(12).to('cpu', dtype=torch.float32) model = CompactPredictor(18).to('cpu', dtype=torch.float32)
model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0226.pt')) model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0407.pt'))
model.eval() model.eval()
# inference # inference
initial = 999469 initial = 999469
last = initial last = initial
start_time = '2025-03-11 15:03:31' start_time = '2025-03-11 15:03:31'
for i in range(1, 32): for i in range(1, 64):
hour = i / 4.2342 hour = i / 4.2342
sec = hour * 3600 sec = hour * 3600
time_d = np.log2(sec) time_d = np.log2(sec)
data = [time_d, np.log2(initial+1), # time_delta, current_views data = [time_d, np.log2(initial+1), # time_delta, current_views
6.319254, 9.0611, 9.401403, 10.653134, 12.008604, 13.230796, 16.3302, # grows_feat 6.319244, 6.96288, 7.04251, 8.38551, 7.648974, 9.061098, 9.147728, 10.07276, 10.653134, 10.092601, 12.008604, 11.676683, 13.230796, # grows_feat
0.627442, 0.232492, 24.778674 # time_feat 0.627442, 0.232492, 24.778674 # time_feat
] ]
np_arr = np.array([data]) np_arr = np.array([data])

View File

@ -11,7 +11,7 @@ def train(model, dataloader, device, epochs=100):
writer = SummaryWriter(f'./pred/runs/play_predictor_{time.strftime("%Y%m%d_%H%M")}') writer = SummaryWriter(f'./pred/runs/play_predictor_{time.strftime("%Y%m%d_%H%M")}')
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
total_steps=len(dataloader)*epochs) total_steps=len(dataloader)*30)
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
model.train() model.train()
@ -82,7 +82,7 @@ if __name__ == "__main__":
input_size = sample['features'].shape[1] input_size = sample['features'].shape[1]
model = CompactPredictor(input_size).to(device) model = CompactPredictor(input_size).to(device)
trained_model = train(model, dataloader, device, epochs=30) trained_model = train(model, dataloader, device, epochs=18)
# Save model # Save model
torch.save(trained_model.state_dict(), f"./pred/checkpoints/model_{time.strftime('%Y%m%d_%H%M')}.pt") torch.save(trained_model.state_dict(), f"./pred/checkpoints/model_{time.strftime('%Y%m%d_%H%M')}.pt")