update: structurize video insertion code, use js-compatible filter model arch

This commit is contained in:
alikia2x (寒寒) 2025-02-08 00:36:40 +08:00
parent cf4ff398b8
commit b21e6da07a
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
28 changed files with 731 additions and 272 deletions

27
.gitignore vendored
View File

@ -61,19 +61,6 @@ TEST-results.xml
package-lock.json
.eslintcache
*v8.log
/lib/
# project specific
data/main.db
.env
logs/
__pycache__
filter/runs
data/filter/eval*
data/filter/train*
filter/checkpoints
data/filter/model_predicted*
scripts
# dotenv environment variable files
.env
@ -86,3 +73,17 @@ scripts
_fresh/
# npm dependencies
node_modules/
# project specific
data/main.db
.env
logs/
__pycache__
filter/runs
data/filter/eval*
data/filter/train*
filter/checkpoints
data/filter/model_predicted*
scripts
model/

View File

@ -1,46 +1,46 @@
{
"lock": false,
"tasks": {
"crawl-raw-bili": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/insertAidsToDB.ts",
"crawl-bili-aids": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/fetchAids.ts",
"check": "deno fmt --check && deno lint && deno check **/*.ts && deno check **/*.tsx",
"cli": "echo \"import '\\$fresh/src/dev/cli.ts'\" | deno run --unstable -A -",
"manifest": "deno task cli manifest $(pwd)",
"start": "deno run -A --watch=static/,routes/ dev.ts",
"build": "deno run -A dev.ts build",
"preview": "deno run -A main.ts",
"update": "deno run -A -r https://fresh.deno.dev/update ."
},
"lint": {
"rules": {
"tags": ["fresh", "recommended"]
}
},
"exclude": ["**/_fresh/*"],
"imports": {
"@std/assert": "jsr:@std/assert@1",
"@types/better-sqlite3": "npm:@types/better-sqlite3@^7.6.12",
"axios": "npm:axios@^1.7.9",
"better-sqlite3": "npm:better-sqlite3@^11.7.2",
"$fresh/": "https://deno.land/x/fresh@1.7.3/",
"preact": "https://esm.sh/preact@10.22.0",
"preact/": "https://esm.sh/preact@10.22.0/",
"@preact/signals": "https://esm.sh/*@preact/signals@1.2.2",
"@preact/signals-core": "https://esm.sh/*@preact/signals-core@1.5.1",
"tailwindcss": "npm:tailwindcss@3.4.1",
"tailwindcss/": "npm:/tailwindcss@3.4.1/",
"tailwindcss/plugin": "npm:/tailwindcss@3.4.1/plugin.js",
"$std/": "https://deno.land/std@0.216.0/"
},
"compilerOptions": {
"jsx": "react-jsx",
"jsxImportSource": "preact"
},
"nodeModulesDir": "auto",
"fmt": {
"useTabs": true,
"lineWidth": 120,
"indentWidth": 4,
"semiColons": true
}
"lock": false,
"tasks": {
"crawl-raw-bili": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/insertAidsToDB.ts",
"crawl-bili-aids": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/fetchAids.ts",
"check": "deno fmt --check && deno lint && deno check **/*.ts && deno check **/*.tsx",
"cli": "echo \"import '\\$fresh/src/dev/cli.ts'\" | deno run --unstable -A -",
"manifest": "deno task cli manifest $(pwd)",
"start": "deno run -A --watch=static/,routes/ dev.ts",
"build": "deno run -A dev.ts build",
"preview": "deno run -A main.ts",
"update": "deno run -A -r https://fresh.deno.dev/update ."
},
"lint": {
"rules": {
"tags": ["fresh", "recommended"]
}
},
"exclude": ["**/_fresh/*"],
"imports": {
"@std/assert": "jsr:@std/assert@1",
"$fresh/": "https://deno.land/x/fresh@1.7.3/",
"preact": "https://esm.sh/preact@10.22.0",
"preact/": "https://esm.sh/preact@10.22.0/",
"@preact/signals": "https://esm.sh/*@preact/signals@1.2.2",
"@preact/signals-core": "https://esm.sh/*@preact/signals-core@1.5.1",
"tailwindcss": "npm:tailwindcss@3.4.1",
"tailwindcss/": "npm:/tailwindcss@3.4.1/",
"tailwindcss/plugin": "npm:/tailwindcss@3.4.1/plugin.js",
"$std/": "https://deno.land/std@0.216.0/",
"@huggingface/transformers": "npm:@huggingface/transformers@3.0.0",
"bullmq": "npm:bullmq",
"lib/": "./lib/"
},
"compilerOptions": {
"jsx": "react-jsx",
"jsxImportSource": "preact"
},
"nodeModulesDir": "auto",
"fmt": {
"useTabs": true,
"lineWidth": 120,
"indentWidth": 4,
"semiColons": true
}
}

