cvsa/filter/test.py
2025-01-29 03:11:34 +08:00

70 lines
2.4 KiB
Python

import os, json
import random
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
import torch
from modelV3_4 import VideoClassifierV3_4
from sentence_transformers import SentenceTransformer
from tag import getch
def predict(json_input):
# 加载模型
model = VideoClassifierV3_4()
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt'))
model.eval()
# 加载SentenceTransformer
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
input_texts = {
"title": [json_input["title"]],
"description": [json_input["description"]],
"tags": [" ".join(json_input["tags"])],
"author_info": [json_input["author_info"]]
}
# 预测
with torch.no_grad():
logits = model(
input_texts=input_texts,
sentence_transformer=sentence_transformer
)
pred = torch.argmax(logits, dim=1).item()
return pred
if __name__ == "__main__":
with open('data/filter/model_predicted.jsonl', 'r') as fp:
data = [json.loads(line) for line in fp.readlines()]
sampled = random.sample(data, min(200, len(data)))
test_data = []
for sample in sampled:
label = sample['label']
os.system("clear")
print(f"AID: {sample['aid']}")
print(f"Title: {sample['title']}")
print(f"Tags: {', '.join(sample['tags'])}")
print(f"Author Info: {sample['author_info']}")
print(f"Description: {sample['description']}")
# 等待用户输入
while True:
print("Label (0 or 1 or 2, s to skip, q to quit): ", end="", flush=True)
real_label = getch().lower()
if real_label in ["0", "1", "2", "s", "q"]:
break
print("\nInvalid input. Please enter 0, 1, 2, s or q.")
if real_label == "s": # 跳过
continue
if real_label == "q": # 退出
break
test_data.append({
"aid": sample['aid'],
"title": sample['title'],
"tags": sample['tags'],
"author_info": sample['author_info'],
"description": sample['description'],
"model": label,
"human": int(real_label)
})
with open("./data/filter/real_test.jsonl", "a") as fp:
fp.writelines([json.dumps(item, ensure_ascii=False) + "\n" for item in test_data])