207 lines
7.9 KiB
Python
207 lines
7.9 KiB
Python
import os
|
|
import json
|
|
import random
|
|
import pandas as pd
|
|
import numpy as np
|
|
from torch.utils.data import Dataset
|
|
from datetime import datetime
|
|
import torch
|
|
|
|
class VideoPlayDataset(Dataset):
|
|
def __init__(self, data_dir, publish_time_path,
|
|
min_seq_len=6, max_seq_len=200,
|
|
min_forecast_span=60, max_forecast_span=604800):
|
|
"""
|
|
改进后的数据集类,支持非等间隔时间序列
|
|
:param data_dir: JSON文件目录
|
|
:param publish_time_path: 发布时间CSV路径
|
|
:param min_seq_len: 最小历史数据点数
|
|
:param max_seq_len: 最大历史数据点数
|
|
:param min_forecast_span: 最小预测时间跨度(秒)
|
|
:param max_forecast_span: 最大预测时间跨度(秒)
|
|
"""
|
|
self.data_dir = data_dir
|
|
self.min_seq_len = min_seq_len
|
|
self.max_seq_len = max_seq_len
|
|
self.min_forecast_span = min_forecast_span
|
|
self.max_forecast_span = max_forecast_span
|
|
self.series_dict = self._load_and_process_data(data_dir, publish_time_path)
|
|
self.valid_series = self._generate_valid_series()
|
|
|
|
def _load_and_process_data(self, data_dir, publish_time_path):
|
|
# 加载发布时间数据
|
|
publish_df = pd.read_csv(publish_time_path)
|
|
publish_df['published_at'] = pd.to_datetime(publish_df['published_at'])
|
|
publish_dict = dict(zip(publish_df['aid'], publish_df['published_at']))
|
|
|
|
# 加载并处理JSON数据
|
|
series_dict = {}
|
|
for filename in os.listdir(data_dir):
|
|
if not filename.endswith('.json'):
|
|
continue
|
|
filepath = os.path.join(data_dir, filename)
|
|
with open(filepath, 'r', encoding='utf-8') as f:
|
|
json_data = json.load(f)
|
|
for item in json_data:
|
|
aid = item['aid']
|
|
if aid not in publish_dict:
|
|
continue
|
|
|
|
# 计算相对时间
|
|
added_time = datetime.fromtimestamp(item['added'])
|
|
published_time = publish_dict[aid]
|
|
rel_time = (added_time - published_time).total_seconds()
|
|
|
|
# 按视频组织数据
|
|
if aid not in series_dict:
|
|
series_dict[aid] = {
|
|
'abs_time': [],
|
|
'rel_time': [],
|
|
'play_count': []
|
|
}
|
|
|
|
series_dict[aid]['abs_time'].append(item['added'])
|
|
series_dict[aid]['rel_time'].append(rel_time)
|
|
series_dict[aid]['play_count'].append(item['view'])
|
|
|
|
# 按时间排序并计算时间间隔
|
|
for aid in series_dict:
|
|
# 按时间排序
|
|
sorted_idx = np.argsort(series_dict[aid]['abs_time'])
|
|
for key in ['abs_time', 'rel_time', 'play_count']:
|
|
series_dict[aid][key] = np.array(series_dict[aid][key])[sorted_idx]
|
|
|
|
# 计算时间间隔特征
|
|
abs_time_arr = series_dict[aid]['abs_time']
|
|
time_deltas = np.diff(abs_time_arr, prepend=abs_time_arr[0])
|
|
series_dict[aid]['time_delta'] = time_deltas
|
|
|
|
return series_dict
|
|
|
|
def _generate_valid_series(self):
|
|
# 生成有效数据序列
|
|
valid_series = []
|
|
for aid in self.series_dict:
|
|
series = self.series_dict[aid]
|
|
n_points = len(series['play_count'])
|
|
|
|
# 过滤数据量不足的视频
|
|
if n_points < self.min_seq_len + 1:
|
|
continue
|
|
|
|
valid_series.append({
|
|
'aid': aid,
|
|
'length': n_points,
|
|
'abs_time': series['abs_time'],
|
|
'rel_time': series['rel_time'],
|
|
'play_count': series['play_count'],
|
|
'time_delta': series['time_delta']
|
|
})
|
|
return valid_series
|
|
|
|
def __len__(self):
|
|
return sum(s['length'] - self.min_seq_len for s in self.valid_series)
|
|
|
|
def __getitem__(self, idx):
|
|
# 随机选择视频序列
|
|
series = random.choice(self.valid_series)
|
|
max_start = series['length'] - self.min_seq_len - 1
|
|
start_idx = random.randint(0, max_start)
|
|
|
|
# 动态确定历史窗口长度
|
|
seq_len = random.randint(self.min_seq_len, min(self.max_seq_len, series['length'] - start_idx - 1))
|
|
end_idx = start_idx + seq_len
|
|
|
|
# 提取历史窗口特征
|
|
hist_slice = slice(start_idx, end_idx)
|
|
x_play = np.log1p(series['play_count'][hist_slice])
|
|
x_abs_time = series['abs_time'][hist_slice]
|
|
x_rel_time = series['rel_time'][hist_slice]
|
|
x_time_delta = series['time_delta'][hist_slice]
|
|
|
|
# 生成预测目标(动态时间跨度)
|
|
forecast_span = random.randint(self.min_forecast_span, self.max_forecast_span)
|
|
target_time = x_abs_time[-1] + forecast_span
|
|
|
|
# 寻找实际目标点(处理数据间隙)
|
|
future_times = series['abs_time'][end_idx:]
|
|
future_plays = series['play_count'][end_idx:]
|
|
|
|
# 找到第一个超过目标时间的点
|
|
target_idx = np.searchsorted(future_times, target_time)
|
|
if target_idx >= len(future_times):
|
|
# 若超出数据范围,取最后一个点
|
|
y_play = future_plays[-1] if len(future_plays) > 0 else x_play[-1]
|
|
actual_span = future_times[-1] - x_abs_time[-1] if len(future_times) > 0 else self.max_forecast_span
|
|
else:
|
|
y_play = future_plays[target_idx]
|
|
actual_span = future_times[target_idx] - x_abs_time[-1]
|
|
|
|
y_play_val = np.log1p(y_play)
|
|
|
|
# 构造时间相关特征
|
|
time_features = np.stack([
|
|
x_abs_time,
|
|
x_rel_time,
|
|
x_time_delta,
|
|
np.log1p(x_time_delta), # 对数变换处理长尾分布
|
|
(x_time_delta > 3600).astype(float) # 间隔是否大于1小时
|
|
], axis=-1)
|
|
|
|
return {
|
|
'x_play': torch.FloatTensor(x_play),
|
|
'x_time_feat': torch.FloatTensor(time_features),
|
|
'y_play': torch.FloatTensor([y_play_val]),
|
|
'forecast_span': torch.FloatTensor([actual_span])
|
|
}
|
|
|
|
def collate_fn(batch):
|
|
"""动态填充处理"""
|
|
max_len = max(item['x_play'].shape[0] for item in batch)
|
|
|
|
padded_batch = {
|
|
'x_play': [],
|
|
'x_time_feat': [],
|
|
'y_play': [],
|
|
'forecast_span': [],
|
|
'padding_mask': []
|
|
}
|
|
|
|
for item in batch:
|
|
seq_len = item['x_play'].shape[0]
|
|
pad_len = max_len - seq_len
|
|
|
|
# 填充播放量数据
|
|
padded_play = torch.cat([
|
|
item['x_play'],
|
|
torch.zeros(pad_len)
|
|
])
|
|
padded_batch['x_play'].append(padded_play)
|
|
|
|
# 填充时间特征
|
|
padded_time_feat = torch.cat([
|
|
item['x_time_feat'],
|
|
torch.zeros(pad_len, item['x_time_feat'].shape[1])
|
|
])
|
|
padded_batch['x_time_feat'].append(padded_time_feat)
|
|
|
|
# 创建padding mask
|
|
mask = torch.cat([
|
|
torch.ones(seq_len),
|
|
torch.zeros(pad_len)
|
|
])
|
|
padded_batch['padding_mask'].append(mask.bool())
|
|
|
|
# 其他字段
|
|
padded_batch['y_play'].append(item['y_play'])
|
|
padded_batch['forecast_span'].append(item['forecast_span'])
|
|
|
|
# 转换为张量
|
|
padded_batch['x_play'] = torch.stack(padded_batch['x_play'])
|
|
padded_batch['x_time_feat'] = torch.stack(padded_batch['x_time_feat'])
|
|
padded_batch['y_play'] = torch.stack(padded_batch['y_play'])
|
|
padded_batch['forecast_span'] = torch.stack(padded_batch['forecast_span'])
|
|
padded_batch['padding_mask'] = torch.stack(padded_batch['padding_mask'])
|
|
|
|
return padded_batch
|