View File

@ -0,0 +1,26 @@
import torch
from modelV3_10 import VideoClassifierV3_10
from modelV3_9 import VideoClassifierV3_9
def convert_checkpoint(original_model, new_model):
"""转换原始checkpoint到新结构"""
state_dict = original_model.state_dict()
# 直接复制所有参数(因为结构保持兼容)
new_model.load_state_dict(state_dict)
return new_model
# 使用示例
original_model = VideoClassifierV3_9()
new_model = VideoClassifierV3_10()
# 加载原始checkpoint
original_model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt'))
# 转换参数
converted_model = convert_checkpoint(original_model, new_model)
# 保存转换后的模型
torch.save(converted_model.state_dict(), './filter/checkpoints/best_model_V3.10.pt')

View File

@ -97,7 +97,7 @@ class MultiChannelDataset(Dataset):
example = self.examples[idx]
# 处理tags将数组转换为空格分隔的字符串
tags_text = " ".join(example['tags'])
tags_text = ",".join(example['tags'])
# 返回文本字典
texts = {

31
filter/embedding.py Normal file
View File

@ -0,0 +1,31 @@
import torch
from model2vec import StaticModel
def prepare_batch(batch_data, device="cpu"):
"""
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, embedding_dim]
参数:
batch_data (dict): 输入的 batch 数据格式为 {
"title": [text1, text2, ...],
"description": [text1, text2, ...],
"tags": [text1, text2, ...],
"author_info": [text1, text2, ...]
}
device (str): 模型运行的设备 "cpu" "cuda"
返回:
torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量
"""
# 1. 对每个通道的文本分别编码
channel_embeddings = []
model = StaticModel.from_pretrained("./model/embedding/")
for channel in ["title", "description", "tags", "author_info"]:
texts = batch_data[channel] # 获取当前通道的文本列表
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
channel_embeddings.append(embeddings)
# 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim]
batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠
return batch_tensor

View File

@ -1,58 +0,0 @@
import torch
import torch.nn as nn
class VideoClassifierV3_1(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=384, output_dim=3):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 改进1带温度系数的通道权重比原始固定权重更灵活
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
self.temperature = 1.7 # 可调节的平滑系数
# 改进2更稳健的全连接结构
self.fc = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.1),
nn.ReLU(),
nn.Linear(hidden_dim*2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, output_dim)
)
# 改进3输出层初始化
nn.init.xavier_uniform_(self.fc[-1].weight)
nn.init.zeros_(self.fc[-1].bias)
def forward(self, input_texts, sentence_transformer):
# 合并所有通道文本进行批量编码
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
# 使用SentenceTransformer生成嵌入保持冻结
with torch.no_grad():
task = "classification"
embeddings = torch.tensor(
sentence_transformer.encode(all_texts, task=task),
device=next(self.parameters()).device
)
# 分割嵌入并加权
split_sizes = [len(input_texts[name]) for name in self.channel_names]
channel_features = torch.split(embeddings, split_sizes, dim=0)
channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024]
# 改进4带温度系数的softmax加权
weights = torch.softmax(self.channel_weights / self.temperature, dim=0)
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 拼接特征
combined = weighted_features.view(weighted_features.size(0), -1)
# 全连接层
return self.fc(combined)
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()

97
filter/modelV3_10.py Normal file
View File

