173 lines
6.0 KiB
Python
173 lines
6.0 KiB
Python
import os
|
|
import json
|
|
import random
|
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
import torch
|
|
from modelV3_4 import VideoClassifierV3_4
|
|
from sentence_transformers import SentenceTransformer
|
|
import sys
|
|
import tty
|
|
import termios
|
|
from db_utils import fetch_entry_data, parse_entry_data
|
|
|
|
DATABASE_PATH = "./data/main.db"
|
|
BATCH_SIZE = 50 # 动态加载批次大小
|
|
|
|
class LabelingSystem:
|
|
def __init__(self):
|
|
# 初始化模型
|
|
self.model = VideoClassifierV3_4()
|
|
self.model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt'))
|
|
self.model.eval()
|
|
self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
|
|
|
# 数据相关
|
|
self.existing_entries = self._load_existing_entries()
|
|
self.existing_aids = set(entry['aid'] for entry in self.existing_entries)
|
|
self.candidate_pool = []
|
|
self.history = []
|
|
self.current_index = -1 # -1表示未开始
|
|
|
|
# 初始化第一批数据
|
|
self._load_more_candidates()
|
|
|
|
def _save_entry(self, entry):
|
|
"""保存或更新条目"""
|
|
# 查找是否已存在
|
|
existing_index = next((i for i, e in enumerate(self.existing_entries)
|
|
if e['aid'] == entry['aid']), None)
|
|
|
|
# 更新或添加条目
|
|
if existing_index is not None:
|
|
self.existing_entries[existing_index] = entry
|
|
else:
|
|
self.existing_entries.append(entry)
|
|
|
|
# 重写整个文件
|
|
with open("./data/filter/real_test.jsonl", "w") as fp:
|
|
for entry in self.existing_entries:
|
|
fp.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
|
|
|
def _load_existing_entries(self):
|
|
"""加载已有条目"""
|
|
if not os.path.exists("./data/filter/real_test.jsonl"):
|
|
return []
|
|
|
|
with open("./data/filter/real_test.jsonl", "r") as fp:
|
|
return [json.loads(line) for line in fp]
|
|
|
|
def _load_more_candidates(self):
|
|
"""动态加载更多候选数据"""
|
|
with open('data/filter/model_predicted.jsonl', 'r') as fp:
|
|
new_candidates = []
|
|
for line in fp:
|
|
entry = json.loads(line)
|
|
if entry['aid'] not in self.existing_aids:
|
|
new_candidates.append(entry['aid'])
|
|
|
|
# 随机打乱后取批次
|
|
random.shuffle(new_candidates)
|
|
self.candidate_pool.extend(new_candidates[:BATCH_SIZE])
|
|
del new_candidates[:BATCH_SIZE] # 释放内存
|
|
|
|
def _get_entry_details(self, aid):
|
|
"""获取条目详细信息并预测模型标签"""
|
|
# 获取元数据
|
|
title, description, tags, author_info, url = parse_entry_data(
|
|
fetch_entry_data(DATABASE_PATH, aid)
|
|
)
|
|
|
|
# 模型预测
|
|
with torch.no_grad():
|
|
logits = self.model(
|
|
input_texts={
|
|
"title": [title],
|
|
"description": [description],
|
|
"tags": [" ".join(tags)],
|
|
"author_info": [author_info]
|
|
},
|
|
sentence_transformer=self.sentence_transformer
|
|
)
|
|
model_label = torch.argmax(logits, dim=1).item()
|
|
|
|
return {
|
|
'aid': aid,
|
|
'title': title,
|
|
'description': description,
|
|
'tags': tags,
|
|
'author_info': author_info,
|
|
'url': url,
|
|
'model_label': model_label,
|
|
'user_label': None
|
|
}
|
|
|
|
def _display_entry(self, entry):
|
|
"""显示条目信息"""
|
|
os.system("clear")
|
|
print(f"AID: {entry['aid']}")
|
|
print(f"URL: {entry['url']}")
|
|
print(f"Title: {entry['title']}")
|
|
print(f"Tags: {', '.join(entry['tags'])}")
|
|
print(f"Author Info: {entry['author_info']}")
|
|
print(f"Description: {entry['description']}")
|
|
print(f"\nModel Prediction: {entry['model_label']}")
|
|
if entry['user_label'] is not None:
|
|
print(f"Your Label: {entry['user_label']}")
|
|
|
|
def run(self):
|
|
while True:
|
|
# 处理当前条目
|
|
if self.current_index < 0:
|
|
self.current_index = 0
|
|
|
|
if self.current_index >= len(self.history):
|
|
if not self.candidate_pool:
|
|
self._load_more_candidates()
|
|
if not self.candidate_pool:
|
|
print("\nAll entries processed!")
|
|
return
|
|
|
|
# 处理新条目
|
|
aid = self.candidate_pool.pop(0)
|
|
entry = self._get_entry_details(aid)
|
|
self.history.append(entry)
|
|
self.current_index = len(self.history) - 1
|
|
|
|
current_entry = self.history[self.current_index]
|
|
self._display_entry(current_entry)
|
|
|
|
# 获取用户输入
|
|
print("\nLabel (0/1/2, s=skip, ←↑/→↓=nav, q=quit): ", end="", flush=True)
|
|
cmd = getch().lower()
|
|
|
|
# 处理导航命令
|
|
if cmd in ['left', 'up']:
|
|
self.current_index = max(0, self.current_index - 1)
|
|
elif cmd in ['right', 'down']:
|
|
self.current_index += 1
|
|
elif cmd in ('0', '1', '2'):
|
|
current_entry['human'] = int(cmd)
|
|
self._save_entry(current_entry)
|
|
self.current_index += 1 # 自动前进
|
|
elif cmd == 's':
|
|
self.current_index += 1 # 跳过
|
|
elif cmd == 'q':
|
|
return
|
|
|
|
def getch():
|
|
"""支持方向键检测的输入函数"""
|
|
fd = sys.stdin.fileno()
|
|
old_settings = termios.tcgetattr(fd)
|
|
try:
|
|
tty.setraw(fd)
|
|
ch = sys.stdin.read(1)
|
|
if ch == '\x1b':
|
|
seq = sys.stdin.read(2)
|
|
return {'[A': 'up', '[B': 'down', '[C': 'right', '[D': 'left'}.get(seq, 'unknown')
|
|
return ch
|
|
finally:
|
|
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
|
|
if __name__ == "__main__":
|
|
labeling_system = LabelingSystem()
|
|
labeling_system.run() |