update: more accurate short-term prediction
This commit is contained in:
parent
f0148ec444
commit
a6211782cb
104
pred/dataset.py
104
pred/dataset.py
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
import bisect
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -8,34 +9,41 @@ from torch.utils.data import Dataset
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
class VideoPlayDataset(Dataset):
|
class VideoPlayDataset(Dataset):
|
||||||
def __init__(self, data_dir, publish_time_path, term = 'long'):
|
def __init__(self, data_dir, publish_time_path, term='long'):
|
||||||
self.data_dir = data_dir
|
self.data_dir = data_dir
|
||||||
self.series_dict = self._load_and_process_data(publish_time_path)
|
self.series_dict = self._load_and_process_data(publish_time_path)
|
||||||
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
||||||
self.term = term
|
self.term = term
|
||||||
|
# Set time window based on term
|
||||||
|
self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600
|
||||||
if term == 'long':
|
if term == 'long':
|
||||||
self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600, 30*24*3600, 100*24*3600]
|
self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600, 30*24*3600, 100*24*3600]
|
||||||
else:
|
else:
|
||||||
self.feature_windows = [3600, 6*3600, 12*3600, 24*3600, 3*24*3600, 7*24*3600, 60*24*3600]
|
self.feature_windows = [
|
||||||
|
(3600, 0), (7200, 3600), (10800, 7200), (10800, 0),
|
||||||
|
(21600, 10800), (21600, 0), (64800, 43200), (86400, 21600),
|
||||||
|
(86400, 0), (172800, 86400), (259200, 0), (345600, 86400),
|
||||||
|
(604800, 0)
|
||||||
|
]
|
||||||
|
|
||||||
def _extract_features(self, series, current_idx, target_idx):
|
def _extract_features(self, series, current_idx, target_idx):
|
||||||
"""Extract incremental features"""
|
|
||||||
current_time = series['abs_time'][current_idx]
|
current_time = series['abs_time'][current_idx]
|
||||||
current_play = series['play_count'][current_idx]
|
current_play = series['play_count'][current_idx]
|
||||||
dt = datetime.datetime.fromtimestamp(current_time)
|
dt = datetime.datetime.fromtimestamp(current_time)
|
||||||
|
|
||||||
if self.term == 'long':
|
if self.term == 'long':
|
||||||
time_features = [
|
time_features = [
|
||||||
np.log2(max(current_time - series['create_time'],1))
|
np.log2(max(current_time - series['create_time'], 1))
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
time_features = [
|
time_features = [
|
||||||
(dt.hour * 3600 + dt.minute * 60 + dt.second) / 86400, (dt.weekday() * 24 + dt.hour) / 168,
|
(dt.hour * 3600 + dt.minute * 60 + dt.second) / 86400,
|
||||||
np.log2(max(current_time - series['create_time'],1))
|
(dt.weekday() * 24 + dt.hour) / 168,
|
||||||
|
np.log2(max(current_time - series['create_time'], 1))
|
||||||
]
|
]
|
||||||
|
|
||||||
# Window growth features (incremental)
|
|
||||||
growth_features = []
|
growth_features = []
|
||||||
|
if self.term == 'long':
|
||||||
for window in self.feature_windows:
|
for window in self.feature_windows:
|
||||||
prev_time = current_time - window
|
prev_time = current_time - window
|
||||||
prev_idx = self._get_nearest_value(series, prev_time, current_idx)
|
prev_idx = self._get_nearest_value(series, prev_time, current_idx)
|
||||||
@ -45,14 +53,25 @@ class VideoPlayDataset(Dataset):
|
|||||||
scaled_diff = play_diff / (time_diff / window) if time_diff > 0 else 0.0
|
scaled_diff = play_diff / (time_diff / window) if time_diff > 0 else 0.0
|
||||||
else:
|
else:
|
||||||
scaled_diff = 0.0
|
scaled_diff = 0.0
|
||||||
growth_features.append(np.log2(max(scaled_diff,1)))
|
growth_features.append(np.log2(max(scaled_diff, 1)))
|
||||||
|
else:
|
||||||
|
for window_start, window_end in self.feature_windows:
|
||||||
|
prev_time_start = current_time - window_start
|
||||||
|
prev_time_end = current_time - window_end # window_end is typically 0
|
||||||
|
prev_idx_start = self._get_nearest_value(series, prev_time_start, current_idx)
|
||||||
|
prev_idx_end = self._get_nearest_value(series, prev_time_end, current_idx)
|
||||||
|
if prev_idx_start is not None and prev_idx_end is not None:
|
||||||
|
time_diff = series['abs_time'][prev_idx_end] - series['abs_time'][prev_idx_start]
|
||||||
|
play_diff = series['play_count'][prev_idx_end] - series['play_count'][prev_idx_start]
|
||||||
|
scaled_diff = play_diff / (time_diff / (window_start - window_end)) if time_diff > 0 else 0.0
|
||||||
|
else:
|
||||||
|
scaled_diff = 0.0
|
||||||
|
growth_features.append(np.log2(max(scaled_diff, 1)))
|
||||||
|
|
||||||
time_diff = series['abs_time'][target_idx] - series['abs_time'][current_idx]
|
time_diff = series['abs_time'][target_idx] - current_time
|
||||||
|
return [np.log2(max(time_diff, 1))] + [np.log2(current_play + 1)] + growth_features + time_features
|
||||||
return [np.log2(max(time_diff,1))] + [np.log2(current_play + 1)] + growth_features + time_features
|
|
||||||
|
|
||||||
def _load_and_process_data(self, publish_time_path):
|
def _load_and_process_data(self, publish_time_path):
|
||||||
# Load publish time data
|
|
||||||
publish_df = pd.read_csv(publish_time_path)
|
publish_df = pd.read_csv(publish_time_path)
|
||||||
publish_df['published_at'] = pd.to_datetime(publish_df['published_at'])
|
publish_df['published_at'] = pd.to_datetime(publish_df['published_at'])
|
||||||
publish_dict = dict(zip(publish_df['aid'], publish_df['published_at']))
|
publish_dict = dict(zip(publish_df['aid'], publish_df['published_at']))
|
||||||
@ -75,42 +94,53 @@ class VideoPlayDataset(Dataset):
|
|||||||
}
|
}
|
||||||
series_dict[aid]['abs_time'].append(item['added'])
|
series_dict[aid]['abs_time'].append(item['added'])
|
||||||
series_dict[aid]['play_count'].append(item['view'])
|
series_dict[aid]['play_count'].append(item['view'])
|
||||||
|
# Sort each series by absolute time
|
||||||
|
for aid in series_dict:
|
||||||
|
sorted_indices = sorted(range(len(series_dict[aid]['abs_time'])),
|
||||||
|
key=lambda k: series_dict[aid]['abs_time'][k])
|
||||||
|
series_dict[aid]['abs_time'] = [series_dict[aid]['abs_time'][i] for i in sorted_indices]
|
||||||
|
series_dict[aid]['play_count'] = [series_dict[aid]['play_count'][i] for i in sorted_indices]
|
||||||
return series_dict
|
return series_dict
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return 100000 # Use virtual length for infinite sampling
|
return 100000 # Virtual length for sampling
|
||||||
|
|
||||||
def _get_nearest_value(self, series, target_time, current_idx):
|
def _get_nearest_value(self, series, target_time, current_idx):
|
||||||
"""Get the nearest data point before the specified time"""
|
times = series['abs_time']
|
||||||
min_diff = float('inf')
|
pos = bisect.bisect_right(times, target_time, 0, current_idx + 1)
|
||||||
for i in range(current_idx + 1, len(series['abs_time'])):
|
candidates = []
|
||||||
diff = abs(series['abs_time'][i] - target_time)
|
if pos > 0:
|
||||||
if diff < min_diff:
|
candidates.append(pos - 1)
|
||||||
min_diff = diff
|
if pos <= current_idx:
|
||||||
else:
|
candidates.append(pos)
|
||||||
return i - 1
|
if not candidates:
|
||||||
return len(series['abs_time']) - 1
|
return None
|
||||||
|
closest_idx = min(candidates, key=lambda i: abs(times[i] - target_time))
|
||||||
|
return closest_idx
|
||||||
|
|
||||||
def __getitem__(self, _idx):
|
def __getitem__(self, _idx):
|
||||||
|
while True:
|
||||||
series = random.choice(self.valid_series)
|
series = random.choice(self.valid_series)
|
||||||
current_idx = random.randint(0, len(series['abs_time'])-2)
|
if len(series['abs_time']) < 2:
|
||||||
if self.term == 'long':
|
continue
|
||||||
range_length = 50
|
current_idx = random.randint(0, len(series['abs_time']) - 2)
|
||||||
else:
|
current_time = series['abs_time'][current_idx]
|
||||||
range_length = 10
|
max_target_time = current_time + self.time_window
|
||||||
target_idx = random.randint(max(0, current_idx-range_length), current_idx)
|
candidate_indices = []
|
||||||
|
for j in range(current_idx + 1, len(series['abs_time'])):
|
||||||
# Extract features
|
if series['abs_time'][j] > max_target_time:
|
||||||
features = self._extract_features(series, current_idx, target_idx)
|
break
|
||||||
|
candidate_indices.append(j)
|
||||||
# Target value: future play count increment
|
if not candidate_indices:
|
||||||
|
continue
|
||||||
|
target_idx = random.choice(candidate_indices)
|
||||||
|
break
|
||||||
current_play = series['play_count'][current_idx]
|
current_play = series['play_count'][current_idx]
|
||||||
target_play = series['play_count'][target_idx]
|
target_play = series['play_count'][target_idx]
|
||||||
target_delta = max(target_play - current_play, 0) # Increment
|
target_delta = max(target_play - current_play, 0)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'features': torch.FloatTensor(features),
|
'features': torch.FloatTensor(self._extract_features(series, current_idx, target_idx)),
|
||||||
'target': torch.log2(torch.FloatTensor([target_delta]) + 1) # Output increment
|
'target': torch.log2(torch.FloatTensor([target_delta]) + 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
|
@ -4,19 +4,19 @@ from model import CompactPredictor
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model = CompactPredictor(12).to('cpu', dtype=torch.float32)
|
model = CompactPredictor(18).to('cpu', dtype=torch.float32)
|
||||||
model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0226.pt'))
|
model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0407.pt'))
|
||||||
model.eval()
|
model.eval()
|
||||||
# inference
|
# inference
|
||||||
initial = 999469
|
initial = 999469
|
||||||
last = initial
|
last = initial
|
||||||
start_time = '2025-03-11 15:03:31'
|
start_time = '2025-03-11 15:03:31'
|
||||||
for i in range(1, 32):
|
for i in range(1, 64):
|
||||||
hour = i / 4.2342
|
hour = i / 4.2342
|
||||||
sec = hour * 3600
|
sec = hour * 3600
|
||||||
time_d = np.log2(sec)
|
time_d = np.log2(sec)
|
||||||
data = [time_d, np.log2(initial+1), # time_delta, current_views
|
data = [time_d, np.log2(initial+1), # time_delta, current_views
|
||||||
6.319254, 9.0611, 9.401403, 10.653134, 12.008604, 13.230796, 16.3302, # grows_feat
|
6.319244, 6.96288, 7.04251, 8.38551, 7.648974, 9.061098, 9.147728, 10.07276, 10.653134, 10.092601, 12.008604, 11.676683, 13.230796, # grows_feat
|
||||||
0.627442, 0.232492, 24.778674 # time_feat
|
0.627442, 0.232492, 24.778674 # time_feat
|
||||||
]
|
]
|
||||||
np_arr = np.array([data])
|
np_arr = np.array([data])
|
||||||
|
@ -11,7 +11,7 @@ def train(model, dataloader, device, epochs=100):
|
|||||||
writer = SummaryWriter(f'./pred/runs/play_predictor_{time.strftime("%Y%m%d_%H%M")}')
|
writer = SummaryWriter(f'./pred/runs/play_predictor_{time.strftime("%Y%m%d_%H%M")}')
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
|
||||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
|
||||||
total_steps=len(dataloader)*epochs)
|
total_steps=len(dataloader)*30)
|
||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
@ -82,7 +82,7 @@ if __name__ == "__main__":
|
|||||||
input_size = sample['features'].shape[1]
|
input_size = sample['features'].shape[1]
|
||||||
|
|
||||||
model = CompactPredictor(input_size).to(device)
|
model = CompactPredictor(input_size).to(device)
|
||||||
trained_model = train(model, dataloader, device, epochs=30)
|
trained_model = train(model, dataloader, device, epochs=18)
|
||||||
|
|
||||||
# Save model
|
# Save model
|
||||||
torch.save(trained_model.state_dict(), f"./pred/checkpoints/model_{time.strftime('%Y%m%d_%H%M')}.pt")
|
torch.save(trained_model.state_dict(), f"./pred/checkpoints/model_{time.strftime('%Y%m%d_%H%M')}.pt")
|
||||||
|
Loading…
Reference in New Issue
Block a user