cvsa/ml/pred/dataset.py
alikia2x 636c5e25cb
ref: move ML stuff
add: .idea to VCS, the refactor guide
2025-03-29 14:13:15 +08:00

178 lines
7.5 KiB
Python

import os
import json
import random
import bisect
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import datetime
class VideoPlayDataset(Dataset):
def __init__(self, data_dir, publish_time_path, term='long', seed=42):
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
self.data_dir = data_dir
self.series_dict = self._load_and_process_data(publish_time_path)
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
self.term = term
# Set time window based on term
self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600
MINUTE = 60
HOUR = 3600
DAY = 24 * HOUR
if term == 'long':
self.feature_windows = [
1 * HOUR,
6 * HOUR,
1 *DAY,
3 * DAY,
7 * DAY,
30 * DAY,
100 * DAY
]
else:
self.feature_windows = [
( 15 * MINUTE, 0 * MINUTE),
( 40 * MINUTE, 0 * MINUTE),
( 1 * HOUR, 0 * HOUR),
( 2 * HOUR, 1 * HOUR),
( 3 * HOUR, 2 * HOUR),
( 3 * HOUR, 0 * HOUR),
#( 6 * HOUR, 3 * HOUR),
( 6 * HOUR, 0 * HOUR),
(18 * HOUR, 12 * HOUR),
#( 1 * DAY, 6 * HOUR),
( 1 * DAY, 0 * DAY),
#( 2 * DAY, 1 * DAY),
( 3 * DAY, 0 * DAY),
#( 4 * DAY, 1 * DAY),
( 7 * DAY, 0 * DAY)
]
def _extract_features(self, series, current_idx, target_idx):
current_time = series['abs_time'][current_idx]
current_play = series['play_count'][current_idx]
dt = datetime.datetime.fromtimestamp(current_time)
if self.term == 'long':
time_features = [
np.log2(max(current_time - series['create_time'], 1))
]
else:
time_features = [
(dt.hour * 3600 + dt.minute * 60 + dt.second) / 86400,
(dt.weekday() * 24 + dt.hour) / 168,
np.log2(max(current_time - series['create_time'], 1))
]
growth_features = []
if self.term == 'long':
for window in self.feature_windows:
prev_time = current_time - window
prev_idx = self._get_nearest_value(series, prev_time, current_idx)
if prev_idx is not None:
time_diff = current_time - series['abs_time'][prev_idx]
play_diff = current_play - series['play_count'][prev_idx]
scaled_diff = play_diff / (time_diff / window) if time_diff > 0 else 0.0
else:
scaled_diff = 0.0
growth_features.append(np.log2(max(scaled_diff, 1)))
else:
for window_start, window_end in self.feature_windows:
prev_time_start = current_time - window_start
prev_time_end = current_time - window_end # window_end is typically 0
prev_idx_start = self._get_nearest_value(series, prev_time_start, current_idx)
prev_idx_end = self._get_nearest_value(series, prev_time_end, current_idx)
if prev_idx_start is not None and prev_idx_end is not None:
time_diff = series['abs_time'][prev_idx_end] - series['abs_time'][prev_idx_start]
play_diff = series['play_count'][prev_idx_end] - series['play_count'][prev_idx_start]
scaled_diff = play_diff / (time_diff / (window_start - window_end)) if time_diff > 0 else 0.0
else:
scaled_diff = 0.0
growth_features.append(np.log2(max(scaled_diff, 1)))
time_diff = series['abs_time'][target_idx] - current_time
return [np.log2(max(time_diff, 1))] + [np.log2(current_play + 1)] + growth_features + time_features
def _load_and_process_data(self, publish_time_path):
publish_df = pd.read_csv(publish_time_path)
publish_df['published_at'] = pd.to_datetime(publish_df['published_at'])
publish_dict = dict(zip(publish_df['aid'], publish_df['published_at']))
series_dict = {}
for filename in os.listdir(self.data_dir):
if not filename.endswith('.json'):
continue
with open(os.path.join(self.data_dir, filename), 'r') as f:
data = json.load(f)
if 'code' in data:
continue
for item in data:
aid = item['aid']
published_time = pd.to_datetime(publish_dict[aid]).timestamp()
if aid not in series_dict:
series_dict[aid] = {
'abs_time': [],
'play_count': [],
'create_time': published_time
}
series_dict[aid]['abs_time'].append(item['added'])
series_dict[aid]['play_count'].append(item['view'])
# Sort each series by absolute time
for aid in series_dict:
sorted_indices = sorted(range(len(series_dict[aid]['abs_time'])),
key=lambda k: series_dict[aid]['abs_time'][k])
series_dict[aid]['abs_time'] = [series_dict[aid]['abs_time'][i] for i in sorted_indices]
series_dict[aid]['play_count'] = [series_dict[aid]['play_count'][i] for i in sorted_indices]
return series_dict
def __len__(self):
return 100000 # Virtual length for sampling
def _get_nearest_value(self, series, target_time, current_idx):
times = series['abs_time']
pos = bisect.bisect_right(times, target_time, 0, current_idx + 1)
candidates = []
if pos > 0:
candidates.append(pos - 1)
if pos <= current_idx:
candidates.append(pos)
if not candidates:
return None
closest_idx = min(candidates, key=lambda i: abs(times[i] - target_time))
return closest_idx
def __getitem__(self, _idx):
while True:
series = random.choice(self.valid_series)
if len(series['abs_time']) < 2:
continue
current_idx = random.randint(0, len(series['abs_time']) - 2)
current_time = series['abs_time'][current_idx]
max_target_time = current_time + self.time_window
candidate_indices = []
for j in range(current_idx + 1, len(series['abs_time'])):
if series['abs_time'][j] > max_target_time:
break
candidate_indices.append(j)
if not candidate_indices:
continue
target_idx = random.choice(candidate_indices)
break
current_play = series['play_count'][current_idx]
target_play = series['play_count'][target_idx]
target_delta = max(target_play - current_play, 0)
return {
'features': torch.FloatTensor(self._extract_features(series, current_idx, target_idx)),
'target': torch.log2(torch.FloatTensor([target_delta]) + 1)
}
def collate_fn(batch):
return {
'features': torch.stack([x['features'] for x in batch]),
'targets': torch.stack([x['target'] for x in batch])
}