1
0
cvsa/ml/api/main.py

207 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer
from typing import List, Dict
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(title="CVSA ML API", version="1.0.0")
# Global variables for models
tokenizer = None
classifier_model = None
class ClassificationRequest(BaseModel):
title: str
description: str
tags: str
aid: int = None
class ClassificationResponse(BaseModel):
label: int
probabilities: List[float]
aid: int = None
class HealthResponse(BaseModel):
status: str
models_loaded: bool
def load_models():
"""Load the tokenizer and classifier models"""
global tokenizer, classifier_model
try:
# Load tokenizer
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3")
# Load classifier model
logger.info("Loading classifier model...")
from model_config import VideoClassifierV3_15
model_path = "../../model/akari/3.17.pt"
classifier_model = VideoClassifierV3_15()
classifier_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
classifier_model.eval()
logger.info("All models loaded successfully")
return True
except Exception as e:
logger.error(f"Failed to load models: {str(e)}")
return False
def softmax(logits: np.ndarray) -> np.ndarray:
"""Apply softmax to logits"""
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / np.sum(exp_logits)
def get_jina_embeddings_1024(texts: List[str]) -> np.ndarray:
"""Get Jina embeddings using tokenizer and ONNX-like processing"""
if tokenizer is None:
raise ValueError("Tokenizer not loaded")
import onnxruntime as ort
session = ort.InferenceSession("../../model/embedding/model.onnx")
encoded_inputs = tokenizer(
texts,
add_special_tokens=False, # 关键不添加特殊token与JS一致
return_attention_mask=False,
return_tensors=None # 返回原生Python列表便于后续处理
)
input_ids = encoded_inputs["input_ids"] # 形状: [batch_size, seq_len_i](每个样本长度可能不同)
# 2. 计算offsets与JS的cumsum逻辑完全一致
# 先获取每个样本的token长度
lengths = [len(ids) for ids in input_ids]
# 计算累积和(排除最后一个样本)
cumsum = []
current_sum = 0
for l in lengths[:-1]: # 只累加前n-1个样本的长度
current_sum += l
cumsum.append(current_sum)
# 构建offsets起始为0后面跟累积和
offsets = [0] + cumsum # 形状: [batch_size]
# 3. 展平input_ids为一维数组
flattened_input_ids = []
for ids in input_ids:
flattened_input_ids.extend(ids) # 直接拼接所有token id
flattened_input_ids = np.array(flattened_input_ids, dtype=np.int64)
# 4. 准备ONNX输入与JS的tensor形状保持一致
inputs = {
"input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids),
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64))
}
# 5. 运行模型推理
outputs = session.run(None, inputs)
embeddings = outputs[0] # 假设第一个输出是embeddings形状: [batch_size, embedding_dim]
return torch.tensor(embeddings, dtype=torch.float32).numpy()
@app.on_event("startup")
async def startup_event():
"""Load models on startup"""
success = load_models()
if not success:
logger.error("Failed to load models during startup")
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
models_loaded = tokenizer is not None and classifier_model is not None
return HealthResponse(
status="healthy" if models_loaded else "models_not_loaded",
models_loaded=models_loaded
)
@app.post("/classify", response_model=ClassificationResponse)
async def classify_video(request: ClassificationRequest):
"""Classify a video based on title, description, and tags"""
try:
if tokenizer is None or classifier_model is None:
raise HTTPException(status_code=503, detail="Models not loaded")
# Get embeddings for each channel
texts = [request.title, request.description, request.tags]
embeddings = get_jina_embeddings_1024(texts)
# Prepare input for classifier (batch_size=1, channels=3, embedding_dim=1024)
channel_features = torch.tensor(embeddings).unsqueeze(0) # [1, 3, 1024]
# Run inference
with torch.no_grad():
logits = classifier_model(channel_features)
probabilities = softmax(logits.numpy()[0])
predicted_label = int(np.argmax(probabilities))
logger.info(f"Classification completed for aid {request.aid}: label={predicted_label}")
return ClassificationResponse(
label=predicted_label,
probabilities=probabilities.tolist(),
aid=request.aid
)
except Exception as e:
logger.error(f"Classification error for aid {request.aid}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")
@app.post("/classify_batch")
async def classify_video_batch(requests: List[ClassificationRequest]):
"""Classify multiple videos in batch"""
try:
if tokenizer is None or classifier_model is None:
raise HTTPException(status_code=503, detail="Models not loaded")
results = []
for request in requests:
try:
# Get embeddings for each channel
texts = [request.title, request.description, request.tags]
embeddings = get_jina_embeddings_1024(texts)
# Prepare input for classifier
channel_features = torch.tensor(embeddings).unsqueeze(0)
# Run inference
with torch.no_grad():
logits = classifier_model(channel_features)
probabilities = softmax(logits.numpy()[0])
predicted_label = int(np.argmax(probabilities))
results.append({
"aid": request.aid,
"label": predicted_label,
"probabilities": probabilities.tolist()
})
except Exception as e:
logger.error(f"Batch classification error for aid {request.aid}: {str(e)}")
results.append({
"aid": request.aid,
"label": -1,
"probabilities": [],
"error": str(e)
})
return {"results": results}
except Exception as e:
logger.error(f"Batch classification failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch classification failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8544)