diff --git a/bun.lock b/bun.lock index dcf6804..b87f06b 100644 --- a/bun.lock +++ b/bun.lock @@ -71,7 +71,7 @@ "date-fns": "^4.1.0", "express": "^5.1.0", "ioredis": "^5.6.1", - "onnxruntime-node": "1.19.2", + "onnxruntime-node": "1.23.0", "postgres": "^3.4.5", }, "devDependencies": { @@ -1122,6 +1122,8 @@ "acorn-typescript": ["acorn-typescript@1.4.13", "", { "peerDependencies": { "acorn": ">=8.9.0" } }, "sha512-xsc9Xv0xlVfwp2o7sQ+GCQ1PgbkdcpWdTzrwXxO3xDMTAywVS3oXVOcOHuRjAPkS4P9b+yc/qNF15460v+jp4Q=="], + "adm-zip": ["adm-zip@0.5.16", "", {}, "sha512-TGw5yVi4saajsSEgz25grObGHEUaDrniwvA2qwSC060KfqGPdglhvPMA2lPIoxs3PQIItj2iag35fONcQqgUaQ=="], + "agent-base": ["agent-base@7.1.4", "", {}, "sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ=="], "ajv": ["ajv@6.12.6", "", { "dependencies": { "fast-deep-equal": "^3.1.1", "fast-json-stable-stringify": "^2.0.0", "json-schema-traverse": "^0.4.1", "uri-js": "^4.2.2" } }, "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g=="], @@ -2174,9 +2176,9 @@ "oniguruma-to-es": ["oniguruma-to-es@4.3.3", "", { "dependencies": { "oniguruma-parser": "^0.12.1", "regex": "^6.0.1", "regex-recursion": "^6.0.2" } }, "sha512-rPiZhzC3wXwE59YQMRDodUwwT9FZ9nNBwQQfsd1wfdtlKEyCdRV0avrTcSZ5xlIvGRVPd/cx6ZN45ECmS39xvg=="], - "onnxruntime-common": ["onnxruntime-common@1.19.2", "", {}, "sha512-a4R7wYEVFbZBlp0BfhpbFWqe4opCor3KM+5Wm22Az3NGDcQMiU2hfG/0MfnBs+1ZrlSGmlgWeMcXQkDk1UFb8Q=="], + "onnxruntime-common": ["onnxruntime-common@1.23.0", "", {}, "sha512-Auz8S9D7vpF8ok7fzTobvD1XdQDftRf/S7pHmjeCr3Xdymi4z1C7zx4vnT6nnUjbpelZdGwda0BmWHCCTMKUTg=="], - "onnxruntime-node": ["onnxruntime-node@1.19.2", "", { "dependencies": { "onnxruntime-common": "1.19.2", "tar": "^7.0.1" }, "os": [ "linux", "win32", "darwin", ] }, "sha512-9eHMP/HKbbeUcqte1JYzaaRC8JPn7ojWeCeoyShO86TOR97OCyIyAIOGX3V95ErjslVhJRXY8Em/caIUc0hm1Q=="], + "onnxruntime-node": ["onnxruntime-node@1.23.0", "", { "dependencies": { "adm-zip": "^0.5.16", "global-agent": "^3.0.0", "onnxruntime-common": "1.23.0" }, "os": [ "linux", "win32", "darwin", ] }, "sha512-j7QVuR4ouektZjOopKtXcIGsB3C6R9kVbgS10lc2e5SxoWMUhmCwxNl4qBslpWJaPrFxvrQrQaQLQdqku8V19w=="], "onnxruntime-web": ["onnxruntime-web@1.22.0-dev.20250409-89f8206ba4", "", { "dependencies": { "flatbuffers": "^25.1.24", "guid-typescript": "^1.0.9", "long": "^5.2.3", "onnxruntime-common": "1.22.0-dev.20250409-89f8206ba4", "platform": "^1.3.6", "protobufjs": "^7.2.4" } }, "sha512-0uS76OPgH0hWCPrFKlL8kYVV7ckM7t/36HfbgoFw6Nd0CZVVbQC4PkrR8mBX8LtNUFZO25IQBqV2Hx2ho3FlbQ=="], diff --git a/ecosystem.config.js b/ecosystem.config.js deleted file mode 100644 index c491e06..0000000 --- a/ecosystem.config.js +++ /dev/null @@ -1,17 +0,0 @@ -module.exports = { - apps: [ - { - name: 'crawler-worker', - script: 'src/worker.ts', - cwd: './packages/api', - interpreter: 'bun', - instances: 1, - autorestart: true, - watch: false, - max_memory_restart: '1G', - env: { - PATH: `${process.env.HOME}/.bun/bin:${process.env.PATH}`, // Add "~/.bun/bin/bun" to PATH - }, - }, - ], -}; diff --git a/ecosystem.config.mjs b/ecosystem.config.mjs new file mode 100644 index 0000000..5de69a7 --- /dev/null +++ b/ecosystem.config.mjs @@ -0,0 +1,42 @@ +export const apps = [ + { + name: 'crawler-jobadder', + script: 'src/jobAdder.wrapper.ts', + cwd: './packages/crawler', + interpreter: 'bun', + }, + { + name: 'crawler-worker', + script: 'src/worker.ts', + cwd: './packages/crawler', + interpreter: 'bun', + env: { + LOG_VERBOSE: "logs/crawler/verbose.log", + LOG_WARN: "logs/crawler/warn.log", + LOG_ERR: "logs/crawler/error.log" + } + }, + { + name: 'crawler-filter', + script: 'src/filterWorker.wrapper.ts', + cwd: './packages/crawler', + interpreter: 'bun', + env: { + LOG_VERBOSE: "logs/crawler/verbose.log", + LOG_WARN: "logs/crawler/warn.log", + LOG_ERR: "logs/crawler/error.log" + } + }, + { + name: 'ml-api', + script: 'start.py', + cwd: './ml/api', + interpreter: process.env.PYTHON_INTERPRETER || 'python3', + env: { + PYTHONPATH: './ml/api:./ml/filter', + LOG_VERBOSE: "logs/ml/verbose.log", + LOG_WARN: "logs/ml/warn.log", + LOG_ERR: "logs/ml/error.log" + } + }, +] \ No newline at end of file diff --git a/ml/api/main.py b/ml/api/main.py new file mode 100644 index 0000000..462f60a --- /dev/null +++ b/ml/api/main.py @@ -0,0 +1,207 @@ +import os +import torch +import numpy as np +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from transformers import AutoTokenizer +from typing import List, Dict +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Initialize FastAPI app +app = FastAPI(title="CVSA ML API", version="1.0.0") + +# Global variables for models +tokenizer = None +classifier_model = None + +class ClassificationRequest(BaseModel): + title: str + description: str + tags: str + aid: int = None + +class ClassificationResponse(BaseModel): + label: int + probabilities: List[float] + aid: int = None + +class HealthResponse(BaseModel): + status: str + models_loaded: bool + +def load_models(): + """Load the tokenizer and classifier models""" + global tokenizer, classifier_model + + try: + # Load tokenizer + logger.info("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3") + + # Load classifier model + logger.info("Loading classifier model...") + from model_config import VideoClassifierV3_15 + + model_path = "../../model/akari/3.17.pt" + classifier_model = VideoClassifierV3_15() + classifier_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) + classifier_model.eval() + + logger.info("All models loaded successfully") + return True + + except Exception as e: + logger.error(f"Failed to load models: {str(e)}") + return False + +def softmax(logits: np.ndarray) -> np.ndarray: + """Apply softmax to logits""" + exp_logits = np.exp(logits - np.max(logits)) + return exp_logits / np.sum(exp_logits) + +def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray: + """Get Jina embeddings using tokenizer and ONNX-like processing""" + if tokenizer is None: + raise ValueError("Tokenizer not loaded") + + import onnxruntime as ort + + session = ort.InferenceSession("../../model/embedding/model.onnx") + + encoded_inputs = tokenizer( + texts, + add_special_tokens=False, # 关键:不添加特殊token(与JS一致) + return_attention_mask=False, + return_tensors=None # 返回原生Python列表,便于后续处理 + ) + input_ids = encoded_inputs["input_ids"] # 形状: [batch_size, seq_len_i](每个样本长度可能不同) + + # 2. 计算offsets(与JS的cumsum逻辑完全一致) + # 先获取每个样本的token长度 + lengths = [len(ids) for ids in input_ids] + # 计算累积和(排除最后一个样本) + cumsum = [] + current_sum = 0 + for l in lengths[:-1]: # 只累加前n-1个样本的长度 + current_sum += l + cumsum.append(current_sum) + # 构建offsets:起始为0,后面跟累积和 + offsets = [0] + cumsum # 形状: [batch_size] + + # 3. 展平input_ids为一维数组 + flattened_input_ids = [] + for ids in input_ids: + flattened_input_ids.extend(ids) # 直接拼接所有token id + flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64) + + # 4. 准备ONNX输入(与JS的tensor形状保持一致) + inputs = { + "input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids), + "offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64)) + } + + # 5. 运行模型推理 + outputs = session.run(None, inputs) + embeddings = outputs[0] # 假设第一个输出是embeddings,形状: [batch_size, embedding_dim] + + return torch.tensor(embeddings, dtype=torch.float32).numpy() + +@app.on_event("startup") +async def startup_event(): + """Load models on startup""" + success = load_models() + if not success: + logger.error("Failed to load models during startup") + +@app.get("/health", response_model=HealthResponse) +async def health_check(): + """Health check endpoint""" + models_loaded = tokenizer is not None and classifier_model is not None + return HealthResponse( + status="healthy" if models_loaded else "models_not_loaded", + models_loaded=models_loaded + ) + +@app.post("/classify", response_model=ClassificationResponse) +async def classify_video(request: ClassificationRequest): + """Classify a video based on title, description, and tags""" + try: + if tokenizer is None or classifier_model is None: + raise HTTPException(status_code=503, detail="Models not loaded") + + # Get embeddings for each channel + texts = [request.title, request.description, request.tags] + embeddings = get_jina_embeddings_1024(texts) + + # Prepare input for classifier (batch_size=1, channels=3, embedding_dim=1024) + channel_features = torch.tensor(embeddings).unsqueeze(0) # [1, 3, 1024] + + # Run inference + with torch.no_grad(): + logits = classifier_model(channel_features) + probabilities = softmax(logits.numpy()[0]) + predicted_label = int(np.argmax(probabilities)) + + logger.info(f"Classification completed for aid {request.aid}: label={predicted_label}") + + return ClassificationResponse( + label=predicted_label, + probabilities=probabilities.tolist(), + aid=request.aid + ) + + except Exception as e: + logger.error(f"Classification error for aid {request.aid}: {str(e)}") + raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}") + +@app.post("/classify_batch") +async def classify_video_batch(requests: List[ClassificationRequest]): + """Classify multiple videos in batch""" + try: + if tokenizer is None or classifier_model is None: + raise HTTPException(status_code=503, detail="Models not loaded") + + results = [] + for request in requests: + try: + # Get embeddings for each channel + texts = [request.title, request.description, request.tags] + embeddings = get_jina_embeddings_1024(texts) + + # Prepare input for classifier + channel_features = torch.tensor(embeddings).unsqueeze(0) + + # Run inference + with torch.no_grad(): + logits = classifier_model(channel_features) + probabilities = softmax(logits.numpy()[0]) + predicted_label = int(np.argmax(probabilities)) + + results.append({ + "aid": request.aid, + "label": predicted_label, + "probabilities": probabilities.tolist() + }) + + except Exception as e: + logger.error(f"Batch classification error for aid {request.aid}: {str(e)}") + results.append({ + "aid": request.aid, + "label": -1, + "probabilities": [], + "error": str(e) + }) + + return {"results": results} + + except Exception as e: + logger.error(f"Batch classification failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"Batch classification failed: {str(e)}") + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8544) \ No newline at end of file diff --git a/ml/api/model_config.py b/ml/api/model_config.py new file mode 100644 index 0000000..9e6be19 --- /dev/null +++ b/ml/api/model_config.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class VideoClassifierV3_15(nn.Module): + def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3, temperature=1.7): + super().__init__() + self.num_channels = 3 + self.channel_names = ['title', 'description', 'tags'] + + # 可学习温度系数 + self.temperature = nn.Parameter(torch.tensor(temperature)) + + # 带约束的通道权重(使用Sigmoid替代Softmax) + self.channel_weights = nn.Parameter(torch.ones(self.num_channels)) + + # 增强的非线性层 + self.fc = nn.Sequential( + nn.Linear(embedding_dim * self.num_channels, hidden_dim*2), + nn.BatchNorm1d(hidden_dim*2), + nn.Dropout(0.2), + nn.GELU(), + nn.Linear(hidden_dim*2, output_dim) + ) + + # 权重初始化 + self._init_weights() + + def _init_weights(self): + for layer in self.fc: + if isinstance(layer, nn.Linear): + # 使用ReLU的初始化参数(GELU的近似) + nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里 + + # 或者使用Xavier初始化(更适合通用场景) + # nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu')) + + nn.init.zeros_(layer.bias) + + + def forward(self, channel_features: torch.Tensor): + """ + 输入格式: [batch_size, num_channels, embedding_dim] + 输出格式: [batch_size, output_dim] + """ + + # 自适应通道权重(Sigmoid约束) + weights = torch.sigmoid(self.channel_weights) # [0,1]范围 + weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1) + + # 特征拼接 + combined = weighted_features.view(weighted_features.size(0), -1) + + return self.fc(combined) + + def get_channel_weights(self): + """获取各通道权重(带温度调节)""" + return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy() + + +class AdaptiveRecallLoss(nn.Module): + def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5): + """ + Args: + class_weights (torch.Tensor): 类别权重 + alpha (float): 召回率调节因子(0-1) + gamma (float): Focal Loss参数 + fp_penalty (float): 类别0假阳性惩罚强度 + """ + super().__init__() + self.class_weights = class_weights + self.alpha = alpha + self.gamma = gamma + self.fp_penalty = fp_penalty + + def forward(self, logits, targets): + # 基础交叉熵损失 + ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none') + + # Focal Loss组件 + pt = torch.exp(-ce_loss) + focal_loss = ((1 - pt) ** self.gamma) * ce_loss + + # 召回率增强(对困难样本加权) + class_mask = F.one_hot(targets, num_classes=len(self.class_weights)) + class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask + recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1) + + # 类别0假阳性惩罚 + probs = F.softmax(logits, dim=1) + fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0) + fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum() + + # 总损失 + total_loss = recall_loss.mean() + fp_loss / len(targets) + + return total_loss \ No newline at end of file diff --git a/ml/api/requirements.txt b/ml/api/requirements.txt new file mode 100644 index 0000000..7c2e9b0 --- /dev/null +++ b/ml/api/requirements.txt @@ -0,0 +1,8 @@ +-i https://pypi.tuna.tsinghua.edu.cn/simple +fastapi==0.104.1 +uvicorn==0.24.0 +torch==2.6.0 +transformers==4.35.2 +numpy==1.26.4 +pydantic==2.5.0 +python-multipart==0.0.6 diff --git a/ml/api/start.py b/ml/api/start.py new file mode 100644 index 0000000..0a4fab4 --- /dev/null +++ b/ml/api/start.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +""" +Startup script for the ML API service +""" +import subprocess +import sys +import os + +def main(): + # Change to the ml/api directory + os.chdir(os.path.dirname(os.path.abspath(__file__))) + + # Start the FastAPI server + cmd = [ + sys.executable, "-m", "uvicorn", + "main:app", + "--host", "0.0.0.0", + "--port", "8544", + "--reload" + ] + + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + print(f"Failed to start server: {e}") + sys.exit(1) + except KeyboardInterrupt: + print("\nServer stopped by user") + sys.exit(0) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/ml/filter/embedding.py b/ml/filter/embedding.py index ccecc9a..b97c342 100644 --- a/ml/filter/embedding.py +++ b/ml/filter/embedding.py @@ -84,7 +84,6 @@ def prepare_batch_per_token(batch_data, max_length=1024): # Step 4: 将输出重塑为 [batch_size, seq_length, embedding_dim] # 注意:这里假设 ONNX 输出的形状是 [total_tokens, embedding_dim] # 需要根据实际序列长度重新分组 - batch_size = len(texts) embeddings_split = np.split(embeddings, np.cumsum(input_ids_lengths[:-1])) padded_embeddings = [] for emb, seq_len in zip(embeddings_split, input_ids_lengths): diff --git a/packages/core/mq/lockManager.ts b/packages/core/mq/lockManager.ts index 75667f4..a11e22f 100644 --- a/packages/core/mq/lockManager.ts +++ b/packages/core/mq/lockManager.ts @@ -1,13 +1,14 @@ -import { redis, RedisClient } from "bun"; +import { Redis } from "ioredis"; +import { redis } from "@core/db/redis"; class LockManager { - private redis: RedisClient; + private redis: Redis; /* * Create a new LockManager * @param redisClient The Redis client used to store the lock data */ - constructor(redisClient: RedisClient) { + constructor(redisClient: Redis) { this.redis = redisClient; } @@ -48,7 +49,8 @@ class LockManager { */ async isLocked(id: string): Promise { const key = `cvsa:lock:${id}`; - return await this.redis.exists(key); + const result = await this.redis.exists(key); + return result === 1; } } diff --git a/packages/core/mq/multipleRateLimiter.ts b/packages/core/mq/multipleRateLimiter.ts index 49ad656..6be42d6 100644 --- a/packages/core/mq/multipleRateLimiter.ts +++ b/packages/core/mq/multipleRateLimiter.ts @@ -1,6 +1,5 @@ import { RateLimiter as Limiter } from "@koshnic/ratelimit"; -import { redis } from "bun"; -import Redis from "ioredis"; +import { redis } from "@core/db/redis"; export interface RateLimiterConfig { duration: number; @@ -9,7 +8,6 @@ export interface RateLimiterConfig { export class RateLimiterError extends Error { public code: string; - constructor(message: string) { super(message); this.name = "RateLimiterError"; @@ -30,7 +28,7 @@ export class MultipleRateLimiter { */ constructor(name: string, configs: RateLimiterConfig[]) { this.configs = configs; - this.limiter = new Limiter(redis as unknown as Redis); + this.limiter = new Limiter(redis); this.name = name; } diff --git a/packages/crawler/db/snapshotSchedule.ts b/packages/crawler/db/snapshotSchedule.ts index 8eed276..f5b5e07 100644 --- a/packages/crawler/db/snapshotSchedule.ts +++ b/packages/crawler/db/snapshotSchedule.ts @@ -1,8 +1,8 @@ import type { SnapshotScheduleType } from "@core/db/schema.d"; import logger from "@core/log"; import { MINUTE } from "@core/lib"; -import { redis } from "bun"; -import { RedisClient } from "bun"; +import { redis } from "@core/db/redis"; +import { Redis } from "ioredis"; import { parseTimestampFromPsql } from "../utils/formatTimestampToPostgre"; import type { Psql } from "@core/db/psql.d"; @@ -14,7 +14,7 @@ function getCurrentWindowIndex(): number { return Math.floor(minutesSinceMidnight / 5); } -export async function refreshSnapshotWindowCounts(sql: Psql, redisClient: RedisClient) { +export async function refreshSnapshotWindowCounts(sql: Psql, redisClient: Redis) { const now = new Date(); const startTime = now.getTime(); @@ -37,19 +37,19 @@ export async function refreshSnapshotWindowCounts(sql: Psql, redisClient: RedisC const targetOffset = Math.floor((row.window_start.getTime() - startTime) / (5 * MINUTE)); const offset = currentWindow + targetOffset; if (offset >= 0) { - await redisClient.hmset(REDIS_KEY, [offset.toString(), row.count.toString()]); + await redisClient.hset(REDIS_KEY, offset.toString(), Number(row.count)); } } } -export async function initSnapshotWindowCounts(sql: Psql, redisClient: RedisClient) { +export async function initSnapshotWindowCounts(sql: Psql, redisClient: Redis) { await refreshSnapshotWindowCounts(sql, redisClient); setInterval(async () => { await refreshSnapshotWindowCounts(sql, redisClient); }, 5 * MINUTE); } -async function getWindowCount(redisClient: RedisClient, offset: number): Promise { +async function getWindowCount(redisClient: Redis, offset: number): Promise { const count = await redisClient.hget(REDIS_KEY, offset.toString()); return count ? parseInt(count, 10) : 0; } @@ -239,7 +239,7 @@ export async function bulkScheduleSnapshot( export async function adjustSnapshotTime( expectedStartTime: Date, allowedCounts: number = 1000, - redisClient: RedisClient + redisClient: Redis ): Promise { const currentWindow = getCurrentWindowIndex(); const targetOffset = Math.floor((expectedStartTime.getTime() - Date.now()) / (5 * MINUTE)) - 6; diff --git a/packages/crawler/ml/akari.ts b/packages/crawler/ml/akari.ts index 6f7f83f..6fdf268 100644 --- a/packages/crawler/ml/akari.ts +++ b/packages/crawler/ml/akari.ts @@ -4,10 +4,14 @@ import logger from "@core/log"; import { WorkerError } from "mq/schema"; import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; import { AkariModelVersion } from "./const"; +import path from "node:path"; + +const currentDir = import.meta.dir; +const modelDir = path.join(currentDir, "../../../model/"); const tokenizerModel = "alikia2x/jina-embedding-v3-m2v-1024"; -const onnxClassifierPath = `../../model/akari/${AkariModelVersion}.onnx`; -const onnxEmbeddingPath = "../../model/embedding/model.onnx"; +const onnxClassifierPath = path.join(modelDir, `./akari/${AkariModelVersion}.onnx`); +const onnxEmbeddingPath = path.join(modelDir, "./embedding/model.onnx"); class AkariProto extends AIManager { private tokenizer: PreTrainedTokenizer | null = null; @@ -59,16 +63,24 @@ class AkariProto extends AIManager { }); const cumsum = (arr: number[]): number[] => - arr.reduce((acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], []); + arr.reduce( + (acc: number[], num: number, i: number) => [...acc, num + (acc[i - 1] || 0)], + [] + ); - const offsets: number[] = [0, ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length))]; + const offsets: number[] = [ + 0, + ...cumsum(input_ids.slice(0, -1).map((x: string) => x.length)) + ]; const flattened_input_ids = input_ids.flat(); const inputs = { input_ids: new ort.Tensor("int64", new BigInt64Array(flattened_input_ids.map(BigInt)), [ flattened_input_ids.length ]), - offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [offsets.length]) + offsets: new ort.Tensor("int64", new BigInt64Array(offsets.map(BigInt)), [ + offsets.length + ]) }; const { embeddings } = await session.run(inputs); @@ -83,11 +95,19 @@ class AkariProto extends AIManager { return this.softmax(logits.data as Float32Array); } - public async classifyVideo(title: string, description: string, tags: string, aid?: number): Promise { + public async classifyVideo( + title: string, + description: string, + tags: string, + aid?: number + ): Promise { const embeddings = await this.getJinaEmbeddings1024([title, description, tags]); const probabilities = await this.runClassification(embeddings); if (aid) { - logger.log(`Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, "ml"); + logger.log( + `Prediction result for aid: ${aid}: [${probabilities.map((p) => p.toFixed(5))}]`, + "ml" + ); } return probabilities.indexOf(Math.max(...probabilities)); } diff --git a/packages/crawler/ml/akari_api.ts b/packages/crawler/ml/akari_api.ts new file mode 100644 index 0000000..c00792e --- /dev/null +++ b/packages/crawler/ml/akari_api.ts @@ -0,0 +1,65 @@ +import apiManager from "./api_manager"; +import logger from "@core/log"; +import { WorkerError } from "mq/schema"; + +class AkariAPI { + private readonly serviceReady: Promise; + + constructor() { + // Wait for the ML API service to be ready on startup + this.serviceReady = apiManager.waitForService(); + } + + public async init(): Promise { + const isReady = await this.serviceReady; + if (!isReady) { + throw new WorkerError( + new Error("ML API service failed to become ready"), + "ml", + "fn:init" + ); + } + logger.log("Akari API initialized successfully", "ml"); + } + + public async classifyVideo( + title: string, + description: string, + tags: string, + aid?: number + ): Promise { + try { + // Ensure service is ready + await this.serviceReady; + + const label = await apiManager.classifyVideo(title, description, tags, aid); + return label; + } catch (error) { + logger.error(`Classification failed for aid ${aid}: ${error}`, "ml"); + throw new WorkerError(error as Error, "ml", "fn:classifyVideo"); + } + } + + public async classifyVideosBatch( + videos: Array<{ title: string; description: string; tags: string; aid?: number }> + ): Promise> { + try { + // Ensure service is ready + await this.serviceReady; + + const results = await apiManager.classifyVideosBatch(videos); + return results; + } catch (error) { + logger.error(`Batch classification failed: ${error}`, "ml"); + throw new WorkerError(error as Error, "ml", "fn:classifyVideosBatch"); + } + } + + public async healthCheck(): Promise { + return await apiManager.healthCheck(); + } +} + +// Create a singleton instance +const Akari = new AkariAPI(); +export default Akari; \ No newline at end of file diff --git a/packages/crawler/ml/api_manager.ts b/packages/crawler/ml/api_manager.ts new file mode 100644 index 0000000..0904a12 --- /dev/null +++ b/packages/crawler/ml/api_manager.ts @@ -0,0 +1,146 @@ +import logger from "@core/log"; +import { WorkerError } from "mq/schema"; + +interface ClassificationRequest { + title: string; + description: string; + tags: string; + aid?: number; +} + +interface ClassificationResponse { + label: number; + probabilities: number[]; + aid?: number; +} + +interface HealthResponse { + status: string; + models_loaded: boolean; +} + +export class APIManager { + private readonly baseUrl: string; + private readonly timeout: number; + + constructor(baseUrl: string = "http://localhost:8544", timeout: number = 30000) { + this.baseUrl = baseUrl; + this.timeout = timeout; + } + + public async healthCheck(): Promise { + try { + const response = await fetch(`${this.baseUrl}/health`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + signal: AbortSignal.timeout(this.timeout), + }); + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const data: HealthResponse = await response.json(); + return data.models_loaded; + } catch (error) { + logger.error(`Health check failed: ${error}`, "ml"); + return false; + } + } + + public async classifyVideo( + title: string, + description: string, + tags: string, + aid?: number + ): Promise { + const request: ClassificationRequest = { + title: title.trim() || "untitled", + description: description.trim() || "N/A", + tags: tags.trim() || "empty", + aid: aid + }; + + try { + const response = await fetch(`${this.baseUrl}/classify`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(request), + signal: AbortSignal.timeout(this.timeout), + }); + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const data: ClassificationResponse = await response.json(); + + if (aid) { + logger.log( + `Prediction result for aid: ${aid}: [${data.probabilities.map((p) => p.toFixed(5))}]`, + "ml" + ); + } + + return data.label; + } catch (error) { + logger.error(`Classification failed for aid ${aid}: ${error}`, "ml"); + throw new WorkerError(error as Error, "ml", "fn:classifyVideo"); + } + } + + public async classifyVideosBatch( + requests: Array<{ title: string; description: string; tags: string; aid?: number }> + ): Promise> { + try { + const response = await fetch(`${this.baseUrl}/classify_batch`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(requests), + signal: AbortSignal.timeout(this.timeout * 2), // Longer timeout for batch + }); + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const data = await response.json(); + return data.results; + } catch (error) { + logger.error(`Batch classification failed: ${error}`, "ml"); + throw new WorkerError(error as Error, "ml", "fn:classifyVideosBatch"); + } + } + + public async waitForService(timeoutMs: number = 60000): Promise { + const startTime = Date.now(); + const checkInterval = 2000; // Check every 2 seconds + + while (Date.now() - startTime < timeoutMs) { + try { + const isHealthy = await this.healthCheck(); + if (isHealthy) { + logger.log("ML API service is healthy", "ml"); + return true; + } + } catch (error) { + // Service not ready yet, continue waiting + } + + await new Promise(resolve => setTimeout(resolve, checkInterval)); + } + + logger.error("ML API service did not become ready within timeout", "ml"); + return false; + } +} + +// Create a singleton instance +const apiManager = new APIManager(); +export default apiManager; \ No newline at end of file diff --git a/packages/crawler/mq/exec/classifyVideo.ts b/packages/crawler/mq/exec/classifyVideo.ts index 04b52e2..bd0ec44 100644 --- a/packages/crawler/mq/exec/classifyVideo.ts +++ b/packages/crawler/mq/exec/classifyVideo.ts @@ -1,6 +1,6 @@ import { Job } from "bullmq"; import { getUnlabelledVideos, getVideoInfoFromAllData, insertVideoLabel } from "../../db/bilibili_metadata"; -import Akari from "ml/akari"; +import Akari from "ml/akari_api"; import { ClassifyVideoQueue } from "mq/index"; import logger from "@core/log"; import { lockManager } from "@core/mq/lockManager"; diff --git a/packages/crawler/mq/index.ts b/packages/crawler/mq/index.ts index 0f4dafd..869e716 100644 --- a/packages/crawler/mq/index.ts +++ b/packages/crawler/mq/index.ts @@ -1,5 +1,5 @@ import { Queue, ConnectionOptions } from "bullmq"; -import { redis } from "bun"; +import { redis } from "@core/db/redis"; export const LatestVideosQueue = new Queue("latestVideos", { connection: redis as ConnectionOptions diff --git a/packages/crawler/mq/init.ts b/packages/crawler/mq/init.ts index 1ac9bac..845f064 100644 --- a/packages/crawler/mq/init.ts +++ b/packages/crawler/mq/init.ts @@ -2,74 +2,74 @@ import { HOUR, MINUTE, SECOND } from "@core/lib"; import { ClassifyVideoQueue, LatestVideosQueue, SnapshotQueue } from "mq/index"; import logger from "@core/log"; import { initSnapshotWindowCounts } from "db/snapshotSchedule"; -import { redis } from "bun"; +import { redis } from "@core/db/redis"; import { sql } from "@core/db/dbNew"; export async function initMQ() { await initSnapshotWindowCounts(sql, redis); - await LatestVideosQueue.upsertJobScheduler("getLatestVideos", { - every: 1 * MINUTE, - immediately: true - }); + // await LatestVideosQueue.upsertJobScheduler("getLatestVideos", { + // every: 1 * MINUTE, + // immediately: true + // }); await ClassifyVideoQueue.upsertJobScheduler("classifyVideos", { every: 5 * MINUTE, immediately: true }); - await LatestVideosQueue.upsertJobScheduler("collectSongs", { - every: 3 * MINUTE, - immediately: true - }); + // await LatestVideosQueue.upsertJobScheduler("collectSongs", { + // every: 3 * MINUTE, + // immediately: true + // }); - await SnapshotQueue.upsertJobScheduler( - "snapshotTick", - { - every: 1 * SECOND, - immediately: true - }, - { - opts: { - removeOnComplete: 300, - removeOnFail: 600 - } - } - ); + // await SnapshotQueue.upsertJobScheduler( + // "snapshotTick", + // { + // every: 1 * SECOND, + // immediately: true + // }, + // { + // opts: { + // removeOnComplete: 300, + // removeOnFail: 600 + // } + // } + // ); - await SnapshotQueue.upsertJobScheduler( - "bulkSnapshotTick", - { - every: 15 * SECOND, - immediately: true - }, - { - opts: { - removeOnComplete: 60, - removeOnFail: 600 - } - } - ); + // await SnapshotQueue.upsertJobScheduler( + // "bulkSnapshotTick", + // { + // every: 15 * SECOND, + // immediately: true + // }, + // { + // opts: { + // removeOnComplete: 60, + // removeOnFail: 600 + // } + // } + // ); - await SnapshotQueue.upsertJobScheduler("dispatchMilestoneSnapshots", { - every: 5 * MINUTE, - immediately: true - }); + // await SnapshotQueue.upsertJobScheduler("dispatchMilestoneSnapshots", { + // every: 5 * MINUTE, + // immediately: true + // }); - await SnapshotQueue.upsertJobScheduler("dispatchRegularSnapshots", { - every: 30 * MINUTE, - immediately: true - }); + // await SnapshotQueue.upsertJobScheduler("dispatchRegularSnapshots", { + // every: 30 * MINUTE, + // immediately: true + // }); - await SnapshotQueue.upsertJobScheduler("dispatchArchiveSnapshots", { - every: 2 * HOUR, - immediately: false - }); + // await SnapshotQueue.upsertJobScheduler("dispatchArchiveSnapshots", { + // every: 2 * HOUR, + // immediately: false + // }); - await SnapshotQueue.upsertJobScheduler("scheduleCleanup", { - every: 2 * MINUTE, - immediately: true - }); + // await SnapshotQueue.upsertJobScheduler("scheduleCleanup", { + // every: 2 * MINUTE, + // immediately: true + // }); logger.log("Message queue initialized."); } diff --git a/packages/crawler/package.json b/packages/crawler/package.json index f611778..20d2adf 100644 --- a/packages/crawler/package.json +++ b/packages/crawler/package.json @@ -22,7 +22,6 @@ "date-fns": "^4.1.0", "express": "^5.1.0", "ioredis": "^5.6.1", - "onnxruntime-node": "1.19.2", "postgres": "^3.4.5" } } diff --git a/packages/crawler/src/build.ts b/packages/crawler/src/build.ts deleted file mode 100644 index 42b7210..0000000 --- a/packages/crawler/src/build.ts +++ /dev/null @@ -1,14 +0,0 @@ -import Bun from "bun"; - -await Bun.build({ - entrypoints: ["./src/filterWorker.ts"], - outdir: "./build", - target: "node" -}); - -const file = Bun.file("./build/filterWorker.js"); -const code = await file.text(); - -const modifiedCode = code.replaceAll("../bin/napi-v3/", "../../../node_modules/onnxruntime-node/bin/napi-v3/"); - -await Bun.write("./build/filterWorker.js", modifiedCode); diff --git a/packages/crawler/src/filterWorker.ts b/packages/crawler/src/filterWorker.ts index b5b55e5..4a2983a 100644 --- a/packages/crawler/src/filterWorker.ts +++ b/packages/crawler/src/filterWorker.ts @@ -1,21 +1,18 @@ import { ConnectionOptions, Job, Worker } from "bullmq"; -import { redis } from "bun"; +import { redis } from "@core/db/redis"; import logger from "@core/log"; import { classifyVideosWorker, classifyVideoWorker } from "mq/exec/classifyVideo"; import { WorkerError } from "mq/schema"; import { lockManager } from "@core/mq/lockManager"; -import Akari from "ml/akari"; +import Akari from "ml/akari_api"; -const shutdown = async (signal: string) => { +const shutdown = async (signal: string, filterWorker: Worker) => { logger.log(`${signal} Received: Shutting down workers...`, "mq"); await filterWorker.close(true); process.exit(0); }; -process.on("SIGINT", () => shutdown("SIGINT")); -process.on("SIGTERM", () => shutdown("SIGTERM")); - -await Akari.init(); +await Akari.init() const filterWorker = new Worker( "classifyVideo", @@ -32,6 +29,9 @@ const filterWorker = new Worker( { connection: redis as ConnectionOptions, concurrency: 2, removeOnComplete: { count: 1000 } } ); +process.on("SIGINT", () => shutdown("SIGINT", filterWorker)); +process.on("SIGTERM", () => shutdown("SIGTERM", filterWorker)); + filterWorker.on("active", () => { logger.log("Worker (filter) activated.", "mq"); }); diff --git a/packages/crawler/src/filterWorker.wrapper.ts b/packages/crawler/src/filterWorker.wrapper.ts new file mode 100644 index 0000000..214447b --- /dev/null +++ b/packages/crawler/src/filterWorker.wrapper.ts @@ -0,0 +1,11 @@ +#!/usr/bin/env bun +/** + * PM2 wrapper to handle Bun's async module loading + * Bypasses require-in-the-middle issues with TypeScript files + */ +// When PM2's require-in-the-middle tries to hook into module loading, +// it fails with async modules. This wrapper uses import() which works correctly. +import("./filterWorker.ts").catch((error) => { + console.error("Failed to start server:", error); + process.exit(1); +}); diff --git a/packages/crawler/src/jobAdder.wrapper.ts b/packages/crawler/src/jobAdder.wrapper.ts new file mode 100644 index 0000000..0837a6d --- /dev/null +++ b/packages/crawler/src/jobAdder.wrapper.ts @@ -0,0 +1,11 @@ +#!/usr/bin/env bun +/** + * PM2 wrapper to handle Bun's async module loading + * Bypasses require-in-the-middle issues with TypeScript files + */ +// When PM2's require-in-the-middle tries to hook into module loading, +// it fails with async modules. This wrapper uses import() which works correctly. +import("./jobAdder.ts").catch((error) => { + console.error("Failed to start server:", error); + process.exit(1); +}); diff --git a/packages/crawler/src/worker.ts b/packages/crawler/src/worker.ts index 6f0fd6b..7dc42d6 100644 --- a/packages/crawler/src/worker.ts +++ b/packages/crawler/src/worker.ts @@ -12,7 +12,7 @@ import { snapshotVideoWorker, takeBulkSnapshotForVideosWorker } from "mq/exec/executors"; -import { redis } from "bun"; +import { redis } from "@core/db/redis"; import logger from "@core/log"; import { lockManager } from "@core/mq/lockManager"; import { WorkerError } from "mq/schema"; diff --git a/packages/crawler/tsconfig.json b/packages/crawler/tsconfig.json index 5902fcd..fcfec29 100644 --- a/packages/crawler/tsconfig.json +++ b/packages/crawler/tsconfig.json @@ -5,6 +5,7 @@ "target": "esnext", "module": "esnext", "useDefineForClassFields": true, + "allowImportingTsExtensions": true, "moduleResolution": "node10", "lib": ["ESNext", "DOM", "DOM.Iterable"], "skipLibCheck": true,