update: structurize video insertion code, use js-compatible filter model arch
This commit is contained in:
parent
cf4ff398b8
commit
b21e6da07a
27
.gitignore
vendored
27
.gitignore
vendored
@ -61,19 +61,6 @@ TEST-results.xml
|
|||||||
package-lock.json
|
package-lock.json
|
||||||
.eslintcache
|
.eslintcache
|
||||||
*v8.log
|
*v8.log
|
||||||
/lib/
|
|
||||||
|
|
||||||
# project specific
|
|
||||||
data/main.db
|
|
||||||
.env
|
|
||||||
logs/
|
|
||||||
__pycache__
|
|
||||||
filter/runs
|
|
||||||
data/filter/eval*
|
|
||||||
data/filter/train*
|
|
||||||
filter/checkpoints
|
|
||||||
data/filter/model_predicted*
|
|
||||||
scripts
|
|
||||||
|
|
||||||
# dotenv environment variable files
|
# dotenv environment variable files
|
||||||
.env
|
.env
|
||||||
@ -86,3 +73,17 @@ scripts
|
|||||||
_fresh/
|
_fresh/
|
||||||
# npm dependencies
|
# npm dependencies
|
||||||
node_modules/
|
node_modules/
|
||||||
|
|
||||||
|
|
||||||
|
# project specific
|
||||||
|
data/main.db
|
||||||
|
.env
|
||||||
|
logs/
|
||||||
|
__pycache__
|
||||||
|
filter/runs
|
||||||
|
data/filter/eval*
|
||||||
|
data/filter/train*
|
||||||
|
filter/checkpoints
|
||||||
|
data/filter/model_predicted*
|
||||||
|
scripts
|
||||||
|
model/
|
||||||
|
@ -19,9 +19,6 @@
|
|||||||
"exclude": ["**/_fresh/*"],
|
"exclude": ["**/_fresh/*"],
|
||||||
"imports": {
|
"imports": {
|
||||||
"@std/assert": "jsr:@std/assert@1",
|
"@std/assert": "jsr:@std/assert@1",
|
||||||
"@types/better-sqlite3": "npm:@types/better-sqlite3@^7.6.12",
|
|
||||||
"axios": "npm:axios@^1.7.9",
|
|
||||||
"better-sqlite3": "npm:better-sqlite3@^11.7.2",
|
|
||||||
"$fresh/": "https://deno.land/x/fresh@1.7.3/",
|
"$fresh/": "https://deno.land/x/fresh@1.7.3/",
|
||||||
"preact": "https://esm.sh/preact@10.22.0",
|
"preact": "https://esm.sh/preact@10.22.0",
|
||||||
"preact/": "https://esm.sh/preact@10.22.0/",
|
"preact/": "https://esm.sh/preact@10.22.0/",
|
||||||
@ -30,7 +27,10 @@
|
|||||||
"tailwindcss": "npm:tailwindcss@3.4.1",
|
"tailwindcss": "npm:tailwindcss@3.4.1",
|
||||||
"tailwindcss/": "npm:/tailwindcss@3.4.1/",
|
"tailwindcss/": "npm:/tailwindcss@3.4.1/",
|
||||||
"tailwindcss/plugin": "npm:/tailwindcss@3.4.1/plugin.js",
|
"tailwindcss/plugin": "npm:/tailwindcss@3.4.1/plugin.js",
|
||||||
"$std/": "https://deno.land/std@0.216.0/"
|
"$std/": "https://deno.land/std@0.216.0/",
|
||||||
|
"@huggingface/transformers": "npm:@huggingface/transformers@3.0.0",
|
||||||
|
"bullmq": "npm:bullmq",
|
||||||
|
"lib/": "./lib/"
|
||||||
},
|
},
|
||||||
"compilerOptions": {
|
"compilerOptions": {
|
||||||
"jsx": "react-jsx",
|
"jsx": "react-jsx",
|
||||||
|
26
filter/checkpoint_conversion.py
Normal file
26
filter/checkpoint_conversion.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from modelV3_10 import VideoClassifierV3_10
|
||||||
|
from modelV3_9 import VideoClassifierV3_9
|
||||||
|
|
||||||
|
|
||||||
|
def convert_checkpoint(original_model, new_model):
|
||||||
|
"""转换原始checkpoint到新结构"""
|
||||||
|
state_dict = original_model.state_dict()
|
||||||
|
|
||||||
|
# 直接复制所有参数(因为结构保持兼容)
|
||||||
|
new_model.load_state_dict(state_dict)
|
||||||
|
return new_model
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
original_model = VideoClassifierV3_9()
|
||||||
|
new_model = VideoClassifierV3_10()
|
||||||
|
|
||||||
|
# 加载原始checkpoint
|
||||||
|
original_model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt'))
|
||||||
|
|
||||||
|
# 转换参数
|
||||||
|
converted_model = convert_checkpoint(original_model, new_model)
|
||||||
|
|
||||||
|
# 保存转换后的模型
|
||||||
|
torch.save(converted_model.state_dict(), './filter/checkpoints/best_model_V3.10.pt')
|
@ -97,7 +97,7 @@ class MultiChannelDataset(Dataset):
|
|||||||
example = self.examples[idx]
|
example = self.examples[idx]
|
||||||
|
|
||||||
# 处理tags(将数组转换为空格分隔的字符串)
|
# 处理tags(将数组转换为空格分隔的字符串)
|
||||||
tags_text = " ".join(example['tags'])
|
tags_text = ",".join(example['tags'])
|
||||||
|
|
||||||
# 返回文本字典
|
# 返回文本字典
|
||||||
texts = {
|
texts = {
|
||||||
|
31
filter/embedding.py
Normal file
31
filter/embedding.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
from model2vec import StaticModel
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_batch(batch_data, device="cpu"):
|
||||||
|
"""
|
||||||
|
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, embedding_dim]。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
batch_data (dict): 输入的 batch 数据,格式为 {
|
||||||
|
"title": [text1, text2, ...],
|
||||||
|
"description": [text1, text2, ...],
|
||||||
|
"tags": [text1, text2, ...],
|
||||||
|
"author_info": [text1, text2, ...]
|
||||||
|
}
|
||||||
|
device (str): 模型运行的设备(如 "cpu" 或 "cuda")。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量。
|
||||||
|
"""
|
||||||
|
# 1. 对每个通道的文本分别编码
|
||||||
|
channel_embeddings = []
|
||||||
|
model = StaticModel.from_pretrained("./model/embedding/")
|
||||||
|
for channel in ["title", "description", "tags", "author_info"]:
|
||||||
|
texts = batch_data[channel] # 获取当前通道的文本列表
|
||||||
|
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
|
||||||
|
channel_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
# 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim]
|
||||||
|
batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠
|
||||||
|
return batch_tensor
|
@ -1,58 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class VideoClassifierV3_1(nn.Module):
|
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=384, output_dim=3):
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = 4
|
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
|
||||||
|
|
||||||
# 改进1:带温度系数的通道权重(比原始固定权重更灵活)
|
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
|
||||||
self.temperature = 1.7 # 可调节的平滑系数
|
|
||||||
|
|
||||||
# 改进2:更稳健的全连接结构
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
|
||||||
nn.BatchNorm1d(hidden_dim*2),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim*2, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.Linear(hidden_dim, output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 改进3:输出层初始化
|
|
||||||
nn.init.xavier_uniform_(self.fc[-1].weight)
|
|
||||||
nn.init.zeros_(self.fc[-1].bias)
|
|
||||||
|
|
||||||
def forward(self, input_texts, sentence_transformer):
|
|
||||||
# 合并所有通道文本进行批量编码
|
|
||||||
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
|
|
||||||
|
|
||||||
# 使用SentenceTransformer生成嵌入(保持冻结)
|
|
||||||
with torch.no_grad():
|
|
||||||
task = "classification"
|
|
||||||
embeddings = torch.tensor(
|
|
||||||
sentence_transformer.encode(all_texts, task=task),
|
|
||||||
device=next(self.parameters()).device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 分割嵌入并加权
|
|
||||||
split_sizes = [len(input_texts[name]) for name in self.channel_names]
|
|
||||||
channel_features = torch.split(embeddings, split_sizes, dim=0)
|
|
||||||
channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024]
|
|
||||||
|
|
||||||
# 改进4:带温度系数的softmax加权
|
|
||||||
weights = torch.softmax(self.channel_weights / self.temperature, dim=0)
|
|
||||||
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
|
|
||||||
|
|
||||||
# 拼接特征
|
|
||||||
combined = weighted_features.view(weighted_features.size(0), -1)
|
|
||||||
|
|
||||||
# 全连接层
|
|
||||||
return self.fc(combined)
|
|
||||||
|
|
||||||
def get_channel_weights(self):
|
|
||||||
"""获取各通道权重(带温度调节)"""
|
|
||||||
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
|
97
filter/modelV3_10.py
Normal file
97
filter/modelV3_10.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class VideoClassifierV3_10(nn.Module):
|
||||||
|
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = 4
|
||||||
|
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
||||||
|
|
||||||
|
# 可学习温度系数
|
||||||
|
self.temperature = nn.Parameter(torch.tensor(1.7))
|
||||||
|
|
||||||
|
# 带约束的通道权重(使用Sigmoid替代Softmax)
|
||||||
|
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
||||||
|
|
||||||
|
# 增强的非线性层
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
||||||
|
nn.BatchNorm1d(hidden_dim*2),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(hidden_dim*2, output_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 权重初始化
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for layer in self.fc:
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
# 使用ReLU的初始化参数(GELU的近似)
|
||||||
|
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里
|
||||||
|
|
||||||
|
# 或者使用Xavier初始化(更适合通用场景)
|
||||||
|
# nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu'))
|
||||||
|
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, channel_features: torch.Tensor):
|
||||||
|
"""
|
||||||
|
输入格式: [batch_size, num_channels, embedding_dim]
|
||||||
|
输出格式: [batch_size, output_dim]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 自适应通道权重(Sigmoid约束)
|
||||||
|
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
|
||||||
|
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
|
||||||
|
|
||||||
|
# 特征拼接
|
||||||
|
combined = weighted_features.view(weighted_features.size(0), -1)
|
||||||
|
|
||||||
|
return self.fc(combined)
|
||||||
|
|
||||||
|
def get_channel_weights(self):
|
||||||
|
"""获取各通道权重(带温度调节)"""
|
||||||
|
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveRecallLoss(nn.Module):
|
||||||
|
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
class_weights (torch.Tensor): 类别权重
|
||||||
|
alpha (float): 召回率调节因子(0-1)
|
||||||
|
gamma (float): Focal Loss参数
|
||||||
|
fp_penalty (float): 类别0假阳性惩罚强度
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.class_weights = class_weights
|
||||||
|
self.alpha = alpha
|
||||||
|
self.gamma = gamma
|
||||||
|
self.fp_penalty = fp_penalty
|
||||||
|
|
||||||
|
def forward(self, logits, targets):
|
||||||
|
# 基础交叉熵损失
|
||||||
|
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
|
||||||
|
|
||||||
|
# Focal Loss组件
|
||||||
|
pt = torch.exp(-ce_loss)
|
||||||
|
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
|
||||||
|
|
||||||
|
# 召回率增强(对困难样本加权)
|
||||||
|
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
|
||||||
|
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
|
||||||
|
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
|
||||||
|
|
||||||
|
# 类别0假阳性惩罚
|
||||||
|
probs = F.softmax(logits, dim=1)
|
||||||
|
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
|
||||||
|
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
|
||||||
|
|
||||||
|
# 总损失
|
||||||
|
total_loss = recall_loss.mean() + fp_loss / len(targets)
|
||||||
|
|
||||||
|
return total_loss
|
@ -1,58 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class VideoClassifierV3_2(nn.Module):
|
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3):
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = 4
|
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
|
||||||
|
|
||||||
# 改进1:带温度系数的通道权重(比原始固定权重更灵活)
|
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
|
||||||
self.temperature = 1.7 # 可调节的平滑系数
|
|
||||||
|
|
||||||
# 改进2:更稳健的全连接结构
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
|
||||||
nn.BatchNorm1d(hidden_dim*2),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim*2, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.Linear(hidden_dim, output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 改进3:输出层初始化
|
|
||||||
nn.init.xavier_uniform_(self.fc[-1].weight)
|
|
||||||
nn.init.zeros_(self.fc[-1].bias)
|
|
||||||
|
|
||||||
def forward(self, input_texts, sentence_transformer):
|
|
||||||
# 合并所有通道文本进行批量编码
|
|
||||||
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
|
|
||||||
|
|
||||||
# 使用SentenceTransformer生成嵌入(保持冻结)
|
|
||||||
with torch.no_grad():
|
|
||||||
task = "classification"
|
|
||||||
embeddings = torch.tensor(
|
|
||||||
sentence_transformer.encode(all_texts, task=task),
|
|
||||||
device=next(self.parameters()).device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 分割嵌入并加权
|
|
||||||
split_sizes = [len(input_texts[name]) for name in self.channel_names]
|
|
||||||
channel_features = torch.split(embeddings, split_sizes, dim=0)
|
|
||||||
channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024]
|
|
||||||
|
|
||||||
# 改进4:带温度系数的softmax加权
|
|
||||||
weights = torch.softmax(self.channel_weights / self.temperature, dim=0)
|
|
||||||
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
|
|
||||||
|
|
||||||
# 拼接特征
|
|
||||||
combined = weighted_features.view(weighted_features.size(0), -1)
|
|
||||||
|
|
||||||
# 全连接层
|
|
||||||
return self.fc(combined)
|
|
||||||
|
|
||||||
def get_channel_weights(self):
|
|
||||||
"""获取各通道权重(带温度调节)"""
|
|
||||||
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
|
@ -1,56 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class VideoClassifierV3_3(nn.Module):
|
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3):
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = 4
|
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
|
||||||
|
|
||||||
# 带温度系数的通道权重(比原始固定权重更灵活)
|
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
|
||||||
self.temperature = 1.7 # 可调节的平滑系数
|
|
||||||
|
|
||||||
# 改进后的非线性层
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
|
||||||
nn.BatchNorm1d(hidden_dim*2),
|
|
||||||
nn.Dropout(0.1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim*2, output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 输出层初始化
|
|
||||||
nn.init.xavier_uniform_(self.fc[-1].weight)
|
|
||||||
nn.init.zeros_(self.fc[-1].bias)
|
|
||||||
|
|
||||||
def forward(self, input_texts, sentence_transformer):
|
|
||||||
# 合并所有通道文本进行批量编码
|
|
||||||
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
|
|
||||||
|
|
||||||
# 使用SentenceTransformer生成嵌入(保持冻结)
|
|
||||||
with torch.no_grad():
|
|
||||||
task = "classification"
|
|
||||||
embeddings = torch.tensor(
|
|
||||||
sentence_transformer.encode(all_texts, task=task),
|
|
||||||
device=next(self.parameters()).device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 分割嵌入并加权
|
|
||||||
split_sizes = [len(input_texts[name]) for name in self.channel_names]
|
|
||||||
channel_features = torch.split(embeddings, split_sizes, dim=0)
|
|
||||||
channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024]
|
|
||||||
|
|
||||||
# 改进4:带温度系数的softmax加权
|
|
||||||
weights = torch.softmax(self.channel_weights / self.temperature, dim=0)
|
|
||||||
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
|
|
||||||
|
|
||||||
# 拼接特征
|
|
||||||
combined = weighted_features.view(weighted_features.size(0), -1)
|
|
||||||
|
|
||||||
# 全连接层
|
|
||||||
return self.fc(combined)
|
|
||||||
|
|
||||||
def get_channel_weights(self):
|
|
||||||
"""获取各通道权重(带温度调节)"""
|
|
||||||
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
|
32
filter/onnx_export.py
Normal file
32
filter/onnx_export.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
from modelV3_10 import VideoClassifierV3_10
|
||||||
|
|
||||||
|
|
||||||
|
def export_onnx(model_path="./filter/checkpoints/best_model_V3.10.pt",
|
||||||
|
onnx_path="./model/video_classifier_v3_10.onnx"):
|
||||||
|
# 初始化模型
|
||||||
|
model = VideoClassifierV3_10()
|
||||||
|
model.load_state_dict(torch.load(model_path))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 创建符合输入规范的虚拟输入
|
||||||
|
dummy_input = torch.randn(1, 4, 1024) # [batch=1, channels=4, embedding_dim=1024]
|
||||||
|
|
||||||
|
# 导出ONNX
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
dummy_input,
|
||||||
|
onnx_path,
|
||||||
|
input_names=["channel_features"],
|
||||||
|
output_names=["logits"],
|
||||||
|
dynamic_axes={
|
||||||
|
"channel_features": {0: "batch_size"},
|
||||||
|
"logits": {0: "batch_size"}
|
||||||
|
},
|
||||||
|
opset_version=13,
|
||||||
|
do_constant_folding=True
|
||||||
|
)
|
||||||
|
print(f"模型已成功导出到 {onnx_path}")
|
||||||
|
|
||||||
|
# 执行导出
|
||||||
|
export_onnx()
|
@ -4,13 +4,14 @@ import numpy as np
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from dataset import MultiChannelDataset
|
from dataset import MultiChannelDataset
|
||||||
from filter.modelV3_9 import VideoClassifierV3_9, AdaptiveRecallLoss
|
from filter.modelV3_10 import VideoClassifierV3_10, AdaptiveRecallLoss
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter # 引入 TensorBoard
|
from torch.utils.tensorboard import SummaryWriter # 引入 TensorBoard
|
||||||
import time
|
import time
|
||||||
|
from embedding import prepare_batch
|
||||||
|
|
||||||
|
|
||||||
# 动态生成子目录名称
|
# 动态生成子目录名称
|
||||||
@ -52,8 +53,8 @@ class_weights = torch.tensor(
|
|||||||
|
|
||||||
# 初始化模型和SentenceTransformer
|
# 初始化模型和SentenceTransformer
|
||||||
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
||||||
model = VideoClassifierV3_9()
|
model = VideoClassifierV3_10()
|
||||||
checkpoint_name = './filter/checkpoints/best_model_V3.9.pt'
|
checkpoint_name = './filter/checkpoints/best_model_V3.11.pt'
|
||||||
|
|
||||||
# 模型保存路径
|
# 模型保存路径
|
||||||
os.makedirs('./filter/checkpoints', exist_ok=True)
|
os.makedirs('./filter/checkpoints', exist_ok=True)
|
||||||
@ -77,8 +78,8 @@ def evaluate(model, dataloader):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
# 传入文本字典和sentence_transformer
|
batch_tensor = prepare_batch(batch['texts'], device="cpu")
|
||||||
logits = model(input_texts=batch['texts'], sentence_transformer=sentence_transformer)
|
logits = model(batch_tensor)
|
||||||
preds = torch.argmax(logits, dim=1)
|
preds = torch.argmax(logits, dim=1)
|
||||||
all_preds.extend(preds.cpu().numpy())
|
all_preds.extend(preds.cpu().numpy())
|
||||||
all_labels.extend(batch['label'].cpu().numpy())
|
all_labels.extend(batch['label'].cpu().numpy())
|
||||||
@ -110,8 +111,10 @@ for epoch in range(num_epochs):
|
|||||||
for batch_idx, batch in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
batch_tensor = prepare_batch(batch['texts'], device="cpu")
|
||||||
|
|
||||||
# 传入文本字典和sentence_transformer
|
# 传入文本字典和sentence_transformer
|
||||||
logits = model(input_texts=batch['texts'], sentence_transformer=sentence_transformer)
|
logits = model(batch_tensor)
|
||||||
|
|
||||||
loss = criterion(logits, batch['label'])
|
loss = criterion(logits, batch['label'])
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
15
lib/db/allData.ts
Normal file
15
lib/db/allData.ts
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||||
|
import { AllDataType } from "lib/db/schema.d.ts";
|
||||||
|
|
||||||
|
export async function videoExistsInAllData(client: Client, aid: number) {
|
||||||
|
return await client.queryObject<{ exists: boolean }>("SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)", [aid])
|
||||||
|
.then((result) => result.rows[0].exists);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function insertIntoAllData(client: Client, data: AllDataType) {
|
||||||
|
console.log(`inserted ${data.aid}`)
|
||||||
|
return await client.queryObject(
|
||||||
|
"INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||||
|
[data.aid, data.bvid, data.description, data.uid, data.tags, data.title, data.published_at],
|
||||||
|
);
|
||||||
|
}
|
9
lib/db/schema.d.ts
vendored
Normal file
9
lib/db/schema.d.ts
vendored
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
export interface AllDataType {
|
||||||
|
aid: number;
|
||||||
|
bvid: string | null;
|
||||||
|
description: string | null;
|
||||||
|
uid: number | null;
|
||||||
|
tags: string | null;
|
||||||
|
title: string | null;
|
||||||
|
published_at: string | null;
|
||||||
|
}
|
19
lib/ml/SentenceTransformer/index.ts
Normal file
19
lib/ml/SentenceTransformer/index.ts
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import { SentenceTransformer } from "./model.ts"; // Changed import path
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const sentenceTransformer = await SentenceTransformer.from_pretrained(
|
||||||
|
"mixedbread-ai/mxbai-embed-large-v1",
|
||||||
|
);
|
||||||
|
const outputs = await sentenceTransformer.encode([
|
||||||
|
"Hello world",
|
||||||
|
"How are you guys doing?",
|
||||||
|
"Today is Friday!",
|
||||||
|
]);
|
||||||
|
|
||||||
|
// @ts-ignore
|
||||||
|
console.log(outputs["last_hidden_state"]);
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
main(); // Keep main function call if you want this file to be runnable directly for testing.
|
40
lib/ml/SentenceTransformer/model.ts
Normal file
40
lib/ml/SentenceTransformer/model.ts
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// lib/ml/sentence_transformer_model.ts
|
||||||
|
import { AutoModel, AutoTokenizer, PretrainedOptions } from "@huggingface/transformers";
|
||||||
|
|
||||||
|
export class SentenceTransformer {
|
||||||
|
constructor(
|
||||||
|
private readonly tokenizer: AutoTokenizer,
|
||||||
|
private readonly model: AutoModel,
|
||||||
|
) {}
|
||||||
|
|
||||||
|
static async from_pretrained(
|
||||||
|
modelName: string,
|
||||||
|
options?: PretrainedOptions,
|
||||||
|
): Promise<SentenceTransformer> {
|
||||||
|
if (!options) {
|
||||||
|
options = {
|
||||||
|
progress_callback: undefined,
|
||||||
|
cache_dir: undefined,
|
||||||
|
local_files_only: false,
|
||||||
|
revision: "main",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
const tokenizer = await AutoTokenizer.from_pretrained(modelName, options);
|
||||||
|
const model = await AutoModel.from_pretrained(modelName, options);
|
||||||
|
|
||||||
|
return new SentenceTransformer(tokenizer, model);
|
||||||
|
}
|
||||||
|
|
||||||
|
async encode(sentences: string[]): Promise<any> { // Changed return type to 'any' for now to match console.log output
|
||||||
|
//@ts-ignore
|
||||||
|
const modelInputs = await this.tokenizer(sentences, {
|
||||||
|
padding: true,
|
||||||
|
truncation: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
//@ts-ignore
|
||||||
|
const outputs = await this.model(modelInputs);
|
||||||
|
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
}
|
34
lib/ml/SentenceTransformer/pooling.ts
Normal file
34
lib/ml/SentenceTransformer/pooling.ts
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import { Tensor } from "@huggingface/transformers";
|
||||||
|
//@ts-ignore
|
||||||
|
import { Callable } from "@huggingface/transformers/src/utils/core.js"; // Keep as is for now, might need adjustment
|
||||||
|
|
||||||
|
export interface PoolingConfig {
|
||||||
|
word_embedding_dimension: number;
|
||||||
|
pooling_mode_cls_token: boolean;
|
||||||
|
pooling_mode_mean_tokens: boolean;
|
||||||
|
pooling_mode_max_tokens: boolean;
|
||||||
|
pooling_mode_mean_sqrt_len_tokens: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PoolingInput {
|
||||||
|
token_embeddings: Tensor;
|
||||||
|
attention_mask: Tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PoolingOutput {
|
||||||
|
sentence_embedding: Tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class Pooling extends Callable {
|
||||||
|
constructor(private readonly config: PoolingConfig) {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
// async _call(inputs: any) { // Keep if pooling functionality is needed
|
||||||
|
// return this.forward(inputs);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// async forward(inputs: PoolingInput): PoolingOutput { // Keep if pooling functionality is needed
|
||||||
|
|
||||||
|
// }
|
||||||
|
}
|
32
lib/ml/classifyVideo.ts
Normal file
32
lib/ml/classifyVideo.ts
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import { AutoModel, AutoTokenizer, Tensor } from '@huggingface/transformers';
|
||||||
|
|
||||||
|
const modelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
||||||
|
|
||||||
|
const modelConfig = {
|
||||||
|
config: { model_type: 'model2vec' },
|
||||||
|
dtype: 'fp32',
|
||||||
|
revision: 'refs/pr/1',
|
||||||
|
cache_dir: undefined,
|
||||||
|
local_files_only: true,
|
||||||
|
};
|
||||||
|
const tokenizerConfig = {
|
||||||
|
revision: 'refs/pr/2'
|
||||||
|
};
|
||||||
|
|
||||||
|
const model = await AutoModel.from_pretrained(modelName, modelConfig);
|
||||||
|
const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig);
|
||||||
|
|
||||||
|
const texts = ['hello', 'hello world'];
|
||||||
|
const { input_ids } = await tokenizer(texts, { add_special_tokens: false, return_tensor: false });
|
||||||
|
|
||||||
|
const cumsum = arr => arr.reduce((acc, num, i) => [...acc, num + (acc[i - 1] || 0)], []);
|
||||||
|
const offsets = [0, ...cumsum(input_ids.slice(0, -1).map(x => x.length))];
|
||||||
|
|
||||||
|
const flattened_input_ids = input_ids.flat();
|
||||||
|
const modelInputs = {
|
||||||
|
input_ids: new Tensor('int64', flattened_input_ids, [flattened_input_ids.length]),
|
||||||
|
offsets: new Tensor('int64', offsets, [offsets.length])
|
||||||
|
};
|
||||||
|
|
||||||
|
const { embeddings } = await model(modelInputs);
|
||||||
|
console.log(embeddings.tolist()); // output matches python version
|
2
lib/mq/index.ts
Normal file
2
lib/mq/index.ts
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
import { Queue } from "bullmq";
|
||||||
|
|
117
lib/net/bilibili.d.ts
vendored
Normal file
117
lib/net/bilibili.d.ts
vendored
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
interface BaseResponse<T> {
|
||||||
|
code: number;
|
||||||
|
message: string;
|
||||||
|
ttl: number;
|
||||||
|
data: T;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type VideoListResponse = BaseResponse<VideoListData>;
|
||||||
|
export type VideoTagsResponse = BaseResponse<VideoTagsData>;
|
||||||
|
|
||||||
|
type VideoTagsData = VideoTags[];
|
||||||
|
|
||||||
|
interface VideoTags {
|
||||||
|
tag_id: number;
|
||||||
|
tag_name: string;
|
||||||
|
cover: string;
|
||||||
|
head_cover: string;
|
||||||
|
content: string;
|
||||||
|
short_content: string;
|
||||||
|
type: number;
|
||||||
|
state: number;
|
||||||
|
ctime: number;
|
||||||
|
count: {
|
||||||
|
view: number;
|
||||||
|
use: number;
|
||||||
|
atten: number;
|
||||||
|
}
|
||||||
|
is_atten: number;
|
||||||
|
likes: number;
|
||||||
|
hates: number;
|
||||||
|
attribute: number;
|
||||||
|
liked: number;
|
||||||
|
hated: number;
|
||||||
|
extra_attr: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VideoListData {
|
||||||
|
archives: VideoListVideo[];
|
||||||
|
page: {
|
||||||
|
num: number;
|
||||||
|
size: number;
|
||||||
|
count: number;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VideoListVideo {
|
||||||
|
aid: number;
|
||||||
|
videos: number;
|
||||||
|
tid: number;
|
||||||
|
tname: string;
|
||||||
|
copyright: number;
|
||||||
|
pic: string;
|
||||||
|
title: string;
|
||||||
|
pubdate: number;
|
||||||
|
ctime: number;
|
||||||
|
desc: string;
|
||||||
|
state: number;
|
||||||
|
duration: number;
|
||||||
|
mission_id?: number;
|
||||||
|
rights: {
|
||||||
|
bp: number;
|
||||||
|
elec: number;
|
||||||
|
download: number;
|
||||||
|
movie: number;
|
||||||
|
pay: number;
|
||||||
|
hd5: number;
|
||||||
|
no_reprint: number;
|
||||||
|
autoplay: number;
|
||||||
|
ugc_pay: number;
|
||||||
|
is_cooperation: number;
|
||||||
|
ugc_pay_preview: number;
|
||||||
|
no_background: number;
|
||||||
|
arc_pay: number;
|
||||||
|
pay_free_watch: number;
|
||||||
|
},
|
||||||
|
owner: {
|
||||||
|
mid: number;
|
||||||
|
name: string;
|
||||||
|
face: string;
|
||||||
|
},
|
||||||
|
stat: {
|
||||||
|
aid: number;
|
||||||
|
view: number;
|
||||||
|
danmaku: number;
|
||||||
|
reply: number;
|
||||||
|
favorite: number;
|
||||||
|
coin: number;
|
||||||
|
share: number;
|
||||||
|
now_rank: number;
|
||||||
|
his_rank: number;
|
||||||
|
like: number;
|
||||||
|
dislike: number;
|
||||||
|
vt: number;
|
||||||
|
vv: number;
|
||||||
|
},
|
||||||
|
dynamic: string;
|
||||||
|
cid: number;
|
||||||
|
dimension: {
|
||||||
|
width: number;
|
||||||
|
height: number;
|
||||||
|
rotate: number;
|
||||||
|
},
|
||||||
|
season_id?: number;
|
||||||
|
short_link_v2: string;
|
||||||
|
first_frame: string;
|
||||||
|
pub_location: string;
|
||||||
|
cover43: string;
|
||||||
|
tidv2: number;
|
||||||
|
tname_v2: string;
|
||||||
|
bvid: string;
|
||||||
|
season_type: number;
|
||||||
|
is_ogv: number;
|
||||||
|
ovg_info: string | null;
|
||||||
|
rcmd_season: string;
|
||||||
|
enable_vt: number;
|
||||||
|
ai_rcmd: null | string;
|
||||||
|
}
|
48
lib/net/getLatestVideos.ts
Normal file
48
lib/net/getLatestVideos.ts
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import { VideoListResponse } from "lib/net/bilibili.d.ts";
|
||||||
|
import formatPublishedAt from "lib/utils/formatTimestampToPostgre.ts";
|
||||||
|
import { getVideoTags } from "lib/net/getVideoTags.ts";
|
||||||
|
import { AllDataType } from "lib/db/schema.d.ts";
|
||||||
|
import { sleep } from "lib/utils/sleep.ts";
|
||||||
|
|
||||||
|
export async function getLatestVideos(page: number = 1, pageSize: number = 10): Promise<AllDataType[] | null> {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`https://api.bilibili.com/x/web-interface/newlist?rid=30&ps=${pageSize}&pn=${page}`);
|
||||||
|
const data: VideoListResponse = await response.json();
|
||||||
|
|
||||||
|
if (data.code !== 0) {
|
||||||
|
console.error(`Error fetching videos: ${data.message}`);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (data.data.archives.length === 0) {
|
||||||
|
console.warn("No more videos found");
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const videoPromises = data.data.archives.map(async (video) => {
|
||||||
|
const published_at = formatPublishedAt(video.pubdate + 3600 * 8);
|
||||||
|
sleep(Math.random() * pageSize * 250);
|
||||||
|
const tags = await getVideoTags(video.aid);
|
||||||
|
let processedTags = null;
|
||||||
|
if (tags !== null) {
|
||||||
|
processedTags = tags.join(',');
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
aid: video.aid,
|
||||||
|
bvid: video.bvid,
|
||||||
|
description: video.desc,
|
||||||
|
uid: video.owner.mid,
|
||||||
|
tags: processedTags,
|
||||||
|
title: video.title,
|
||||||
|
published_at: published_at,
|
||||||
|
} as AllDataType;
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await Promise.all(videoPromises);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
} catch (error) {
|
||||||
|
console.error(error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
18
lib/net/getVideoTags.ts
Normal file
18
lib/net/getVideoTags.ts
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import { VideoTagsResponse } from "lib/net/bilibili.d.ts";
|
||||||
|
|
||||||
|
export async function getVideoTags(aid: number): Promise<string[] | null> {
|
||||||
|
try {
|
||||||
|
const url = `https://api.bilibili.com/x/tag/archive/tags?aid=${aid}`;
|
||||||
|
const res = await fetch(url);
|
||||||
|
const data: VideoTagsResponse = await res.json();
|
||||||
|
if (data.code != 0) {
|
||||||
|
console.error(`Error fetching tags for video ${aid}: ${data.message}`);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
return data.data.map((tag) => tag.tag_name);
|
||||||
|
}
|
||||||
|
catch {
|
||||||
|
console.error(`Error fetching tags for video ${aid}`);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
81
lib/task/insertLatestVideo.ts
Normal file
81
lib/task/insertLatestVideo.ts
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||||
|
import { getLatestVideos } from "lib/net/getLatestVideos.ts";
|
||||||
|
import { insertIntoAllData, videoExistsInAllData } from "lib/db/allData.ts";
|
||||||
|
import { sleep } from "lib/utils/sleep.ts";
|
||||||
|
|
||||||
|
const requiredEnvVars = ["DB_HOST", "DB_NAME", "DB_USER", "DB_PASSWORD", "DB_PORT"];
|
||||||
|
|
||||||
|
const unsetVars = requiredEnvVars.filter((key) => Deno.env.get(key) === undefined);
|
||||||
|
|
||||||
|
if (unsetVars.length > 0) {
|
||||||
|
throw new Error(`Missing required environment variables: ${unsetVars.join(", ")}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const databaseHost = Deno.env.get("DB_HOST")!;
|
||||||
|
const databaseName = Deno.env.get("DB_NAME")!;
|
||||||
|
const databaseUser = Deno.env.get("DB_USER")!;
|
||||||
|
const databasePassword = Deno.env.get("DB_PASSWORD")!;
|
||||||
|
const databasePort = Deno.env.get("DB_PORT")!;
|
||||||
|
|
||||||
|
const postgresConfig = {
|
||||||
|
hostname: databaseHost,
|
||||||
|
port: parseInt(databasePort),
|
||||||
|
database: databaseName,
|
||||||
|
user: databaseUser,
|
||||||
|
password: databasePassword,
|
||||||
|
};
|
||||||
|
|
||||||
|
async function connectToPostgres() {
|
||||||
|
const client = new Client(postgresConfig);
|
||||||
|
await client.connect();
|
||||||
|
return client;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function insertLatestVideos() {
|
||||||
|
const client = await connectToPostgres();
|
||||||
|
let page = 334;
|
||||||
|
let failCount = 0;
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
const videos = await getLatestVideos(page, 10);
|
||||||
|
if (videos == null) {
|
||||||
|
failCount++;
|
||||||
|
if (failCount > 5) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (videos.length == 0) {
|
||||||
|
console.warn("No more videos found");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let allExists = true;
|
||||||
|
for (const video of videos) {
|
||||||
|
const videoExists = await videoExistsInAllData(client, video.aid);
|
||||||
|
if (!videoExists) {
|
||||||
|
allExists = false;
|
||||||
|
insertIntoAllData(client, video);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (allExists) {
|
||||||
|
console.log("All videos already exist in all_data, stop crawling.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
console.log(`Page ${page} crawled, total: ${(page - 1) * 20 + videos.length} videos.`);
|
||||||
|
page++;
|
||||||
|
} catch (error) {
|
||||||
|
console.error(error);
|
||||||
|
failCount++;
|
||||||
|
if (failCount > 5) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
await sleep(Math.random() * 4000 + 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
insertLatestVideos();
|
4
lib/utils/formatTimestampToPostgre.ts
Normal file
4
lib/utils/formatTimestampToPostgre.ts
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
export default function formatTimestamp(timestamp: number) {
|
||||||
|
const date = new Date(timestamp * 1000);
|
||||||
|
return date.toISOString().slice(0, 19).replace("T", " ");
|
||||||
|
}
|
3
lib/utils/sleep.ts
Normal file
3
lib/utils/sleep.ts
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
export async function sleep(ms: number) {
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, ms));
|
||||||
|
}
|
@ -1,6 +1,7 @@
|
|||||||
|
import { Buffer } from "node:buffer";
|
||||||
"use strict";
|
"use strict";
|
||||||
|
|
||||||
export const handler = async (event, context) => {
|
export const handler = async (event, _context) => {
|
||||||
const eventObj = JSON.parse(event);
|
const eventObj = JSON.parse(event);
|
||||||
console.log(`receive event: ${JSON.stringify(eventObj)}`);
|
console.log(`receive event: ${JSON.stringify(eventObj)}`);
|
||||||
|
|
||||||
|
@ -1,35 +0,0 @@
|
|||||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
|
||||||
|
|
||||||
const API_URL = "https://api.bilibili.com/x/web-interface/newlist?rid=30&ps=50&pn=";
|
|
||||||
|
|
||||||
const requiredEnvVars = ["DB_HOST", "DB_NAME", "DB_USER", "DB_PASSWORD", "DB_PORT"];
|
|
||||||
|
|
||||||
const unsetVars = requiredEnvVars.filter((key) => !Deno.env.get(key));
|
|
||||||
|
|
||||||
if (unsetVars.length > 0) {
|
|
||||||
throw new Error(`Missing required environment variables: ${unsetVars.join(", ")}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
const databaseHost = Deno.env.get("DB_HOST")!;
|
|
||||||
const databaseName = Deno.env.get("DB_NAME")!;
|
|
||||||
const databaseUser = Deno.env.get("DB_USER")!;
|
|
||||||
const databasePassword = Deno.env.get("DB_PASSWORD")!;
|
|
||||||
const databasePort = Deno.env.get("DB_PORT")!;
|
|
||||||
|
|
||||||
const postgresConfig = {
|
|
||||||
hostname: databaseHost,
|
|
||||||
port: parseInt(databasePort),
|
|
||||||
database: databaseName,
|
|
||||||
user: databaseUser,
|
|
||||||
password: databasePassword,
|
|
||||||
};
|
|
||||||
|
|
||||||
async function connectToPostgres() {
|
|
||||||
const client = new Client(postgresConfig);
|
|
||||||
await client.connect();
|
|
||||||
return client;
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function getLatestVideos() {
|
|
||||||
const client = await connectToPostgres();
|
|
||||||
}
|
|
25
test/net/getLatestVideos.test.ts
Normal file
25
test/net/getLatestVideos.test.ts
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import { assertEquals } from "jsr:@std/assert";
|
||||||
|
import { getLatestVideos } from "lib/net/getLatestVideos.ts";
|
||||||
|
|
||||||
|
Deno.test("Get latest videos", async () => {
|
||||||
|
const videos = (await getLatestVideos(1, 5))!;
|
||||||
|
assertEquals(videos.length, 5);
|
||||||
|
|
||||||
|
videos.forEach((video) => {
|
||||||
|
assertVideoProperties(video);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
function assertVideoProperties(video: object) {
|
||||||
|
const aid = "aid" in video && typeof video.aid === "number";
|
||||||
|
const bvid = "bvid" in video && typeof video.bvid === "string" &&
|
||||||
|
video.bvid.length === 12 && video.bvid.startsWith("BV");
|
||||||
|
const description = "description" in video && typeof video.description === "string";
|
||||||
|
const uid = "uid" in video && typeof video.uid === "number";
|
||||||
|
const tags = "tags" in video && (typeof video.tags === "string" || video.tags === null);
|
||||||
|
const title = "title" in video && typeof video.title === "string";
|
||||||
|
const publishedAt = "published_at" in video && typeof video.published_at === "string";
|
||||||
|
|
||||||
|
const match = aid && bvid && description && uid && tags && title && publishedAt;
|
||||||
|
assertEquals(match, true);
|
||||||
|
}
|
28
test/net/getVideoTags.test.ts
Normal file
28
test/net/getVideoTags.test.ts
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import { assertEquals } from "jsr:@std/assert";
|
||||||
|
import { getVideoTags } from "lib/net/getVideoTags.ts";
|
||||||
|
|
||||||
|
Deno.test("Get video tags - regular video", async () => {
|
||||||
|
const tags = (await getVideoTags(826597951)).sort();
|
||||||
|
assertEquals(tags, [
|
||||||
|
"纯白P",
|
||||||
|
"中华墨水娘",
|
||||||
|
"中华少女",
|
||||||
|
"中华粘土娘",
|
||||||
|
"中华缘木娘",
|
||||||
|
"中华少女Project",
|
||||||
|
"提糯Tino",
|
||||||
|
"中华烛火娘",
|
||||||
|
"中华烁金娘",
|
||||||
|
"新世代音乐人计划女生季",
|
||||||
|
].sort());
|
||||||
|
});
|
||||||
|
|
||||||
|
Deno.test("Get video tags - non-existent video", async () => {
|
||||||
|
const tags = (await getVideoTags(8265979511111111));
|
||||||
|
assertEquals(tags, []);
|
||||||
|
});
|
||||||
|
|
||||||
|
Deno.test("Get video tags - video with no tag", async () => {
|
||||||
|
const tags = (await getVideoTags(981001865));
|
||||||
|
assertEquals(tags, []);
|
||||||
|
});
|
Loading…
Reference in New Issue
Block a user