70 lines
2.4 KiB
Python
70 lines
2.4 KiB
Python
import os, json
|
|
import random
|
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
|
|
import torch
|
|
from modelV3_4 import VideoClassifierV3_4
|
|
from sentence_transformers import SentenceTransformer
|
|
from tag import getch
|
|
|
|
def predict(json_input):
|
|
# 加载模型
|
|
model = VideoClassifierV3_4()
|
|
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt'))
|
|
model.eval()
|
|
|
|
# 加载SentenceTransformer
|
|
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
|
|
|
input_texts = {
|
|
"title": [json_input["title"]],
|
|
"description": [json_input["description"]],
|
|
"tags": [" ".join(json_input["tags"])],
|
|
"author_info": [json_input["author_info"]]
|
|
}
|
|
|
|
# 预测
|
|
with torch.no_grad():
|
|
logits = model(
|
|
input_texts=input_texts,
|
|
sentence_transformer=sentence_transformer
|
|
)
|
|
pred = torch.argmax(logits, dim=1).item()
|
|
|
|
return pred
|
|
|
|
if __name__ == "__main__":
|
|
with open('data/filter/model_predicted.jsonl', 'r') as fp:
|
|
data = [json.loads(line) for line in fp.readlines()]
|
|
sampled = random.sample(data, min(200, len(data)))
|
|
test_data = []
|
|
for sample in sampled:
|
|
label = sample['label']
|
|
os.system("clear")
|
|
print(f"AID: {sample['aid']}")
|
|
print(f"Title: {sample['title']}")
|
|
print(f"Tags: {', '.join(sample['tags'])}")
|
|
print(f"Author Info: {sample['author_info']}")
|
|
print(f"Description: {sample['description']}")
|
|
# 等待用户输入
|
|
while True:
|
|
print("Label (0 or 1 or 2, s to skip, q to quit): ", end="", flush=True)
|
|
real_label = getch().lower()
|
|
if real_label in ["0", "1", "2", "s", "q"]:
|
|
break
|
|
print("\nInvalid input. Please enter 0, 1, 2, s or q.")
|
|
if real_label == "s": # 跳过
|
|
continue
|
|
if real_label == "q": # 退出
|
|
break
|
|
test_data.append({
|
|
"aid": sample['aid'],
|
|
"title": sample['title'],
|
|
"tags": sample['tags'],
|
|
"author_info": sample['author_info'],
|
|
"description": sample['description'],
|
|
"model": label,
|
|
"human": int(real_label)
|
|
})
|
|
|
|
with open("./data/filter/real_test.jsonl", "a") as fp:
|
|
fp.writelines([json.dumps(item, ensure_ascii=False) + "\n" for item in test_data]) |