39 lines
1.1 KiB
Python
39 lines
1.1 KiB
Python
import os
|
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
|
|
import torch
|
|
from modelV3_4 import VideoClassifierV3_4
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
def predict(json_input):
|
|
# 加载模型
|
|
model = VideoClassifierV3_4()
|
|
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt'))
|
|
model.eval()
|
|
|
|
# 加载SentenceTransformer
|
|
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
|
|
|
input_texts = {
|
|
"title": [json_input["title"]],
|
|
"description": [json_input["description"]],
|
|
"tags": [" ".join(json_input["tags"])],
|
|
"author_info": [json_input["author_info"]]
|
|
}
|
|
|
|
# 预测
|
|
with torch.no_grad():
|
|
logits = model(
|
|
input_texts=input_texts,
|
|
sentence_transformer=sentence_transformer
|
|
)
|
|
pred = torch.argmax(logits, dim=1).item()
|
|
|
|
return pred
|
|
|
|
if __name__ == "__main__":
|
|
# 示例用法
|
|
sample_input = {"title": "", "description": "", "tags": ["",""], "author_info": "xx: yy"}
|
|
|
|
result = predict(sample_input)
|
|
print(f"预测结果: {result}")
|