From 1de8d85d2bca4405ad282a4fd97f629ddac7db51 Mon Sep 17 00:00:00 2001 From: alikia2x Date: Thu, 13 Mar 2025 21:07:17 +0800 Subject: [PATCH] update: the stable pred model --- .gitignore | 2 + pred/1 | 65 ---------- pred/{1.py => crawler.py} | 8 +- pred/dataset.py | 243 ++++++++++++-------------------------- pred/inference.py | 17 +++ pred/model.py | 194 +++--------------------------- pred/train.py | 133 +++++++++++---------- 7 files changed, 186 insertions(+), 476 deletions(-) delete mode 100644 pred/1 rename pred/{1.py => crawler.py} (71%) create mode 100644 pred/inference.py diff --git a/.gitignore b/.gitignore index b27a6b6..31d6ddf 100644 --- a/.gitignore +++ b/.gitignore @@ -79,6 +79,8 @@ node_modules/ logs/ __pycache__ filter/runs +pred/runs +pred/checkpoints data/ filter/checkpoints scripts diff --git a/pred/1 b/pred/1 deleted file mode 100644 index 8b9c350..0000000 --- a/pred/1 +++ /dev/null @@ -1,65 +0,0 @@ -1151133233 -61967870 -977494472 -891815995 -375265162 -341686360 -2228953 -1951059019 -799277283 -844610791 -1706212240 -339432 -243913657 -16576108 -583566710 -802536340 -2976394 -8321047 -261045912 -381806 -1203136639 -316228425 -257550414 -242976248 -9230106 -517962327 -752662232 -771373147 -63924898 -221567994 -840428043 -78978783 -24990703 -820756 -27171791 -80473511 -847707089 -418226861 -11757544 -232040007 -2371972 -84183673 -829450 -6844720 -39129342 -1203992885 -800408956 -316720732 -33139201 -860855406 -4497808 -25432055 -7366509 -40841777 -1656397450 -371830092 -799978121 -50897913 -674312444 -651329836 -875035826 -469433434 -58814955 -33044780 -946091445 \ No newline at end of file diff --git a/pred/1.py b/pred/crawler.py similarity index 71% rename from pred/1.py rename to pred/crawler.py index 8d8064c..53008d8 100644 --- a/pred/1.py +++ b/pred/crawler.py @@ -1,15 +1,19 @@ +import os import requests import json import time -with open("1", "r") as fp: +with open("./pred/2", "r") as fp: raw = fp.readlines() aids = [ int(x.strip()) for x in raw ] for aid in aids: + if os.path.exists(f"./data/pred/{aid}.json"): + continue url = f"https://api.bunnyxt.com/tdd/v2/video/{aid}/record?last_count=5000" r = requests.get(url) data = r.json() with open (f"./data/pred/{aid}.json", "w") as fp: json.dump(data, fp, ensure_ascii=False, indent=4) - time.sleep(5) \ No newline at end of file + time.sleep(5) + print(aid) \ No newline at end of file diff --git a/pred/dataset.py b/pred/dataset.py index 1de35c1..dfa7614 100644 --- a/pred/dataset.py +++ b/pred/dataset.py @@ -1,206 +1,109 @@ +# dataset.py import os import json import random -import pandas as pd import numpy as np -from torch.utils.data import Dataset -from datetime import datetime +import pandas as pd import torch +from torch.utils.data import Dataset +import datetime class VideoPlayDataset(Dataset): - def __init__(self, data_dir, publish_time_path, - min_seq_len=6, max_seq_len=200, - min_forecast_span=60, max_forecast_span=604800): - """ - 改进后的数据集类,支持非等间隔时间序列 - :param data_dir: JSON文件目录 - :param publish_time_path: 发布时间CSV路径 - :param min_seq_len: 最小历史数据点数 - :param max_seq_len: 最大历史数据点数 - :param min_forecast_span: 最小预测时间跨度(秒) - :param max_forecast_span: 最大预测时间跨度(秒) - """ + def __init__(self, data_dir, publish_time_path, max_future_days=7): self.data_dir = data_dir - self.min_seq_len = min_seq_len - self.max_seq_len = max_seq_len - self.min_forecast_span = min_forecast_span - self.max_forecast_span = max_forecast_span - self.series_dict = self._load_and_process_data(data_dir, publish_time_path) - self.valid_series = self._generate_valid_series() + self.max_future_seconds = max_future_days * 86400 + self.series_dict = self._load_and_process_data(publish_time_path) + self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1] + self.feature_windows = [3600, 6*3600, 24*3600, 3*24*3600, 7*24*3600] # 1h,6h,24h,3d,7d - def _load_and_process_data(self, data_dir, publish_time_path): + def _extract_features(self, series, current_idx, target_idx): + """提取增量特征""" + current_time = series['abs_time'][current_idx] + current_play = series['play_count'][current_idx] + dt = datetime.datetime.fromtimestamp(current_time) + # 时间特征 + time_features = [ + dt.hour / 24, (dt.weekday() + 1) / 7, + np.log2(max(current_time - series['create_time'],1)) + ] + + # 窗口增长特征(增量) + growth_features = [] + for window in self.feature_windows: + prev_time = current_time - window + prev_idx = self._get_nearest_value(series, prev_time, current_idx) + if prev_idx is not None: + time_diff = current_time - series['abs_time'][prev_idx] + play_diff = current_play - series['play_count'][prev_idx] + scaled_diff = play_diff / (time_diff / window) if time_diff > 0 else 0.0 + else: + scaled_diff = 0.0 + growth_features.append(np.log2(max(scaled_diff,1))) + + time_diff = series['abs_time'][target_idx] - series['abs_time'][current_idx] + + return [np.log2(max(time_diff,1))] + [np.log2(current_play + 1)] + growth_features + time_features + + def _load_and_process_data(self, publish_time_path): # 加载发布时间数据 publish_df = pd.read_csv(publish_time_path) publish_df['published_at'] = pd.to_datetime(publish_df['published_at']) publish_dict = dict(zip(publish_df['aid'], publish_df['published_at'])) - - # 加载并处理JSON数据 series_dict = {} - for filename in os.listdir(data_dir): + for filename in os.listdir(self.data_dir): if not filename.endswith('.json'): continue - filepath = os.path.join(data_dir, filename) - with open(filepath, 'r', encoding='utf-8') as f: - json_data = json.load(f) - for item in json_data: + with open(os.path.join(self.data_dir, filename), 'r') as f: + data = json.load(f) + if 'code' in data: + continue + for item in data: aid = item['aid'] - if aid not in publish_dict: - continue - - # 计算相对时间 - added_time = datetime.fromtimestamp(item['added']) - published_time = publish_dict[aid] - rel_time = (added_time - published_time).total_seconds() - - # 按视频组织数据 + published_time = pd.to_datetime(publish_dict[aid]).timestamp() if aid not in series_dict: series_dict[aid] = { 'abs_time': [], - 'rel_time': [], - 'play_count': [] + 'play_count': [], + 'create_time': published_time } - series_dict[aid]['abs_time'].append(item['added']) - series_dict[aid]['rel_time'].append(rel_time) series_dict[aid]['play_count'].append(item['view']) - - # 按时间排序并计算时间间隔 - for aid in series_dict: - # 按时间排序 - sorted_idx = np.argsort(series_dict[aid]['abs_time']) - for key in ['abs_time', 'rel_time', 'play_count']: - series_dict[aid][key] = np.array(series_dict[aid][key])[sorted_idx] - - # 计算时间间隔特征 - abs_time_arr = series_dict[aid]['abs_time'] - time_deltas = np.diff(abs_time_arr, prepend=abs_time_arr[0]) - series_dict[aid]['time_delta'] = time_deltas - return series_dict - def _generate_valid_series(self): - # 生成有效数据序列 - valid_series = [] - for aid in self.series_dict: - series = self.series_dict[aid] - n_points = len(series['play_count']) - - # 过滤数据量不足的视频 - if n_points < self.min_seq_len + 1: - continue - - valid_series.append({ - 'aid': aid, - 'length': n_points, - 'abs_time': series['abs_time'], - 'rel_time': series['rel_time'], - 'play_count': series['play_count'], - 'time_delta': series['time_delta'] - }) - return valid_series - def __len__(self): - return sum(s['length'] - self.min_seq_len for s in self.valid_series) + return 100000 # 使用虚拟长度实现无限采样 + + def _get_nearest_value(self, series, target_time, current_idx): + """获取指定时间前最近的数据点""" + min_diff = float('inf') + for i in range(current_idx + 1, len(series['abs_time']), 1): + diff = abs(series['abs_time'][i] - target_time) + if diff < min_diff: + min_diff = diff + else: + return i - 1 + return None def __getitem__(self, idx): - # 随机选择视频序列 series = random.choice(self.valid_series) - max_start = series['length'] - self.min_seq_len - 1 - start_idx = random.randint(0, max_start) + current_idx = random.randint(0, len(series['abs_time'])-2) + target_idx = random.randint(max(0, current_idx-50), current_idx) - # 动态确定历史窗口长度 - seq_len = random.randint(self.min_seq_len, min(self.max_seq_len, series['length'] - start_idx - 1)) - end_idx = start_idx + seq_len - - # 提取历史窗口特征 - hist_slice = slice(start_idx, end_idx) - x_play = np.log1p(series['play_count'][hist_slice]) - x_abs_time = series['abs_time'][hist_slice] - x_rel_time = series['rel_time'][hist_slice] - x_time_delta = series['time_delta'][hist_slice] - - # 生成预测目标(动态时间跨度) - forecast_span = random.randint(self.min_forecast_span, self.max_forecast_span) - target_time = x_abs_time[-1] + forecast_span - - # 寻找实际目标点(处理数据间隙) - future_times = series['abs_time'][end_idx:] - future_plays = series['play_count'][end_idx:] - - # 找到第一个超过目标时间的点 - target_idx = np.searchsorted(future_times, target_time) - if target_idx >= len(future_times): - # 若超出数据范围,取最后一个点 - y_play = future_plays[-1] if len(future_plays) > 0 else x_play[-1] - actual_span = future_times[-1] - x_abs_time[-1] if len(future_times) > 0 else self.max_forecast_span - else: - y_play = future_plays[target_idx] - actual_span = future_times[target_idx] - x_abs_time[-1] + # 提取特征 + features = self._extract_features(series, current_idx, target_idx) - y_play_val = np.log1p(y_play) - - # 构造时间相关特征 - time_features = np.stack([ - x_abs_time, - x_rel_time, - x_time_delta, - np.log1p(x_time_delta), # 对数变换处理长尾分布 - (x_time_delta > 3600).astype(float) # 间隔是否大于1小时 - ], axis=-1) + # 目标值:未来播放量增量 + current_play = series['play_count'][current_idx] + target_play = series['play_count'][target_idx] + target_delta = max(target_play - current_play, 0) # 增量 return { - 'x_play': torch.FloatTensor(x_play), - 'x_time_feat': torch.FloatTensor(time_features), - 'y_play': torch.FloatTensor([y_play_val]), - 'forecast_span': torch.FloatTensor([actual_span]) + 'features': torch.FloatTensor(features), + 'target': torch.log2(torch.FloatTensor([target_delta]) + 1) # 输出增量 } def collate_fn(batch): - """动态填充处理""" - max_len = max(item['x_play'].shape[0] for item in batch) - - padded_batch = { - 'x_play': [], - 'x_time_feat': [], - 'y_play': [], - 'forecast_span': [], - 'padding_mask': [] - } - - for item in batch: - seq_len = item['x_play'].shape[0] - pad_len = max_len - seq_len - - # 填充播放量数据 - padded_play = torch.cat([ - item['x_play'], - torch.zeros(pad_len) - ]) - padded_batch['x_play'].append(padded_play) - - # 填充时间特征 - padded_time_feat = torch.cat([ - item['x_time_feat'], - torch.zeros(pad_len, item['x_time_feat'].shape[1]) - ]) - padded_batch['x_time_feat'].append(padded_time_feat) - - # 创建padding mask - mask = torch.cat([ - torch.ones(seq_len), - torch.zeros(pad_len) - ]) - padded_batch['padding_mask'].append(mask.bool()) - - # 其他字段 - padded_batch['y_play'].append(item['y_play']) - padded_batch['forecast_span'].append(item['forecast_span']) - - # 转换为张量 - padded_batch['x_play'] = torch.stack(padded_batch['x_play']) - padded_batch['x_time_feat'] = torch.stack(padded_batch['x_time_feat']) - padded_batch['y_play'] = torch.stack(padded_batch['y_play']) - padded_batch['forecast_span'] = torch.stack(padded_batch['forecast_span']) - padded_batch['padding_mask'] = torch.stack(padded_batch['padding_mask']) - - return padded_batch + return { + 'features': torch.stack([x['features'] for x in batch]), + 'targets': torch.stack([x['target'] for x in batch]) + } \ No newline at end of file diff --git a/pred/inference.py b/pred/inference.py new file mode 100644 index 0000000..3efb34a --- /dev/null +++ b/pred/inference.py @@ -0,0 +1,17 @@ +import numpy as np +from model import CompactPredictor +import torch + +def main(): + model = CompactPredictor(10).to('cpu', dtype=torch.float32) + model.load_state_dict(torch.load('play_predictor.pth')) + model.eval() + # inference + data = [3,3.9315974229,5.4263146604,9.4958550269,10.9203528554,11.5835529305,13.0426853722,0.7916666667,0.2857142857,24.7794093257] + np_arr = np.array([data]) + tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32) + output = model(tensor) + print(output) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/pred/model.py b/pred/model.py index 1754087..ee98bf6 100644 --- a/pred/model.py +++ b/pred/model.py @@ -1,182 +1,24 @@ -import torch import torch.nn as nn import torch.nn.functional as F -import math -class TimeEmbedding(nn.Module): - """时间特征编码模块""" - def __init__(self, embed_dim): +class CompactPredictor(nn.Module): + def __init__(self, input_size): super().__init__() - self.embed_dim = embed_dim - - self.norm = nn.LayerNorm(5) - - # 时间特征编码(适配新的5维时间特征) - self.time_encoder = nn.Sequential( - nn.Linear(5, 64), # 输入维度对应x_time_feat的5个特征 - nn.GELU(), - nn.LayerNorm(64), - nn.Linear(64, embed_dim) + self.net = nn.Sequential( + nn.BatchNorm1d(input_size), + nn.Linear(input_size, 256), + nn.LeakyReLU(0.1), + nn.Dropout(0.3), + nn.Linear(256, 128), + nn.LeakyReLU(0.1), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Tanh(), # 使用Tanh限制输出范围 + nn.Linear(64, 1) ) - - def forward(self, time_feat): - """ - time_feat: 时间特征 (batch, seq_len, 5) - """ - time_feat = self.norm(time_feat) # 应用归一化 - return self.time_encoder(time_feat) + # 初始化最后一层为接近零的值 + nn.init.uniform_(self.net[-1].weight, -0.01, 0.01) + nn.init.constant_(self.net[-1].bias, 0.0) - -class MultiScaleEncoder(nn.Module): - """多尺度特征编码器""" - def __init__(self, input_dim, d_model, nhead, conv_kernels=[3, 7, 23]): - super().__init__() - self.d_model = d_model - - self.conv_branches = nn.ModuleList([ - nn.Sequential( - nn.Conv1d(input_dim, d_model, kernel_size=k, padding=k//2), - nn.GELU(), - ) for k in conv_kernels - ]) - - # 添加 LayerNorm 到单独的列表中 - self.layer_norms = nn.ModuleList([nn.LayerNorm(d_model) for _ in conv_kernels]) - - # Transformer编码器 - self.transformer = nn.TransformerEncoder( - nn.TransformerEncoderLayer( - d_model, - nhead, - dim_feedforward=d_model*4, - batch_first=True # 修改为batch_first - ), - num_layers=4 - ) - - # 特征融合层 - self.fusion = nn.Linear(d_model*(len(conv_kernels)+1), d_model) - - def forward(self, x, padding_mask=None): - """ - x: 输入特征 (batch, seq_len, input_dim) - padding_mask: 填充掩码 (batch, seq_len) - """ - - # 卷积分支处理 - conv_features = [] - x_conv = x.permute(0, 2, 1) # (batch, input_dim, seq_len) - for i, branch in enumerate(self.conv_branches): - feat = branch(x_conv) # 输出形状 (batch, d_model, seq_len) - # 手动转置并应用 LayerNorm - feat = feat.permute(0, 2, 1) # (batch, seq_len, d_model) - feat = self.layer_norms[i](feat) # 应用 LayerNorm - conv_features.append(feat) - - # Transformer分支处理 - trans_feat = self.transformer( - x, - src_key_padding_mask=padding_mask - ) # (batch, seq_len, d_model) - - # 特征拼接与融合 - combined = torch.cat(conv_features + [trans_feat], dim=-1) - fused = self.fusion(combined) # (batch, seq_len, d_model) - - return fused - -class VideoPlayPredictor(nn.Module): - def __init__(self, d_model=256, nhead=8): - super().__init__() - self.d_model = d_model - - # 特征嵌入 - self.time_embed = TimeEmbedding(embed_dim=64) - self.base_embed = nn.Linear(1 + 64, d_model) # 播放量 + 时间特征 - - # 编码器 - self.encoder = MultiScaleEncoder(d_model, d_model, nhead) - - # 时间感知预测头 - self.forecast_head = nn.Sequential( - nn.Linear(2 * d_model + 1, d_model * 4), # 关键修改:输入维度为 2*d_model +1 - nn.GELU(), - nn.Linear(d_model * 4, 1), - nn.ReLU() # 确保输出非负 - ) - - # 上下文提取器 - self.context_extractor = nn.LSTM( - input_size=d_model, - hidden_size=d_model, - num_layers=2, - bidirectional=True, - batch_first=True - ) - - # 初始化参数 - self._init_weights() - - def _init_weights(self): - for name, p in self.named_parameters(): - if 'forecast_head' in name: - if 'weight' in name: - nn.init.xavier_normal_(p, gain=1e-2) # 缩小初始化范围 - elif 'bias' in name: - nn.init.constant_(p, 0.0) - elif p.dim() > 1: - nn.init.xavier_uniform_(p) - - def forward(self, x_play, x_time_feat, padding_mask, forecast_span): - """ - x_play: 历史播放量 (batch, seq_len) - x_time_feat: 时间特征 (batch, seq_len, 5) - padding_mask: 填充掩码 (batch, seq_len) - forecast_span: 预测时间跨度 (batch, 1) - """ - batch_size = x_play.size(0) - - # 时间特征编码 - time_emb = self.time_embed(x_time_feat) # (batch, seq_len, 64) - - # 基础特征拼接 - base_feat = torch.cat([ - x_play.unsqueeze(-1), # (batch, seq_len, 1) - time_emb - ], dim=-1) # (batch, seq_len, 1+64) - - # 投影到模型维度 - embedded = self.base_embed(base_feat) # (batch, seq_len, d_model) - - # 编码特征 - encoded = self.encoder(embedded, padding_mask) # (batch, seq_len, d_model) - - # 提取上下文 - context, _ = self.context_extractor(encoded) # (batch, seq_len, d_model*2) - context = context.mean(dim=1) # (batch, d_model*2) - - # 融合时间跨度特征 - span_feat = torch.log1p(forecast_span) / 10 # 归一化 - combined = torch.cat([ - context, - span_feat - ], dim=-1) # (batch, d_model*2 + 1) - - # 最终预测 - pred = self.forecast_head(combined) # (batch, 1) - - return pred - -class MultiTaskWrapper(nn.Module): - """适配新数据结构的封装""" - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, batch): - return self.model( - batch['x_play'], - batch['x_time_feat'], - batch['padding_mask'], - batch['forecast_span'] - ) + def forward(self, x): + return self.net(x) \ No newline at end of file diff --git a/pred/train.py b/pred/train.py index e73085f..603eb17 100644 --- a/pred/train.py +++ b/pred/train.py @@ -1,76 +1,83 @@ +import random +import time import numpy as np +from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader -from model import MultiTaskWrapper, VideoPlayPredictor import torch -import torch.nn.functional as F from dataset import VideoPlayDataset, collate_fn +from pred.model import CompactPredictor -def train(model, dataloader, epochs=100, device='mps'): - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) +def train(model, dataloader, device, epochs=100): + writer = SummaryWriter(f'./pred/runs/play_predictor_{time.strftime("%Y%m%d_%H%M")}') + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01) + scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, + total_steps=len(dataloader)*epochs) + criterion = torch.nn.MSELoss() - steps = 0 + model.train() + global_step = 0 for epoch in range(epochs): - model.train() - total_loss = 0 - - for batch in dataloader: + total_loss = 0.0 + for batch_idx, batch in enumerate(dataloader): + features = batch['features'].to(device) + targets = batch['targets'].to(device) + optimizer.zero_grad() - - # movel whole batch to device - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - batch[k] = v.to(device) - - # 前向传播 - pred = model(batch) - - y_play = batch['y_play'] - - real = np.expm1(y_play.cpu().detach().numpy()) - yhat = np.expm1(pred.cpu().detach().numpy()) - print("real", [int(real[0][0]), int(real[1][0])]) - print("yhat", [int(yhat[0][0]), int(yhat[1][0])], [float(pred.cpu().detach().numpy()[0][0]), float(pred.cpu().detach().numpy()[1][0])]) - - # 计算加权损失 - weights = torch.log1p(batch['forecast_span']) # 时间越长权重越低 - loss_per_sample = F.huber_loss(pred, y_play, reduction='none') - loss = (loss_per_sample * weights).mean() - - # 反向传播 + outputs = model(features) + loss = criterion(outputs, targets) loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() + scheduler.step() - steps += 1 + total_loss += loss.item() + global_step += 1 + + if global_step % 100 == 0: + writer.add_scalar('Loss/train', loss.item(), global_step) + writer.add_scalar('LR', scheduler.get_last_lr()[0], global_step) + if batch_idx % 50 == 0: + # 监控梯度 + grad_norms = [ + torch.norm(p.grad).item() + for p in model.parameters() if p.grad is not None + ] + writer.add_scalar('Grad/Norm', sum(grad_norms)/len(grad_norms), global_step) + + # 监控参数值 + param_means = [torch.mean(p.data).item() for p in model.parameters()] + writer.add_scalar('Params/Mean', sum(param_means)/len(param_means), global_step) - print(f"Epoch {epoch+1} | Step {steps} | Loss: {loss.item():.4f}") - - scheduler.step() - avg_loss = total_loss / len(dataloader) - print(f"Epoch {epoch+1:03d} | Loss: {avg_loss:.4f}") + samples_count = len(targets) + r = random.randint(0, samples_count-1) + t = float(torch.exp2(targets[r])) - 1 + o = float(torch.exp2(outputs[r])) - 1 + d = features[r].cpu().numpy()[0] + speed = np.exp2(features[r].cpu().numpy()[2]) + time_diff = np.exp2(d) / 3600 + inc = speed * time_diff + model_error = abs(t - o) + reg_error = abs(inc - t) + print(f"{t:07.1f} | {o:07.1f} | {d:07.1f} | {inc:07.1f} | {model_error < reg_error}") + + print(f"Epoch {epoch+1} | Avg Loss: {total_loss/len(dataloader):.4f}") + + writer.close() + return model -# 初始化模型 -device = 'mps' -model = MultiTaskWrapper(VideoPlayPredictor()) -model = model.to(device) - -data_dir = './data/pred' -publish_time_path = './data/pred/publish_time.csv' -dataset = VideoPlayDataset( - data_dir=data_dir, - publish_time_path=publish_time_path, - min_seq_len=2, # 至少2个历史点 - max_seq_len=350, # 最多350个历史点 - min_forecast_span=60, # 预测跨度1分钟到 - max_forecast_span=86400 * 10 # 10天 -) -dataloader = DataLoader( - dataset, - batch_size=2, - shuffle=True, - collate_fn=collate_fn, # 使用自定义collate函数 -) - -# 开始训练 -train(model, dataloader, epochs=20, device=device) \ No newline at end of file +if __name__ == "__main__": + device = 'mps' + + # 初始化数据集和模型 + dataset = VideoPlayDataset('./data/pred', './data/pred/publish_time.csv') + dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn) + + # 获取特征维度 + sample = next(iter(dataloader)) + input_size = sample['features'].shape[1] + + model = CompactPredictor(input_size).to(device) + trained_model = train(model, dataloader, device, epochs=30) + + # 保存模型 + torch.save(trained_model.state_dict(), 'play_predictor.pth') \ No newline at end of file