update: the features for the model
This commit is contained in:
parent
a6211782cb
commit
bf919da1ea
@ -9,21 +9,49 @@ from torch.utils.data import Dataset
|
||||
import datetime
|
||||
|
||||
class VideoPlayDataset(Dataset):
|
||||
def __init__(self, data_dir, publish_time_path, term='long'):
|
||||
def __init__(self, data_dir, publish_time_path, term='long', seed=42):
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.series_dict = self._load_and_process_data(publish_time_path)
|
||||
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
||||
self.term = term
|
||||
# Set time window based on term
|
||||
self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600
|
||||
MINUTE = 60
|
||||
HOUR = 3600
|
||||
DAY = 24 * HOUR
|
||||
|
||||
if term == 'long':
|
||||
self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600, 30*24*3600, 100*24*3600]
|
||||
self.feature_windows = [
|
||||
1 * HOUR,
|
||||
6 * HOUR,
|
||||
1 *DAY,
|
||||
3 * DAY,
|
||||
7 * DAY,
|
||||
30 * DAY,
|
||||
100 * DAY
|
||||
]
|
||||
else:
|
||||
self.feature_windows = [
|
||||
(3600, 0), (7200, 3600), (10800, 7200), (10800, 0),
|
||||
(21600, 10800), (21600, 0), (64800, 43200), (86400, 21600),
|
||||
(86400, 0), (172800, 86400), (259200, 0), (345600, 86400),
|
||||
(604800, 0)
|
||||
( 15 * MINUTE, 0 * MINUTE),
|
||||
( 40 * MINUTE, 0 * MINUTE),
|
||||
( 1 * HOUR, 0 * HOUR),
|
||||
( 2 * HOUR, 1 * HOUR),
|
||||
( 3 * HOUR, 2 * HOUR),
|
||||
( 3 * HOUR, 0 * HOUR),
|
||||
#( 6 * HOUR, 3 * HOUR),
|
||||
( 6 * HOUR, 0 * HOUR),
|
||||
(18 * HOUR, 12 * HOUR),
|
||||
#( 1 * DAY, 6 * HOUR),
|
||||
( 1 * DAY, 0 * DAY),
|
||||
#( 2 * DAY, 1 * DAY),
|
||||
( 3 * DAY, 0 * DAY),
|
||||
#( 4 * DAY, 1 * DAY),
|
||||
( 7 * DAY, 0 * DAY)
|
||||
]
|
||||
|
||||
def _extract_features(self, series, current_idx, target_idx):
|
||||
|
@ -4,20 +4,20 @@ from model import CompactPredictor
|
||||
import torch
|
||||
|
||||
def main():
|
||||
model = CompactPredictor(18).to('cpu', dtype=torch.float32)
|
||||
model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0407.pt'))
|
||||
model = CompactPredictor(16).to('cpu', dtype=torch.float32)
|
||||
model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0504.pt'))
|
||||
model.eval()
|
||||
# inference
|
||||
initial = 999469
|
||||
initial = 999917
|
||||
last = initial
|
||||
start_time = '2025-03-11 15:03:31'
|
||||
for i in range(1, 64):
|
||||
hour = i / 4.2342
|
||||
start_time = '2025-03-11 18:43:52'
|
||||
for i in range(1, 48):
|
||||
hour = i / 30
|
||||
sec = hour * 3600
|
||||
time_d = np.log2(sec)
|
||||
data = [time_d, np.log2(initial+1), # time_delta, current_views
|
||||
6.319244, 6.96288, 7.04251, 8.38551, 7.648974, 9.061098, 9.147728, 10.07276, 10.653134, 10.092601, 12.008604, 11.676683, 13.230796, # grows_feat
|
||||
0.627442, 0.232492, 24.778674 # time_feat
|
||||
5.231997, 6.473876, 7.063624, 7.026946, 6.9753, 8.599954, 9.448747, 7.236474, 10.881226, 12.128971, 13.351179, # grows_feat
|
||||
0.7798611111, 0.2541666667, 24.778674 # time_feat
|
||||
]
|
||||
np_arr = np.array([data])
|
||||
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)
|
||||
|
@ -55,14 +55,14 @@ def train(model, dataloader, device, epochs=100):
|
||||
t = float(torch.exp2(targets[r])) - 1
|
||||
o = float(torch.exp2(outputs[r])) - 1
|
||||
d = features[r].cpu().numpy()[0]
|
||||
speed = np.exp2(features[r].cpu().numpy()[5]) / 24
|
||||
speed = np.exp2(features[r].cpu().numpy()[6]) / 6
|
||||
time_diff = np.exp2(d) / 3600
|
||||
inc = speed * time_diff
|
||||
model_error = abs(t - o)
|
||||
reg_error = abs(inc - t)
|
||||
if model_error < reg_error:
|
||||
good += 1
|
||||
# print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {good/samples_count*100:.1f}%")
|
||||
#print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {good/samples_count*100:.1f}%")
|
||||
writer.add_scalar('Train/WinRate', good/samples_count, global_step)
|
||||
|
||||
print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(dataloader):.4f}")
|
||||
|
Loading…
Reference in New Issue
Block a user