cvsa/filter/test.py

39 lines
1.1 KiB
Python

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
import torch
from modelV3_4 import VideoClassifierV3_4
from sentence_transformers import SentenceTransformer
def predict(json_input):
# 加载模型
model = VideoClassifierV3_4()
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt'))
model.eval()
# 加载SentenceTransformer
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
input_texts = {
"title": [json_input["title"]],
"description": [json_input["description"]],
"tags": [" ".join(json_input["tags"])],
"author_info": [json_input["author_info"]]
}
# 预测
with torch.no_grad():
logits = model(
input_texts=input_texts,
sentence_transformer=sentence_transformer
)
pred = torch.argmax(logits, dim=1).item()
return pred
if __name__ == "__main__":
# 示例用法
sample_input = {"title": "", "description": "", "tags": ["",""], "author_info": "xx: yy"}
result = predict(sample_input)
print(f"预测结果: {result}")