From bf919da1ea0d09a64e73a29e486fa4fa4ca5096a Mon Sep 17 00:00:00 2001 From: alikia2x Date: Sat, 15 Mar 2025 05:07:49 +0800 Subject: [PATCH] update: the features for the model --- pred/dataset.py | 40 ++++++++++++++++++++++++++++++++++------ pred/inference.py | 16 ++++++++-------- pred/train.py | 4 ++-- 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/pred/dataset.py b/pred/dataset.py index 0c67454..9ed4846 100644 --- a/pred/dataset.py +++ b/pred/dataset.py @@ -9,21 +9,49 @@ from torch.utils.data import Dataset import datetime class VideoPlayDataset(Dataset): - def __init__(self, data_dir, publish_time_path, term='long'): + def __init__(self, data_dir, publish_time_path, term='long', seed=42): + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + self.data_dir = data_dir self.series_dict = self._load_and_process_data(publish_time_path) self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1] self.term = term # Set time window based on term self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600 + MINUTE = 60 + HOUR = 3600 + DAY = 24 * HOUR + if term == 'long': - self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600, 30*24*3600, 100*24*3600] + self.feature_windows = [ + 1 * HOUR, + 6 * HOUR, + 1 *DAY, + 3 * DAY, + 7 * DAY, + 30 * DAY, + 100 * DAY + ] else: self.feature_windows = [ - (3600, 0), (7200, 3600), (10800, 7200), (10800, 0), - (21600, 10800), (21600, 0), (64800, 43200), (86400, 21600), - (86400, 0), (172800, 86400), (259200, 0), (345600, 86400), - (604800, 0) + ( 15 * MINUTE, 0 * MINUTE), + ( 40 * MINUTE, 0 * MINUTE), + ( 1 * HOUR, 0 * HOUR), + ( 2 * HOUR, 1 * HOUR), + ( 3 * HOUR, 2 * HOUR), + ( 3 * HOUR, 0 * HOUR), + #( 6 * HOUR, 3 * HOUR), + ( 6 * HOUR, 0 * HOUR), + (18 * HOUR, 12 * HOUR), + #( 1 * DAY, 6 * HOUR), + ( 1 * DAY, 0 * DAY), + #( 2 * DAY, 1 * DAY), + ( 3 * DAY, 0 * DAY), + #( 4 * DAY, 1 * DAY), + ( 7 * DAY, 0 * DAY) ] def _extract_features(self, series, current_idx, target_idx): diff --git a/pred/inference.py b/pred/inference.py index b8d0226..abfe236 100644 --- a/pred/inference.py +++ b/pred/inference.py @@ -4,20 +4,20 @@ from model import CompactPredictor import torch def main(): - model = CompactPredictor(18).to('cpu', dtype=torch.float32) - model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0407.pt')) + model = CompactPredictor(16).to('cpu', dtype=torch.float32) + model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0504.pt')) model.eval() # inference - initial = 999469 + initial = 999917 last = initial - start_time = '2025-03-11 15:03:31' - for i in range(1, 64): - hour = i / 4.2342 + start_time = '2025-03-11 18:43:52' + for i in range(1, 48): + hour = i / 30 sec = hour * 3600 time_d = np.log2(sec) data = [time_d, np.log2(initial+1), # time_delta, current_views - 6.319244, 6.96288, 7.04251, 8.38551, 7.648974, 9.061098, 9.147728, 10.07276, 10.653134, 10.092601, 12.008604, 11.676683, 13.230796, # grows_feat - 0.627442, 0.232492, 24.778674 # time_feat + 5.231997, 6.473876, 7.063624, 7.026946, 6.9753, 8.599954, 9.448747, 7.236474, 10.881226, 12.128971, 13.351179, # grows_feat + 0.7798611111, 0.2541666667, 24.778674 # time_feat ] np_arr = np.array([data]) tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32) diff --git a/pred/train.py b/pred/train.py index 765f63c..024ac51 100644 --- a/pred/train.py +++ b/pred/train.py @@ -55,14 +55,14 @@ def train(model, dataloader, device, epochs=100): t = float(torch.exp2(targets[r])) - 1 o = float(torch.exp2(outputs[r])) - 1 d = features[r].cpu().numpy()[0] - speed = np.exp2(features[r].cpu().numpy()[5]) / 24 + speed = np.exp2(features[r].cpu().numpy()[6]) / 6 time_diff = np.exp2(d) / 3600 inc = speed * time_diff model_error = abs(t - o) reg_error = abs(inc - t) if model_error < reg_error: good += 1 - # print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {good/samples_count*100:.1f}%") + #print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {good/samples_count*100:.1f}%") writer.add_scalar('Train/WinRate', good/samples_count, global_step) print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(dataloader):.4f}")