update: the features for the model

This commit is contained in:
alikia2x (寒寒) 2025-03-15 05:07:49 +08:00
parent a6211782cb
commit bf919da1ea
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
3 changed files with 44 additions and 16 deletions

View File

@ -9,21 +9,49 @@ from torch.utils.data import Dataset
import datetime
class VideoPlayDataset(Dataset):
def __init__(self, data_dir, publish_time_path, term='long'):
def __init__(self, data_dir, publish_time_path, term='long', seed=42):
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
self.data_dir = data_dir
self.series_dict = self._load_and_process_data(publish_time_path)
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
self.term = term
# Set time window based on term
self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600
MINUTE = 60
HOUR = 3600
DAY = 24 * HOUR
if term == 'long':
self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600, 30*24*3600, 100*24*3600]
self.feature_windows = [
1 * HOUR,
6 * HOUR,
1 *DAY,
3 * DAY,
7 * DAY,
30 * DAY,
100 * DAY
]
else:
self.feature_windows = [
(3600, 0), (7200, 3600), (10800, 7200), (10800, 0),
(21600, 10800), (21600, 0), (64800, 43200), (86400, 21600),
(86400, 0), (172800, 86400), (259200, 0), (345600, 86400),
(604800, 0)
( 15 * MINUTE, 0 * MINUTE),
( 40 * MINUTE, 0 * MINUTE),
( 1 * HOUR, 0 * HOUR),
( 2 * HOUR, 1 * HOUR),
( 3 * HOUR, 2 * HOUR),
( 3 * HOUR, 0 * HOUR),
#( 6 * HOUR, 3 * HOUR),
( 6 * HOUR, 0 * HOUR),
(18 * HOUR, 12 * HOUR),
#( 1 * DAY, 6 * HOUR),
( 1 * DAY, 0 * DAY),
#( 2 * DAY, 1 * DAY),
( 3 * DAY, 0 * DAY),
#( 4 * DAY, 1 * DAY),
( 7 * DAY, 0 * DAY)
]
def _extract_features(self, series, current_idx, target_idx):

View File

@ -4,20 +4,20 @@ from model import CompactPredictor
import torch
def main():
model = CompactPredictor(18).to('cpu', dtype=torch.float32)
model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0407.pt'))
model = CompactPredictor(16).to('cpu', dtype=torch.float32)
model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0504.pt'))
model.eval()
# inference
initial = 999469
initial = 999917
last = initial
start_time = '2025-03-11 15:03:31'
for i in range(1, 64):
hour = i / 4.2342
start_time = '2025-03-11 18:43:52'
for i in range(1, 48):
hour = i / 30
sec = hour * 3600
time_d = np.log2(sec)
data = [time_d, np.log2(initial+1), # time_delta, current_views
6.319244, 6.96288, 7.04251, 8.38551, 7.648974, 9.061098, 9.147728, 10.07276, 10.653134, 10.092601, 12.008604, 11.676683, 13.230796, # grows_feat
0.627442, 0.232492, 24.778674 # time_feat
5.231997, 6.473876, 7.063624, 7.026946, 6.9753, 8.599954, 9.448747, 7.236474, 10.881226, 12.128971, 13.351179, # grows_feat
0.7798611111, 0.2541666667, 24.778674 # time_feat
]
np_arr = np.array([data])
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)

View File

@ -55,14 +55,14 @@ def train(model, dataloader, device, epochs=100):
t = float(torch.exp2(targets[r])) - 1
o = float(torch.exp2(outputs[r])) - 1
d = features[r].cpu().numpy()[0]
speed = np.exp2(features[r].cpu().numpy()[5]) / 24
speed = np.exp2(features[r].cpu().numpy()[6]) / 6
time_diff = np.exp2(d) / 3600
inc = speed * time_diff
model_error = abs(t - o)
reg_error = abs(inc - t)
if model_error < reg_error:
good += 1
# print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {good/samples_count*100:.1f}%")
#print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {good/samples_count*100:.1f}%")
writer.add_scalar('Train/WinRate', good/samples_count, global_step)
print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(dataloader):.4f}")