From b21e6da07a65f40aa288f6a145266877c99fe479 Mon Sep 17 00:00:00 2001 From: alikia2x Date: Sat, 8 Feb 2025 00:36:40 +0800 Subject: [PATCH] update: structurize video insertion code, use js-compatible filter model arch --- .gitignore | 27 +++--- deno.json | 88 +++++++++---------- filter/checkpoint_conversion.py | 26 ++++++ filter/dataset.py | 2 +- filter/embedding.py | 31 +++++++ filter/modelV3_1.py | 58 ------------- filter/modelV3_10.py | 97 +++++++++++++++++++++ filter/modelV3_2.py | 58 ------------- filter/modelV3_3.py | 56 ------------ filter/onnx_export.py | 32 +++++++ filter/train.py | 15 ++-- lib/db/allData.ts | 15 ++++ lib/db/schema.d.ts | 9 ++ lib/ml/SentenceTransformer/index.ts | 19 +++++ lib/ml/SentenceTransformer/model.ts | 40 +++++++++ lib/ml/SentenceTransformer/pooling.ts | 34 ++++++++ lib/ml/classifyVideo.ts | 32 +++++++ lib/mq/index.ts | 2 + lib/net/bilibili.d.ts | 117 ++++++++++++++++++++++++++ lib/net/getLatestVideos.ts | 48 +++++++++++ lib/net/getVideoTags.ts | 18 ++++ lib/task/insertLatestVideo.ts | 81 ++++++++++++++++++ lib/utils/formatTimestampToPostgre.ts | 4 + lib/utils/sleep.ts | 3 + src/db/raw/aliyun-fc.mjs | 3 +- src/net/getLatestVideos.ts | 35 -------- test/net/getLatestVideos.test.ts | 25 ++++++ test/net/getVideoTags.test.ts | 28 ++++++ 28 files changed, 731 insertions(+), 272 deletions(-) create mode 100644 filter/checkpoint_conversion.py create mode 100644 filter/embedding.py delete mode 100644 filter/modelV3_1.py create mode 100644 filter/modelV3_10.py delete mode 100644 filter/modelV3_2.py delete mode 100644 filter/modelV3_3.py create mode 100644 filter/onnx_export.py create mode 100644 lib/db/allData.ts create mode 100644 lib/db/schema.d.ts create mode 100644 lib/ml/SentenceTransformer/index.ts create mode 100644 lib/ml/SentenceTransformer/model.ts create mode 100644 lib/ml/SentenceTransformer/pooling.ts create mode 100644 lib/ml/classifyVideo.ts create mode 100644 lib/mq/index.ts create mode 100644 lib/net/bilibili.d.ts create mode 100644 lib/net/getLatestVideos.ts create mode 100644 lib/net/getVideoTags.ts create mode 100644 lib/task/insertLatestVideo.ts create mode 100644 lib/utils/formatTimestampToPostgre.ts create mode 100644 lib/utils/sleep.ts delete mode 100644 src/net/getLatestVideos.ts create mode 100644 test/net/getLatestVideos.test.ts create mode 100644 test/net/getVideoTags.test.ts diff --git a/.gitignore b/.gitignore index d17200a..e3b07d1 100644 --- a/.gitignore +++ b/.gitignore @@ -61,19 +61,6 @@ TEST-results.xml package-lock.json .eslintcache *v8.log -/lib/ - -# project specific -data/main.db -.env -logs/ -__pycache__ -filter/runs -data/filter/eval* -data/filter/train* -filter/checkpoints -data/filter/model_predicted* -scripts # dotenv environment variable files .env @@ -86,3 +73,17 @@ scripts _fresh/ # npm dependencies node_modules/ + + +# project specific +data/main.db +.env +logs/ +__pycache__ +filter/runs +data/filter/eval* +data/filter/train* +filter/checkpoints +data/filter/model_predicted* +scripts +model/ diff --git a/deno.json b/deno.json index 0d0db8f..4594440 100644 --- a/deno.json +++ b/deno.json @@ -1,46 +1,46 @@ { - "lock": false, - "tasks": { - "crawl-raw-bili": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/insertAidsToDB.ts", - "crawl-bili-aids": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/fetchAids.ts", - "check": "deno fmt --check && deno lint && deno check **/*.ts && deno check **/*.tsx", - "cli": "echo \"import '\\$fresh/src/dev/cli.ts'\" | deno run --unstable -A -", - "manifest": "deno task cli manifest $(pwd)", - "start": "deno run -A --watch=static/,routes/ dev.ts", - "build": "deno run -A dev.ts build", - "preview": "deno run -A main.ts", - "update": "deno run -A -r https://fresh.deno.dev/update ." - }, - "lint": { - "rules": { - "tags": ["fresh", "recommended"] - } - }, - "exclude": ["**/_fresh/*"], - "imports": { - "@std/assert": "jsr:@std/assert@1", - "@types/better-sqlite3": "npm:@types/better-sqlite3@^7.6.12", - "axios": "npm:axios@^1.7.9", - "better-sqlite3": "npm:better-sqlite3@^11.7.2", - "$fresh/": "https://deno.land/x/fresh@1.7.3/", - "preact": "https://esm.sh/preact@10.22.0", - "preact/": "https://esm.sh/preact@10.22.0/", - "@preact/signals": "https://esm.sh/*@preact/signals@1.2.2", - "@preact/signals-core": "https://esm.sh/*@preact/signals-core@1.5.1", - "tailwindcss": "npm:tailwindcss@3.4.1", - "tailwindcss/": "npm:/tailwindcss@3.4.1/", - "tailwindcss/plugin": "npm:/tailwindcss@3.4.1/plugin.js", - "$std/": "https://deno.land/std@0.216.0/" - }, - "compilerOptions": { - "jsx": "react-jsx", - "jsxImportSource": "preact" - }, - "nodeModulesDir": "auto", - "fmt": { - "useTabs": true, - "lineWidth": 120, - "indentWidth": 4, - "semiColons": true - } + "lock": false, + "tasks": { + "crawl-raw-bili": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/insertAidsToDB.ts", + "crawl-bili-aids": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/fetchAids.ts", + "check": "deno fmt --check && deno lint && deno check **/*.ts && deno check **/*.tsx", + "cli": "echo \"import '\\$fresh/src/dev/cli.ts'\" | deno run --unstable -A -", + "manifest": "deno task cli manifest $(pwd)", + "start": "deno run -A --watch=static/,routes/ dev.ts", + "build": "deno run -A dev.ts build", + "preview": "deno run -A main.ts", + "update": "deno run -A -r https://fresh.deno.dev/update ." + }, + "lint": { + "rules": { + "tags": ["fresh", "recommended"] + } + }, + "exclude": ["**/_fresh/*"], + "imports": { + "@std/assert": "jsr:@std/assert@1", + "$fresh/": "https://deno.land/x/fresh@1.7.3/", + "preact": "https://esm.sh/preact@10.22.0", + "preact/": "https://esm.sh/preact@10.22.0/", + "@preact/signals": "https://esm.sh/*@preact/signals@1.2.2", + "@preact/signals-core": "https://esm.sh/*@preact/signals-core@1.5.1", + "tailwindcss": "npm:tailwindcss@3.4.1", + "tailwindcss/": "npm:/tailwindcss@3.4.1/", + "tailwindcss/plugin": "npm:/tailwindcss@3.4.1/plugin.js", + "$std/": "https://deno.land/std@0.216.0/", + "@huggingface/transformers": "npm:@huggingface/transformers@3.0.0", + "bullmq": "npm:bullmq", + "lib/": "./lib/" + }, + "compilerOptions": { + "jsx": "react-jsx", + "jsxImportSource": "preact" + }, + "nodeModulesDir": "auto", + "fmt": { + "useTabs": true, + "lineWidth": 120, + "indentWidth": 4, + "semiColons": true + } } diff --git a/filter/checkpoint_conversion.py b/filter/checkpoint_conversion.py new file mode 100644 index 0000000..61a72ed --- /dev/null +++ b/filter/checkpoint_conversion.py @@ -0,0 +1,26 @@ +import torch + +from modelV3_10 import VideoClassifierV3_10 +from modelV3_9 import VideoClassifierV3_9 + + +def convert_checkpoint(original_model, new_model): + """转换原始checkpoint到新结构""" + state_dict = original_model.state_dict() + + # 直接复制所有参数(因为结构保持兼容) + new_model.load_state_dict(state_dict) + return new_model + +# 使用示例 +original_model = VideoClassifierV3_9() +new_model = VideoClassifierV3_10() + +# 加载原始checkpoint +original_model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt')) + +# 转换参数 +converted_model = convert_checkpoint(original_model, new_model) + +# 保存转换后的模型 +torch.save(converted_model.state_dict(), './filter/checkpoints/best_model_V3.10.pt') \ No newline at end of file diff --git a/filter/dataset.py b/filter/dataset.py index 62ef371..7a4edc1 100644 --- a/filter/dataset.py +++ b/filter/dataset.py @@ -97,7 +97,7 @@ class MultiChannelDataset(Dataset): example = self.examples[idx] # 处理tags(将数组转换为空格分隔的字符串) - tags_text = " ".join(example['tags']) + tags_text = ",".join(example['tags']) # 返回文本字典 texts = { diff --git a/filter/embedding.py b/filter/embedding.py new file mode 100644 index 0000000..d433b4a --- /dev/null +++ b/filter/embedding.py @@ -0,0 +1,31 @@ +import torch +from model2vec import StaticModel + + +def prepare_batch(batch_data, device="cpu"): + """ + 将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, embedding_dim]。 + + 参数: + batch_data (dict): 输入的 batch 数据,格式为 { + "title": [text1, text2, ...], + "description": [text1, text2, ...], + "tags": [text1, text2, ...], + "author_info": [text1, text2, ...] + } + device (str): 模型运行的设备(如 "cpu" 或 "cuda")。 + + 返回: + torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量。 + """ + # 1. 对每个通道的文本分别编码 + channel_embeddings = [] + model = StaticModel.from_pretrained("./model/embedding/") + for channel in ["title", "description", "tags", "author_info"]: + texts = batch_data[channel] # 获取当前通道的文本列表 + embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim] + channel_embeddings.append(embeddings) + + # 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim] + batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠 + return batch_tensor \ No newline at end of file diff --git a/filter/modelV3_1.py b/filter/modelV3_1.py deleted file mode 100644 index 81e9988..0000000 --- a/filter/modelV3_1.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import torch.nn as nn - -class VideoClassifierV3_1(nn.Module): - def __init__(self, embedding_dim=1024, hidden_dim=384, output_dim=3): - super().__init__() - self.num_channels = 4 - self.channel_names = ['title', 'description', 'tags', 'author_info'] - - # 改进1:带温度系数的通道权重(比原始固定权重更灵活) - self.channel_weights = nn.Parameter(torch.ones(self.num_channels)) - self.temperature = 1.7 # 可调节的平滑系数 - - # 改进2:更稳健的全连接结构 - self.fc = nn.Sequential( - nn.Linear(embedding_dim * self.num_channels, hidden_dim*2), - nn.BatchNorm1d(hidden_dim*2), - nn.Dropout(0.1), - nn.ReLU(), - nn.Linear(hidden_dim*2, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.Linear(hidden_dim, output_dim) - ) - - # 改进3:输出层初始化 - nn.init.xavier_uniform_(self.fc[-1].weight) - nn.init.zeros_(self.fc[-1].bias) - - def forward(self, input_texts, sentence_transformer): - # 合并所有通道文本进行批量编码 - all_texts = [text for channel in self.channel_names for text in input_texts[channel]] - - # 使用SentenceTransformer生成嵌入(保持冻结) - with torch.no_grad(): - task = "classification" - embeddings = torch.tensor( - sentence_transformer.encode(all_texts, task=task), - device=next(self.parameters()).device - ) - - # 分割嵌入并加权 - split_sizes = [len(input_texts[name]) for name in self.channel_names] - channel_features = torch.split(embeddings, split_sizes, dim=0) - channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024] - - # 改进4:带温度系数的softmax加权 - weights = torch.softmax(self.channel_weights / self.temperature, dim=0) - weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1) - - # 拼接特征 - combined = weighted_features.view(weighted_features.size(0), -1) - - # 全连接层 - return self.fc(combined) - - def get_channel_weights(self): - """获取各通道权重(带温度调节)""" - return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy() \ No newline at end of file diff --git a/filter/modelV3_10.py b/filter/modelV3_10.py new file mode 100644 index 0000000..909590b --- /dev/null +++ b/filter/modelV3_10.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class VideoClassifierV3_10(nn.Module): + def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3): + super().__init__() + self.num_channels = 4 + self.channel_names = ['title', 'description', 'tags', 'author_info'] + + # 可学习温度系数 + self.temperature = nn.Parameter(torch.tensor(1.7)) + + # 带约束的通道权重(使用Sigmoid替代Softmax) + self.channel_weights = nn.Parameter(torch.ones(self.num_channels)) + + # 增强的非线性层 + self.fc = nn.Sequential( + nn.Linear(embedding_dim * self.num_channels, hidden_dim*2), + nn.BatchNorm1d(hidden_dim*2), + nn.Dropout(0.2), + nn.GELU(), + nn.Linear(hidden_dim*2, output_dim) + ) + + # 权重初始化 + self._init_weights() + + def _init_weights(self): + for layer in self.fc: + if isinstance(layer, nn.Linear): + # 使用ReLU的初始化参数(GELU的近似) + nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里 + + # 或者使用Xavier初始化(更适合通用场景) + # nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu')) + + nn.init.zeros_(layer.bias) + + + def forward(self, channel_features: torch.Tensor): + """ + 输入格式: [batch_size, num_channels, embedding_dim] + 输出格式: [batch_size, output_dim] + """ + + # 自适应通道权重(Sigmoid约束) + weights = torch.sigmoid(self.channel_weights) # [0,1]范围 + weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1) + + # 特征拼接 + combined = weighted_features.view(weighted_features.size(0), -1) + + return self.fc(combined) + + def get_channel_weights(self): + """获取各通道权重(带温度调节)""" + return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy() + + +class AdaptiveRecallLoss(nn.Module): + def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5): + """ + Args: + class_weights (torch.Tensor): 类别权重 + alpha (float): 召回率调节因子(0-1) + gamma (float): Focal Loss参数 + fp_penalty (float): 类别0假阳性惩罚强度 + """ + super().__init__() + self.class_weights = class_weights + self.alpha = alpha + self.gamma = gamma + self.fp_penalty = fp_penalty + + def forward(self, logits, targets): + # 基础交叉熵损失 + ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none') + + # Focal Loss组件 + pt = torch.exp(-ce_loss) + focal_loss = ((1 - pt) ** self.gamma) * ce_loss + + # 召回率增强(对困难样本加权) + class_mask = F.one_hot(targets, num_classes=len(self.class_weights)) + class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask + recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1) + + # 类别0假阳性惩罚 + probs = F.softmax(logits, dim=1) + fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0) + fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum() + + # 总损失 + total_loss = recall_loss.mean() + fp_loss / len(targets) + + return total_loss \ No newline at end of file diff --git a/filter/modelV3_2.py b/filter/modelV3_2.py deleted file mode 100644 index 15e6b17..0000000 --- a/filter/modelV3_2.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -import torch.nn as nn - -class VideoClassifierV3_2(nn.Module): - def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3): - super().__init__() - self.num_channels = 4 - self.channel_names = ['title', 'description', 'tags', 'author_info'] - - # 改进1:带温度系数的通道权重(比原始固定权重更灵活) - self.channel_weights = nn.Parameter(torch.ones(self.num_channels)) - self.temperature = 1.7 # 可调节的平滑系数 - - # 改进2:更稳健的全连接结构 - self.fc = nn.Sequential( - nn.Linear(embedding_dim * self.num_channels, hidden_dim*2), - nn.BatchNorm1d(hidden_dim*2), - nn.Dropout(0.1), - nn.ReLU(), - nn.Linear(hidden_dim*2, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.Linear(hidden_dim, output_dim) - ) - - # 改进3:输出层初始化 - nn.init.xavier_uniform_(self.fc[-1].weight) - nn.init.zeros_(self.fc[-1].bias) - - def forward(self, input_texts, sentence_transformer): - # 合并所有通道文本进行批量编码 - all_texts = [text for channel in self.channel_names for text in input_texts[channel]] - - # 使用SentenceTransformer生成嵌入(保持冻结) - with torch.no_grad(): - task = "classification" - embeddings = torch.tensor( - sentence_transformer.encode(all_texts, task=task), - device=next(self.parameters()).device - ) - - # 分割嵌入并加权 - split_sizes = [len(input_texts[name]) for name in self.channel_names] - channel_features = torch.split(embeddings, split_sizes, dim=0) - channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024] - - # 改进4:带温度系数的softmax加权 - weights = torch.softmax(self.channel_weights / self.temperature, dim=0) - weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1) - - # 拼接特征 - combined = weighted_features.view(weighted_features.size(0), -1) - - # 全连接层 - return self.fc(combined) - - def get_channel_weights(self): - """获取各通道权重(带温度调节)""" - return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy() \ No newline at end of file diff --git a/filter/modelV3_3.py b/filter/modelV3_3.py deleted file mode 100644 index 5ab4f57..0000000 --- a/filter/modelV3_3.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch -import torch.nn as nn - -class VideoClassifierV3_3(nn.Module): - def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3): - super().__init__() - self.num_channels = 4 - self.channel_names = ['title', 'description', 'tags', 'author_info'] - - # 带温度系数的通道权重(比原始固定权重更灵活) - self.channel_weights = nn.Parameter(torch.ones(self.num_channels)) - self.temperature = 1.7 # 可调节的平滑系数 - - # 改进后的非线性层 - self.fc = nn.Sequential( - nn.Linear(embedding_dim * self.num_channels, hidden_dim*2), - nn.BatchNorm1d(hidden_dim*2), - nn.Dropout(0.1), - nn.ReLU(), - nn.Linear(hidden_dim*2, output_dim) - ) - - # 输出层初始化 - nn.init.xavier_uniform_(self.fc[-1].weight) - nn.init.zeros_(self.fc[-1].bias) - - def forward(self, input_texts, sentence_transformer): - # 合并所有通道文本进行批量编码 - all_texts = [text for channel in self.channel_names for text in input_texts[channel]] - - # 使用SentenceTransformer生成嵌入(保持冻结) - with torch.no_grad(): - task = "classification" - embeddings = torch.tensor( - sentence_transformer.encode(all_texts, task=task), - device=next(self.parameters()).device - ) - - # 分割嵌入并加权 - split_sizes = [len(input_texts[name]) for name in self.channel_names] - channel_features = torch.split(embeddings, split_sizes, dim=0) - channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024] - - # 改进4:带温度系数的softmax加权 - weights = torch.softmax(self.channel_weights / self.temperature, dim=0) - weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1) - - # 拼接特征 - combined = weighted_features.view(weighted_features.size(0), -1) - - # 全连接层 - return self.fc(combined) - - def get_channel_weights(self): - """获取各通道权重(带温度调节)""" - return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy() \ No newline at end of file diff --git a/filter/onnx_export.py b/filter/onnx_export.py new file mode 100644 index 0000000..6337ef3 --- /dev/null +++ b/filter/onnx_export.py @@ -0,0 +1,32 @@ +import torch +from modelV3_10 import VideoClassifierV3_10 + + +def export_onnx(model_path="./filter/checkpoints/best_model_V3.10.pt", + onnx_path="./model/video_classifier_v3_10.onnx"): + # 初始化模型 + model = VideoClassifierV3_10() + model.load_state_dict(torch.load(model_path)) + model.eval() + + # 创建符合输入规范的虚拟输入 + dummy_input = torch.randn(1, 4, 1024) # [batch=1, channels=4, embedding_dim=1024] + + # 导出ONNX + torch.onnx.export( + model, + dummy_input, + onnx_path, + input_names=["channel_features"], + output_names=["logits"], + dynamic_axes={ + "channel_features": {0: "batch_size"}, + "logits": {0: "batch_size"} + }, + opset_version=13, + do_constant_folding=True + ) + print(f"模型已成功导出到 {onnx_path}") + +# 执行导出 +export_onnx() \ No newline at end of file diff --git a/filter/train.py b/filter/train.py index e65edbb..ad0bc3d 100644 --- a/filter/train.py +++ b/filter/train.py @@ -4,13 +4,14 @@ import numpy as np from torch.utils.data import DataLoader import torch.optim as optim from dataset import MultiChannelDataset -from filter.modelV3_9 import VideoClassifierV3_9, AdaptiveRecallLoss +from filter.modelV3_10 import VideoClassifierV3_10, AdaptiveRecallLoss from sentence_transformers import SentenceTransformer from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report import os import torch from torch.utils.tensorboard import SummaryWriter # 引入 TensorBoard import time +from embedding import prepare_batch # 动态生成子目录名称 @@ -52,8 +53,8 @@ class_weights = torch.tensor( # 初始化模型和SentenceTransformer sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024") -model = VideoClassifierV3_9() -checkpoint_name = './filter/checkpoints/best_model_V3.9.pt' +model = VideoClassifierV3_10() +checkpoint_name = './filter/checkpoints/best_model_V3.11.pt' # 模型保存路径 os.makedirs('./filter/checkpoints', exist_ok=True) @@ -77,8 +78,8 @@ def evaluate(model, dataloader): with torch.no_grad(): for batch in dataloader: - # 传入文本字典和sentence_transformer - logits = model(input_texts=batch['texts'], sentence_transformer=sentence_transformer) + batch_tensor = prepare_batch(batch['texts'], device="cpu") + logits = model(batch_tensor) preds = torch.argmax(logits, dim=1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(batch['label'].cpu().numpy()) @@ -110,8 +111,10 @@ for epoch in range(num_epochs): for batch_idx, batch in enumerate(train_loader): optimizer.zero_grad() + batch_tensor = prepare_batch(batch['texts'], device="cpu") + # 传入文本字典和sentence_transformer - logits = model(input_texts=batch['texts'], sentence_transformer=sentence_transformer) + logits = model(batch_tensor) loss = criterion(logits, batch['label']) loss.backward() diff --git a/lib/db/allData.ts b/lib/db/allData.ts new file mode 100644 index 0000000..1aa2ff4 --- /dev/null +++ b/lib/db/allData.ts @@ -0,0 +1,15 @@ +import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; +import { AllDataType } from "lib/db/schema.d.ts"; + +export async function videoExistsInAllData(client: Client, aid: number) { + return await client.queryObject<{ exists: boolean }>("SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)", [aid]) + .then((result) => result.rows[0].exists); +} + +export async function insertIntoAllData(client: Client, data: AllDataType) { + console.log(`inserted ${data.aid}`) + return await client.queryObject( + "INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at) VALUES ($1, $2, $3, $4, $5, $6, $7)", + [data.aid, data.bvid, data.description, data.uid, data.tags, data.title, data.published_at], + ); +} diff --git a/lib/db/schema.d.ts b/lib/db/schema.d.ts new file mode 100644 index 0000000..db8c9a4 --- /dev/null +++ b/lib/db/schema.d.ts @@ -0,0 +1,9 @@ +export interface AllDataType { + aid: number; + bvid: string | null; + description: string | null; + uid: number | null; + tags: string | null; + title: string | null; + published_at: string | null; +} \ No newline at end of file diff --git a/lib/ml/SentenceTransformer/index.ts b/lib/ml/SentenceTransformer/index.ts new file mode 100644 index 0000000..3676f2a --- /dev/null +++ b/lib/ml/SentenceTransformer/index.ts @@ -0,0 +1,19 @@ +import { SentenceTransformer } from "./model.ts"; // Changed import path + +async function main() { + const sentenceTransformer = await SentenceTransformer.from_pretrained( + "mixedbread-ai/mxbai-embed-large-v1", + ); + const outputs = await sentenceTransformer.encode([ + "Hello world", + "How are you guys doing?", + "Today is Friday!", + ]); + + // @ts-ignore + console.log(outputs["last_hidden_state"]); + + return outputs; +} + +main(); // Keep main function call if you want this file to be runnable directly for testing. diff --git a/lib/ml/SentenceTransformer/model.ts b/lib/ml/SentenceTransformer/model.ts new file mode 100644 index 0000000..7d8b507 --- /dev/null +++ b/lib/ml/SentenceTransformer/model.ts @@ -0,0 +1,40 @@ +// lib/ml/sentence_transformer_model.ts +import { AutoModel, AutoTokenizer, PretrainedOptions } from "@huggingface/transformers"; + +export class SentenceTransformer { + constructor( + private readonly tokenizer: AutoTokenizer, + private readonly model: AutoModel, + ) {} + + static async from_pretrained( + modelName: string, + options?: PretrainedOptions, + ): Promise { + if (!options) { + options = { + progress_callback: undefined, + cache_dir: undefined, + local_files_only: false, + revision: "main", + }; + } + const tokenizer = await AutoTokenizer.from_pretrained(modelName, options); + const model = await AutoModel.from_pretrained(modelName, options); + + return new SentenceTransformer(tokenizer, model); + } + + async encode(sentences: string[]): Promise { // Changed return type to 'any' for now to match console.log output + //@ts-ignore + const modelInputs = await this.tokenizer(sentences, { + padding: true, + truncation: true, + }); + + //@ts-ignore + const outputs = await this.model(modelInputs); + + return outputs; + } +} diff --git a/lib/ml/SentenceTransformer/pooling.ts b/lib/ml/SentenceTransformer/pooling.ts new file mode 100644 index 0000000..762feb7 --- /dev/null +++ b/lib/ml/SentenceTransformer/pooling.ts @@ -0,0 +1,34 @@ +import { Tensor } from "@huggingface/transformers"; +//@ts-ignore +import { Callable } from "@huggingface/transformers/src/utils/core.js"; // Keep as is for now, might need adjustment + +export interface PoolingConfig { + word_embedding_dimension: number; + pooling_mode_cls_token: boolean; + pooling_mode_mean_tokens: boolean; + pooling_mode_max_tokens: boolean; + pooling_mode_mean_sqrt_len_tokens: boolean; +} + +export interface PoolingInput { + token_embeddings: Tensor; + attention_mask: Tensor; +} + +export interface PoolingOutput { + sentence_embedding: Tensor; +} + +export class Pooling extends Callable { + constructor(private readonly config: PoolingConfig) { + super(); + } + + // async _call(inputs: any) { // Keep if pooling functionality is needed + // return this.forward(inputs); + // } + + // async forward(inputs: PoolingInput): PoolingOutput { // Keep if pooling functionality is needed + + // } +} \ No newline at end of file diff --git a/lib/ml/classifyVideo.ts b/lib/ml/classifyVideo.ts new file mode 100644 index 0000000..6d27e8b --- /dev/null +++ b/lib/ml/classifyVideo.ts @@ -0,0 +1,32 @@ +import { AutoModel, AutoTokenizer, Tensor } from '@huggingface/transformers'; + +const modelName = "alikia2x/jina-embedding-v3-m2v-1024"; + +const modelConfig = { + config: { model_type: 'model2vec' }, + dtype: 'fp32', + revision: 'refs/pr/1', + cache_dir: undefined, + local_files_only: true, +}; +const tokenizerConfig = { + revision: 'refs/pr/2' +}; + +const model = await AutoModel.from_pretrained(modelName, modelConfig); +const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig); + +const texts = ['hello', 'hello world']; +const { input_ids } = await tokenizer(texts, { add_special_tokens: false, return_tensor: false }); + +const cumsum = arr => arr.reduce((acc, num, i) => [...acc, num + (acc[i - 1] || 0)], []); +const offsets = [0, ...cumsum(input_ids.slice(0, -1).map(x => x.length))]; + +const flattened_input_ids = input_ids.flat(); +const modelInputs = { + input_ids: new Tensor('int64', flattened_input_ids, [flattened_input_ids.length]), + offsets: new Tensor('int64', offsets, [offsets.length]) +}; + +const { embeddings } = await model(modelInputs); +console.log(embeddings.tolist()); // output matches python version \ No newline at end of file diff --git a/lib/mq/index.ts b/lib/mq/index.ts new file mode 100644 index 0000000..df17c36 --- /dev/null +++ b/lib/mq/index.ts @@ -0,0 +1,2 @@ +import { Queue } from "bullmq"; + diff --git a/lib/net/bilibili.d.ts b/lib/net/bilibili.d.ts new file mode 100644 index 0000000..a0f682d --- /dev/null +++ b/lib/net/bilibili.d.ts @@ -0,0 +1,117 @@ +interface BaseResponse { + code: number; + message: string; + ttl: number; + data: T; +} + +export type VideoListResponse = BaseResponse; +export type VideoTagsResponse = BaseResponse; + +type VideoTagsData = VideoTags[]; + +interface VideoTags { + tag_id: number; + tag_name: string; + cover: string; + head_cover: string; + content: string; + short_content: string; + type: number; + state: number; + ctime: number; + count: { + view: number; + use: number; + atten: number; + } + is_atten: number; + likes: number; + hates: number; + attribute: number; + liked: number; + hated: number; + extra_attr: number; +} + +interface VideoListData { + archives: VideoListVideo[]; + page: { + num: number; + size: number; + count: number; + }; +} + +interface VideoListVideo { + aid: number; + videos: number; + tid: number; + tname: string; + copyright: number; + pic: string; + title: string; + pubdate: number; + ctime: number; + desc: string; + state: number; + duration: number; + mission_id?: number; + rights: { + bp: number; + elec: number; + download: number; + movie: number; + pay: number; + hd5: number; + no_reprint: number; + autoplay: number; + ugc_pay: number; + is_cooperation: number; + ugc_pay_preview: number; + no_background: number; + arc_pay: number; + pay_free_watch: number; + }, + owner: { + mid: number; + name: string; + face: string; + }, + stat: { + aid: number; + view: number; + danmaku: number; + reply: number; + favorite: number; + coin: number; + share: number; + now_rank: number; + his_rank: number; + like: number; + dislike: number; + vt: number; + vv: number; + }, + dynamic: string; + cid: number; + dimension: { + width: number; + height: number; + rotate: number; + }, + season_id?: number; + short_link_v2: string; + first_frame: string; + pub_location: string; + cover43: string; + tidv2: number; + tname_v2: string; + bvid: string; + season_type: number; + is_ogv: number; + ovg_info: string | null; + rcmd_season: string; + enable_vt: number; + ai_rcmd: null | string; +} diff --git a/lib/net/getLatestVideos.ts b/lib/net/getLatestVideos.ts new file mode 100644 index 0000000..a46b735 --- /dev/null +++ b/lib/net/getLatestVideos.ts @@ -0,0 +1,48 @@ +import { VideoListResponse } from "lib/net/bilibili.d.ts"; +import formatPublishedAt from "lib/utils/formatTimestampToPostgre.ts"; +import { getVideoTags } from "lib/net/getVideoTags.ts"; +import { AllDataType } from "lib/db/schema.d.ts"; +import { sleep } from "lib/utils/sleep.ts"; + +export async function getLatestVideos(page: number = 1, pageSize: number = 10): Promise { + try { + const response = await fetch(`https://api.bilibili.com/x/web-interface/newlist?rid=30&ps=${pageSize}&pn=${page}`); + const data: VideoListResponse = await response.json(); + + if (data.code !== 0) { + console.error(`Error fetching videos: ${data.message}`); + return null; + } + + if (data.data.archives.length === 0) { + console.warn("No more videos found"); + return []; + } + + const videoPromises = data.data.archives.map(async (video) => { + const published_at = formatPublishedAt(video.pubdate + 3600 * 8); + sleep(Math.random() * pageSize * 250); + const tags = await getVideoTags(video.aid); + let processedTags = null; + if (tags !== null) { + processedTags = tags.join(','); + } + return { + aid: video.aid, + bvid: video.bvid, + description: video.desc, + uid: video.owner.mid, + tags: processedTags, + title: video.title, + published_at: published_at, + } as AllDataType; + }); + + const result = await Promise.all(videoPromises); + + return result; + } catch (error) { + console.error(error); + return null; + } +} \ No newline at end of file diff --git a/lib/net/getVideoTags.ts b/lib/net/getVideoTags.ts new file mode 100644 index 0000000..9e33083 --- /dev/null +++ b/lib/net/getVideoTags.ts @@ -0,0 +1,18 @@ +import { VideoTagsResponse } from "lib/net/bilibili.d.ts"; + +export async function getVideoTags(aid: number): Promise { + try { + const url = `https://api.bilibili.com/x/tag/archive/tags?aid=${aid}`; + const res = await fetch(url); + const data: VideoTagsResponse = await res.json(); + if (data.code != 0) { + console.error(`Error fetching tags for video ${aid}: ${data.message}`); + return []; + } + return data.data.map((tag) => tag.tag_name); + } + catch { + console.error(`Error fetching tags for video ${aid}`); + return null; + } +} diff --git a/lib/task/insertLatestVideo.ts b/lib/task/insertLatestVideo.ts new file mode 100644 index 0000000..3eaf60b --- /dev/null +++ b/lib/task/insertLatestVideo.ts @@ -0,0 +1,81 @@ +import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; +import { getLatestVideos } from "lib/net/getLatestVideos.ts"; +import { insertIntoAllData, videoExistsInAllData } from "lib/db/allData.ts"; +import { sleep } from "lib/utils/sleep.ts"; + +const requiredEnvVars = ["DB_HOST", "DB_NAME", "DB_USER", "DB_PASSWORD", "DB_PORT"]; + +const unsetVars = requiredEnvVars.filter((key) => Deno.env.get(key) === undefined); + +if (unsetVars.length > 0) { + throw new Error(`Missing required environment variables: ${unsetVars.join(", ")}`); +} + +const databaseHost = Deno.env.get("DB_HOST")!; +const databaseName = Deno.env.get("DB_NAME")!; +const databaseUser = Deno.env.get("DB_USER")!; +const databasePassword = Deno.env.get("DB_PASSWORD")!; +const databasePort = Deno.env.get("DB_PORT")!; + +const postgresConfig = { + hostname: databaseHost, + port: parseInt(databasePort), + database: databaseName, + user: databaseUser, + password: databasePassword, +}; + +async function connectToPostgres() { + const client = new Client(postgresConfig); + await client.connect(); + return client; +} + +export async function insertLatestVideos() { + const client = await connectToPostgres(); + let page = 334; + let failCount = 0; + while (true) { + try { + const videos = await getLatestVideos(page, 10); + if (videos == null) { + failCount++; + if (failCount > 5) { + break; + } + continue; + } + if (videos.length == 0) { + console.warn("No more videos found"); + break; + } + let allExists = true; + for (const video of videos) { + const videoExists = await videoExistsInAllData(client, video.aid); + if (!videoExists) { + allExists = false; + insertIntoAllData(client, video); + } + } + if (allExists) { + console.log("All videos already exist in all_data, stop crawling."); + break; + } + console.log(`Page ${page} crawled, total: ${(page - 1) * 20 + videos.length} videos.`); + page++; + } catch (error) { + console.error(error); + failCount++; + if (failCount > 5) { + break; + } + continue; + } + finally { + await sleep(Math.random() * 4000 + 1000); + } + } +} + + +insertLatestVideos(); \ No newline at end of file diff --git a/lib/utils/formatTimestampToPostgre.ts b/lib/utils/formatTimestampToPostgre.ts new file mode 100644 index 0000000..9b5140a --- /dev/null +++ b/lib/utils/formatTimestampToPostgre.ts @@ -0,0 +1,4 @@ +export default function formatTimestamp(timestamp: number) { + const date = new Date(timestamp * 1000); + return date.toISOString().slice(0, 19).replace("T", " "); +} \ No newline at end of file diff --git a/lib/utils/sleep.ts b/lib/utils/sleep.ts new file mode 100644 index 0000000..3a5dcb9 --- /dev/null +++ b/lib/utils/sleep.ts @@ -0,0 +1,3 @@ +export async function sleep(ms: number) { + await new Promise((resolve) => setTimeout(resolve, ms)); +} \ No newline at end of file diff --git a/src/db/raw/aliyun-fc.mjs b/src/db/raw/aliyun-fc.mjs index 0fcca42..d7a9c00 100644 --- a/src/db/raw/aliyun-fc.mjs +++ b/src/db/raw/aliyun-fc.mjs @@ -1,6 +1,7 @@ +import { Buffer } from "node:buffer"; "use strict"; -export const handler = async (event, context) => { +export const handler = async (event, _context) => { const eventObj = JSON.parse(event); console.log(`receive event: ${JSON.stringify(eventObj)}`); diff --git a/src/net/getLatestVideos.ts b/src/net/getLatestVideos.ts deleted file mode 100644 index b6df59e..0000000 --- a/src/net/getLatestVideos.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts"; - -const API_URL = "https://api.bilibili.com/x/web-interface/newlist?rid=30&ps=50&pn="; - -const requiredEnvVars = ["DB_HOST", "DB_NAME", "DB_USER", "DB_PASSWORD", "DB_PORT"]; - -const unsetVars = requiredEnvVars.filter((key) => !Deno.env.get(key)); - -if (unsetVars.length > 0) { - throw new Error(`Missing required environment variables: ${unsetVars.join(", ")}`); -} - -const databaseHost = Deno.env.get("DB_HOST")!; -const databaseName = Deno.env.get("DB_NAME")!; -const databaseUser = Deno.env.get("DB_USER")!; -const databasePassword = Deno.env.get("DB_PASSWORD")!; -const databasePort = Deno.env.get("DB_PORT")!; - -const postgresConfig = { - hostname: databaseHost, - port: parseInt(databasePort), - database: databaseName, - user: databaseUser, - password: databasePassword, -}; - -async function connectToPostgres() { - const client = new Client(postgresConfig); - await client.connect(); - return client; -} - -export async function getLatestVideos() { - const client = await connectToPostgres(); -} diff --git a/test/net/getLatestVideos.test.ts b/test/net/getLatestVideos.test.ts new file mode 100644 index 0000000..b2daa4d --- /dev/null +++ b/test/net/getLatestVideos.test.ts @@ -0,0 +1,25 @@ +import { assertEquals } from "jsr:@std/assert"; +import { getLatestVideos } from "lib/net/getLatestVideos.ts"; + +Deno.test("Get latest videos", async () => { + const videos = (await getLatestVideos(1, 5))!; + assertEquals(videos.length, 5); + + videos.forEach((video) => { + assertVideoProperties(video); + }); +}); + +function assertVideoProperties(video: object) { + const aid = "aid" in video && typeof video.aid === "number"; + const bvid = "bvid" in video && typeof video.bvid === "string" && + video.bvid.length === 12 && video.bvid.startsWith("BV"); + const description = "description" in video && typeof video.description === "string"; + const uid = "uid" in video && typeof video.uid === "number"; + const tags = "tags" in video && (typeof video.tags === "string" || video.tags === null); + const title = "title" in video && typeof video.title === "string"; + const publishedAt = "published_at" in video && typeof video.published_at === "string"; + + const match = aid && bvid && description && uid && tags && title && publishedAt; + assertEquals(match, true); +} diff --git a/test/net/getVideoTags.test.ts b/test/net/getVideoTags.test.ts new file mode 100644 index 0000000..dd1f02a --- /dev/null +++ b/test/net/getVideoTags.test.ts @@ -0,0 +1,28 @@ +import { assertEquals } from "jsr:@std/assert"; +import { getVideoTags } from "lib/net/getVideoTags.ts"; + +Deno.test("Get video tags - regular video", async () => { + const tags = (await getVideoTags(826597951)).sort(); + assertEquals(tags, [ + "纯白P", + "中华墨水娘", + "中华少女", + "中华粘土娘", + "中华缘木娘", + "中华少女Project", + "提糯Tino", + "中华烛火娘", + "中华烁金娘", + "新世代音乐人计划女生季", + ].sort()); +}); + +Deno.test("Get video tags - non-existent video", async () => { + const tags = (await getVideoTags(8265979511111111)); + assertEquals(tags, []); +}); + +Deno.test("Get video tags - video with no tag", async () => { + const tags = (await getVideoTags(981001865)); + assertEquals(tags, []); +}); \ No newline at end of file