@ -0,0 +1,97 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV3_10(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 可学习温度系数
self.temperature = nn.Parameter(torch.tensor(1.7))
# 带约束的通道权重使用Sigmoid替代Softmax
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
# 增强的非线性层
self.fc = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim*2, output_dim)
)
# 权重初始化
self._init_weights()
def _init_weights(self):
for layer in self.fc:
if isinstance(layer, nn.Linear):
# 使用ReLU的初始化参数GELU的近似
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里
# 或者使用Xavier初始化更适合通用场景
# nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu'))
nn.init.zeros_(layer.bias)
def forward(self, channel_features: torch.Tensor):
"""
输入格式: [batch_size, num_channels, embedding_dim]
输出格式: [batch_size, output_dim]
"""
# 自适应通道权重Sigmoid约束
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 特征拼接
combined = weighted_features.view(weighted_features.size(0), -1)
return self.fc(combined)
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
class AdaptiveRecallLoss(nn.Module):
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
"""
Args:
class_weights (torch.Tensor): 类别权重
alpha (float): 召回率调节因子0-1
gamma (float): Focal Loss参数
fp_penalty (float): 类别0假阳性惩罚强度
"""
super().__init__()
self.class_weights = class_weights
self.alpha = alpha
self.gamma = gamma
self.fp_penalty = fp_penalty
def forward(self, logits, targets):
# 基础交叉熵损失
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
# Focal Loss组件
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
# 召回率增强(对困难样本加权)
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
# 类别0假阳性惩罚
probs = F.softmax(logits, dim=1)
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
# 总损失
total_loss = recall_loss.mean() + fp_loss / len(targets)
return total_loss

View File

@ -1,58 +0,0 @@
import torch
import torch.nn as nn
class VideoClassifierV3_2(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 改进1带温度系数的通道权重比原始固定权重更灵活
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
self.temperature = 1.7 # 可调节的平滑系数
# 改进2更稳健的全连接结构
self.fc = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.1),
nn.ReLU(),
nn.Linear(hidden_dim*2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, output_dim)
)
# 改进3输出层初始化
nn.init.xavier_uniform_(self.fc[-1].weight)
nn.init.zeros_(self.fc[-1].bias)
def forward(self, input_texts, sentence_transformer):
# 合并所有通道文本进行批量编码
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
# 使用SentenceTransformer生成嵌入保持冻结
with torch.no_grad():
task = "classification"
embeddings = torch.tensor(
sentence_transformer.encode(all_texts, task=task),
device=next(self.parameters()).device
)
# 分割嵌入并加权
split_sizes = [len(input_texts[name]) for name in self.channel_names]
channel_features = torch.split(embeddings, split_sizes, dim=0)
channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024]
# 改进4带温度系数的softmax加权
weights = torch.softmax(self.channel_weights / self.temperature, dim=0)
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 拼接特征
combined = weighted_features.view(weighted_features.size(0), -1)
# 全连接层
return self.fc(combined)
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()

View File

@ -1,56 +0,0 @@
import torch
import torch.nn as nn
class VideoClassifierV3_3(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 带温度系数的通道权重(比原始固定权重更灵活)
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
self.temperature = 1.7 # 可调节的平滑系数
# 改进后的非线性层
self.fc = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.1),
nn.ReLU(),
nn.Linear(hidden_dim*2, output_dim)
)
# 输出层初始化
nn.init.xavier_uniform_(self.fc[-1].weight)
nn.init.zeros_(self.fc[-1].bias)
def forward(self, input_texts, sentence_transformer):
# 合并所有通道文本进行批量编码
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
# 使用SentenceTransformer生成嵌入保持冻结
with torch.no_grad():
task = "classification"
embeddings = torch.tensor(
sentence_transformer.encode(all_texts, task=task),
device=next(self.parameters()).device
)
# 分割嵌入并加权
split_sizes = [len(input_texts[name]) for name in self.channel_names]
channel_features = torch.split(embeddings, split_sizes, dim=0)
channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024]
# 改进4带温度系数的softmax加权
weights = torch.softmax(self.channel_weights / self.temperature, dim=0)
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 拼接特征
combined = weighted_features.view(weighted_features.size(0), -1)
# 全连接层
return self.fc(combined)
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()

32
filter/onnx_export.py Normal file
View File

