temp: try to modify some features for the pred model
This commit is contained in:
parent
2ed909268e
commit
ba6b8bd5b3
@ -20,7 +20,7 @@ class VideoPlayDataset(Dataset):
|
|||||||
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
||||||
self.term = term
|
self.term = term
|
||||||
# Set time window based on term
|
# Set time window based on term
|
||||||
self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600
|
self.time_window = 1000 * 24 * 3600 if term == 'long' else 3 * 24 * 3600
|
||||||
MINUTE = 60
|
MINUTE = 60
|
||||||
HOUR = 3600
|
HOUR = 3600
|
||||||
DAY = 24 * HOUR
|
DAY = 24 * HOUR
|
||||||
@ -37,7 +37,7 @@ class VideoPlayDataset(Dataset):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.feature_windows = [
|
self.feature_windows = [
|
||||||
( 5 * MINUTE, 0 * MINUTE),
|
#( 5 * MINUTE, 0 * MINUTE),
|
||||||
( 15 * MINUTE, 0 * MINUTE),
|
( 15 * MINUTE, 0 * MINUTE),
|
||||||
( 40 * MINUTE, 0 * MINUTE),
|
( 40 * MINUTE, 0 * MINUTE),
|
||||||
( 1 * HOUR, 0 * HOUR),
|
( 1 * HOUR, 0 * HOUR),
|
||||||
@ -46,7 +46,7 @@ class VideoPlayDataset(Dataset):
|
|||||||
( 3 * HOUR, 0 * HOUR),
|
( 3 * HOUR, 0 * HOUR),
|
||||||
#( 6 * HOUR, 3 * HOUR),
|
#( 6 * HOUR, 3 * HOUR),
|
||||||
( 6 * HOUR, 0 * HOUR),
|
( 6 * HOUR, 0 * HOUR),
|
||||||
(18 * HOUR, 12 * HOUR),
|
#(18 * HOUR, 12 * HOUR),
|
||||||
#( 1 * DAY, 6 * HOUR),
|
#( 1 * DAY, 6 * HOUR),
|
||||||
( 1 * DAY, 0 * DAY),
|
( 1 * DAY, 0 * DAY),
|
||||||
#( 2 * DAY, 1 * DAY),
|
#( 2 * DAY, 1 * DAY),
|
||||||
|
@ -4,20 +4,20 @@ from model import CompactPredictor
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model = CompactPredictor(16).to('cpu', dtype=torch.float32)
|
model = CompactPredictor(15).to('cpu', dtype=torch.float32)
|
||||||
model.load_state_dict(torch.load('./pred/checkpoints/model_20250315_0530.pt'))
|
model.load_state_dict(torch.load('./pred/checkpoints/model_20250320_0045.pt'))
|
||||||
model.eval()
|
model.eval()
|
||||||
# inference
|
# inference
|
||||||
initial = 99906
|
initial = 999704
|
||||||
last = initial
|
last = initial
|
||||||
start_time = '2025-03-16 14:48:42'
|
start_time = '2025-03-19 22:00:42'
|
||||||
for i in range(1, 48):
|
for i in range(1, 48):
|
||||||
hour = i / 4
|
hour = i / 6
|
||||||
sec = hour * 3600
|
sec = hour * 3600
|
||||||
time_d = np.log2(sec)
|
time_d = np.log2(sec)
|
||||||
data = [time_d, np.log2(initial+1), # time_delta, current_views
|
data = [time_d, np.log2(initial+1), # time_delta, current_views
|
||||||
2.456146, 3.562719, 4.106399, 1.0, 1.0, 5.634413, 6.619818, 1.0, 8.608774, 10.19127, 11.412958, # grows_feat
|
4.857981, 6.29067, 6.869476, 6.58392, 6.523051, 8.242355, 8.841574, 10.203909, 11.449314, 12.659556, # grows_feat
|
||||||
0.617153, 0.945308, 22.091431 # time_feat
|
0.916956, 0.416708, 28.003162 # time_feat
|
||||||
]
|
]
|
||||||
np_arr = np.array([data])
|
np_arr = np.array([data])
|
||||||
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)
|
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)
|
||||||
|
@ -38,7 +38,7 @@ def train(model, dataloader, device, epochs=100):
|
|||||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
|
||||||
total_steps=len(dataloader)*30)
|
total_steps=len(dataloader)*30)
|
||||||
# Huber loss
|
# Huber loss
|
||||||
criterion = asymmetricHuberLoss(delta=1.0, beta=2.1)
|
criterion = asymmetricHuberLoss(delta=1.0, beta=2.2)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
global_step = 0
|
global_step = 0
|
||||||
@ -100,7 +100,7 @@ if __name__ == "__main__":
|
|||||||
device = 'mps'
|
device = 'mps'
|
||||||
|
|
||||||
# Initialize dataset and model
|
# Initialize dataset and model
|
||||||
dataset = VideoPlayDataset('./data/pred', './data/pred/publish_time.csv', 'short')
|
dataset = VideoPlayDataset('./data/pred', './data/pred/publish_time.csv', 'short', 712)
|
||||||
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
|
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
|
||||||
|
|
||||||
# Get feature dimension
|
# Get feature dimension
|
||||||
|
Loading…
Reference in New Issue
Block a user