@ -0,0 +1,32 @@
import torch
from modelV3_10 import VideoClassifierV3_10
def export_onnx(model_path="./filter/checkpoints/best_model_V3.10.pt",
onnx_path="./model/video_classifier_v3_10.onnx"):
# 初始化模型
model = VideoClassifierV3_10()
model.load_state_dict(torch.load(model_path))
model.eval()
# 创建符合输入规范的虚拟输入
dummy_input = torch.randn(1, 4, 1024) # [batch=1, channels=4, embedding_dim=1024]
# 导出ONNX
torch.onnx.export(
model,
dummy_input,
onnx_path,
input_names=["channel_features"],
output_names=["logits"],
dynamic_axes={
"channel_features": {0: "batch_size"},
"logits": {0: "batch_size"}
},
opset_version=13,
do_constant_folding=True
)
print(f"模型已成功导出到 {onnx_path}")
# 执行导出
export_onnx()

View File

@ -4,13 +4,14 @@ import numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim
from dataset import MultiChannelDataset
from filter.modelV3_9 import VideoClassifierV3_9, AdaptiveRecallLoss
from filter.modelV3_10 import VideoClassifierV3_10, AdaptiveRecallLoss
from sentence_transformers import SentenceTransformer
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
import os
import torch
from torch.utils.tensorboard import SummaryWriter # 引入 TensorBoard
import time
from embedding import prepare_batch
# 动态生成子目录名称
@ -52,8 +53,8 @@ class_weights = torch.tensor(
# 初始化模型和SentenceTransformer
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
model = VideoClassifierV3_9()
checkpoint_name = './filter/checkpoints/best_model_V3.9.pt'
model = VideoClassifierV3_10()
checkpoint_name = './filter/checkpoints/best_model_V3.11.pt'
# 模型保存路径
os.makedirs('./filter/checkpoints', exist_ok=True)
@ -77,8 +78,8 @@ def evaluate(model, dataloader):
with torch.no_grad():
for batch in dataloader:
# 传入文本字典和sentence_transformer
logits = model(input_texts=batch['texts'], sentence_transformer=sentence_transformer)
batch_tensor = prepare_batch(batch['texts'], device="cpu")
logits = model(batch_tensor)
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(batch['label'].cpu().numpy())
@ -110,8 +111,10 @@ for epoch in range(num_epochs):
for batch_idx, batch in enumerate(train_loader):
optimizer.zero_grad()
batch_tensor = prepare_batch(batch['texts'], device="cpu")
# 传入文本字典和sentence_transformer
logits = model(input_texts=batch['texts'], sentence_transformer=sentence_transformer)
logits = model(batch_tensor)
loss = criterion(logits, batch['label'])
loss.backward()

15
lib/db/allData.ts Normal file
View File

@ -0,0 +1,15 @@
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import { AllDataType } from "lib/db/schema.d.ts";
export async function videoExistsInAllData(client: Client, aid: number) {
return await client.queryObject<{ exists: boolean }>("SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)", [aid])
.then((result) => result.rows[0].exists);
}
export async function insertIntoAllData(client: Client, data: AllDataType) {
console.log(`inserted ${data.aid}`)
return await client.queryObject(
"INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at) VALUES ($1, $2, $3, $4, $5, $6, $7)",
[data.aid, data.bvid, data.description, data.uid, data.tags, data.title, data.published_at],
);
}

9
lib/db/schema.d.ts vendored Normal file
View File

@ -0,0 +1,9 @@
export interface AllDataType {
aid: number;
bvid: string | null;
description: string | null;
uid: number | null;
tags: string | null;
title: string | null;
published_at: string | null;
}

View File

@ -0,0 +1,19 @@
import { SentenceTransformer } from "./model.ts"; // Changed import path
async function main() {
const sentenceTransformer = await SentenceTransformer.from_pretrained(
"mixedbread-ai/mxbai-embed-large-v1",
);
const outputs = await sentenceTransformer.encode([
"Hello world",
"How are you guys doing?",
"Today is Friday!",
]);
// @ts-ignore
console.log(outputs["last_hidden_state"]);
return outputs;
}
main(); // Keep main function call if you want this file to be runnable directly for testing.

View File

@ -0,0 +1,40 @@
// lib/ml/sentence_transformer_model.ts
import { AutoModel, AutoTokenizer, PretrainedOptions } from "@huggingface/transformers";
export class SentenceTransformer {
constructor(
private readonly tokenizer: AutoTokenizer,
private readonly model: AutoModel,
) {}
static async from_pretrained(
modelName: string,
options?: PretrainedOptions,
): Promise<SentenceTransformer> {
if (!options) {
options = {
progress_callback: undefined,
cache_dir: undefined,
local_files_only: false,
revision: "main",
};
}
const tokenizer = await AutoTokenizer.from_pretrained(modelName, options);
const model = await AutoModel.from_pretrained(modelName, options);
return new SentenceTransformer(tokenizer, model);
}
async encode(sentences: string[]): Promise<any> { // Changed return type to 'any' for now to match console.log output
//@ts-ignore
const modelInputs = await this.tokenizer(sentences, {
padding: true,
truncation: true,
});
//@ts-ignore
const outputs = await this.model(modelInputs);
return outputs;
}
}

View File

@ -0,0 +1,34 @@
import { Tensor } from "@huggingface/transformers";
//@ts-ignore
import { Callable } from "@huggingface/transformers/src/utils/core.js"; // Keep as is for now, might need adjustment
export interface PoolingConfig {
word_embedding_dimension: number;
pooling_mode_cls_token: boolean;
pooling_mode_mean_tokens: boolean;
pooling_mode_max_tokens: boolean;
pooling_mode_mean_sqrt_len_tokens: boolean;
}
export interface PoolingInput {
token_embeddings: Tensor;
attention_mask: Tensor;
}
export interface PoolingOutput {
sentence_embedding: Tensor;
}
export class Pooling extends Callable {
constructor(private readonly config: PoolingConfig) {
super();
}
// async _call(inputs: any) { // Keep if pooling functionality is needed
// return this.forward(inputs);
// }
// async forward(inputs: PoolingInput): PoolingOutput { // Keep if pooling functionality is needed
// }
}

32
lib/ml/classifyVideo.ts Normal file
View File

@ -0,0 +1,32 @@
import { AutoModel, AutoTokenizer, Tensor } from '@huggingface/transformers';
const modelName = "alikia2x/jina-embedding-v3-m2v-1024";
const modelConfig = {
config: { model_type: 'model2vec' },
dtype: 'fp32',
revision: 'refs/pr/1',
cache_dir: undefined,
local_files_only: true,
};
const tokenizerConfig = {
revision: 'refs/pr/2'
};
const model = await AutoModel.from_pretrained(modelName, modelConfig);
const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig);
const texts = ['hello', 'hello world'];
const { input_ids } = await tokenizer(texts, { add_special_tokens: false, return_tensor: false });
const cumsum = arr => arr.reduce((acc, num, i) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets = [0, ...cumsum(input_ids.slice(0, -1).map(x => x.length))];
const flattened_input_ids = input_ids.flat();
const modelInputs = {
input_ids: new Tensor('int64', flattened_input_ids, [flattened_input_ids.length]),
offsets: new Tensor('int64', offsets, [offsets.length])
};
const { embeddings } = await model(modelInputs);
console.log(embeddings.tolist()); // output matches python version

2
lib/mq/index.ts Normal file
View File

@ -0,0 +1,2 @@
import { Queue } from "bullmq";

117
lib/net/bilibili.d.ts vendored Normal file
View File

@ -0,0 +1,117 @@
interface BaseResponse<T> {
code: number;
message: string;
ttl: number;
data: T;
}
export type VideoListResponse = BaseResponse<VideoListData>;
export type VideoTagsResponse = BaseResponse<VideoTagsData>;
type VideoTagsData = VideoTags[];
interface VideoTags {
tag_id: number;
tag_name: string;
cover: string;
head_cover: string;
content: string;
short_content: string;
type: number;
state: number;
ctime: number;
count: {
view: number;
use: number;
atten: number;
}
is_atten: number;
likes: number;
hates: number;
attribute: number;
liked: number;
hated: number;
extra_attr: number;
}
interface VideoListData {
archives: VideoListVideo[];
page: {
num: number;
size: number;
count: number;
};
}
interface VideoListVideo {
aid: number;
videos: number;
tid: number;
tname: string;
copyright: number;
pic: string;
title: string;
pubdate: number;
ctime: number;
desc: string;
state: number;
duration: number;
mission_id?: number;
rights: {
bp: number;
elec: number;
download: number;
movie: number;
pay: number;
hd5: number;
no_reprint: number;
autoplay: number;
ugc_pay: number;
is_cooperation: number;
ugc_pay_preview: number;
no_background: number;
arc_pay: number;
pay_free_watch: number;
},
owner: {
mid: number;
name: string;
face: string;
},
stat: {
aid: number;
view: number;
danmaku: number;
reply: number;
favorite: number;
coin: number;
share: number;
now_rank: number;
his_rank: number;
like: number;
dislike: number;
vt: number;
vv: number;
},
dynamic: string;
cid: number;
dimension: {
width: number;
height: number;
rotate: number;
},
season_id?: number;
short_link_v2: string;
first_frame: string;
pub_location: string;
cover43: string;
tidv2: number;
tname_v2: string;
bvid: string;
season_type: number;
is_ogv: number;
ovg_info: string | null;
rcmd_season: string;
enable_vt: number;
ai_rcmd: null | string;
}

View File

@ -0,0 +1,48 @@
import { VideoListResponse } from "lib/net/bilibili.d.ts";
import formatPublishedAt from "lib/utils/formatTimestampToPostgre.ts";
import { getVideoTags } from "lib/net/getVideoTags.ts";
import { AllDataType } from "lib/db/schema.d.ts";
import { sleep } from "lib/utils/sleep.ts";
export async function getLatestVideos(page: number = 1, pageSize: number = 10): Promise<AllDataType[] | null> {
try {
const response = await fetch(`https://api.bilibili.com/x/web-interface/newlist?rid=30&ps=${pageSize}&pn=${page}`);
const data: VideoListResponse = await response.json();
if (data.code !== 0) {
console.error(`Error fetching videos: ${data.message}`);
return null;
}
if (data.data.archives.length === 0) {
console.warn("No more videos found");
return [];
}
const videoPromises = data.data.archives.map(async (video) => {
const published_at = formatPublishedAt(video.pubdate + 3600 * 8);
sleep(Math.random() * pageSize * 250);
const tags = await getVideoTags(video.aid);
let processedTags = null;
if (tags !== null) {
processedTags = tags.join(',');
}
return {
aid: video.aid,
bvid: video.bvid,
description: video.desc,
uid: video.owner.mid,
tags: processedTags,
title: video.title,
published_at: published_at,
} as AllDataType;
});
const result = await Promise.all(videoPromises);
return result;
} catch (error) {
console.error(error);
return null;
}
}

18
lib/net/getVideoTags.ts Normal file
View File

@ -0,0 +1,18 @@
import { VideoTagsResponse } from "lib/net/bilibili.d.ts";
export async function getVideoTags(aid: number): Promise<string[] | null> {
try {
const url = `https://api.bilibili.com/x/tag/archive/tags?aid=${aid}`;
const res = await fetch(url);
const data: VideoTagsResponse = await res.json();
if (data.code != 0) {
console.error(`Error fetching tags for video ${aid}: ${data.message}`);
return [];
}
return data.data.map((tag) => tag.tag_name);
}
catch {
console.error(`Error fetching tags for video ${aid}`);
return null;
}
}

View File

@ -0,0 +1,81 @@
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import { getLatestVideos } from "lib/net/getLatestVideos.ts";
import { insertIntoAllData, videoExistsInAllData } from "lib/db/allData.ts";
import { sleep } from "lib/utils/sleep.ts";
const requiredEnvVars = ["DB_HOST", "DB_NAME", "DB_USER", "DB_PASSWORD", "DB_PORT"];
const unsetVars = requiredEnvVars.filter((key) => Deno.env.get(key) === undefined);
if (unsetVars.length > 0) {
throw new Error(`Missing required environment variables: ${unsetVars.join(", ")}`);
}
const databaseHost = Deno.env.get("DB_HOST")!;
const databaseName = Deno.env.get("DB_NAME")!;
const databaseUser = Deno.env.get("DB_USER")!;
const databasePassword = Deno.env.get("DB_PASSWORD")!;
const databasePort = Deno.env.get("DB_PORT")!;
const postgresConfig = {
hostname: databaseHost,
port: parseInt(databasePort),
database: databaseName,
user: databaseUser,
password: databasePassword,
};
async function connectToPostgres() {
const client = new Client(postgresConfig);
await client.connect();
return client;
}
export async function insertLatestVideos() {
const client = await connectToPostgres();
let page = 334;
let failCount = 0;
while (true) {
try {
const videos = await getLatestVideos(page, 10);
if (videos == null) {
failCount++;
if (failCount > 5) {
break;
}
continue;
}
if (videos.length == 0) {
console.warn("No more videos found");
break;
}
let allExists = true;
for (const video of videos) {
const videoExists = await videoExistsInAllData(client, video.aid);
if (!videoExists) {
allExists = false;
insertIntoAllData(client, video);
}
}
if (allExists) {
console.log("All videos already exist in all_data, stop crawling.");
break;
}
console.log(`Page ${page} crawled, total: ${(page - 1) * 20 + videos.length} videos.`);
page++;
} catch (error) {
console.error(error);
failCount++;
if (failCount > 5) {
break;
}
continue;
}
finally {
await sleep(Math.random() * 4000 + 1000);
}
}
}
insertLatestVideos();

View File

@ -0,0 +1,4 @@
export default function formatTimestamp(timestamp: number) {
const date = new Date(timestamp * 1000);
return date.toISOString().slice(0, 19).replace("T", " ");
}

3
lib/utils/sleep.ts Normal file
View File

@ -0,0 +1,3 @@
export async function sleep(ms: number) {
await new Promise((resolve) => setTimeout(resolve, ms));
}

View File

@ -1,6 +1,7 @@
import { Buffer } from "node:buffer";
"use strict";
export const handler = async (event, context) => {
export const handler = async (event, _context) => {
const eventObj = JSON.parse(event);
console.log(`receive event: ${JSON.stringify(eventObj)}`);

View File

@ -1,35 +0,0 @@
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
const API_URL = "https://api.bilibili.com/x/web-interface/newlist?rid=30&ps=50&pn=";
const requiredEnvVars = ["DB_HOST", "DB_NAME", "DB_USER", "DB_PASSWORD", "DB_PORT"];
const unsetVars = requiredEnvVars.filter((key) => !Deno.env.get(key));
if (unsetVars.length > 0) {
throw new Error(`Missing required environment variables: ${unsetVars.join(", ")}`);
}
const databaseHost = Deno.env.get("DB_HOST")!;
const databaseName = Deno.env.get("DB_NAME")!;
const databaseUser = Deno.env.get("DB_USER")!;
const databasePassword = Deno.env.get("DB_PASSWORD")!;
const databasePort = Deno.env.get("DB_PORT")!;
const postgresConfig = {
hostname: databaseHost,
port: parseInt(databasePort),
database: databaseName,
user: databaseUser,
password: databasePassword,
};
async function connectToPostgres() {
const client = new Client(postgresConfig);
await client.connect();
return client;
}
export async function getLatestVideos() {
const client = await connectToPostgres();
}

View File

@ -0,0 +1,25 @@
import { assertEquals } from "jsr:@std/assert";
import { getLatestVideos } from "lib/net/getLatestVideos.ts";
Deno.test("Get latest videos", async () => {
const videos = (await getLatestVideos(1, 5))!;
assertEquals(videos.length, 5);
videos.forEach((video) => {
assertVideoProperties(video);
});
});
function assertVideoProperties(video: object) {
const aid = "aid" in video && typeof video.aid === "number";
const bvid = "bvid" in video && typeof video.bvid === "string" &&
video.bvid.length === 12 && video.bvid.startsWith("BV");
const description = "description" in video && typeof video.description === "string";
const uid = "uid" in video && typeof video.uid === "number";
const tags = "tags" in video && (typeof video.tags === "string" || video.tags === null);
const title = "title" in video && typeof video.title === "string";
const publishedAt = "published_at" in video && typeof video.published_at === "string";
const match = aid && bvid && description && uid && tags && title && publishedAt;
assertEquals(match, true);
}

View File

@ -0,0 +1,28 @@
import { assertEquals } from "jsr:@std/assert";
import { getVideoTags } from "lib/net/getVideoTags.ts";
Deno.test("Get video tags - regular video", async () => {
const tags = (await getVideoTags(826597951)).sort();
assertEquals(tags, [
"纯白P",
"中华墨水娘",
"中华少女",
"中华粘土娘",
"中华缘木娘",
"中华少女Project",
"提糯Tino",
"中华烛火娘",
"中华烁金娘",
"新世代音乐人计划女生季",
].sort());
});
Deno.test("Get video tags - non-existent video", async () => {
const tags = (await getVideoTags(8265979511111111));
assertEquals(tags, []);
});
Deno.test("Get video tags - video with no tag", async () => {
const tags = (await getVideoTags(981001865));
assertEquals(tags, []);
});