cvsa/filter/predict.py
2025-02-01 19:30:44 +08:00

136 lines
4.9 KiB
Python

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import sqlite3
import json
import torch
from modelV3_9 import VideoClassifierV3_9
from sentence_transformers import SentenceTransformer
from tqdm import tqdm # 导入 tqdm
# 数据库配置
DATABASE_PATH = "./data/main.db"
OUTPUT_FILE = "./data/filter/model_predicted.jsonl"
BATCH_SIZE = 128 # 批量处理的大小
def fetch_all_aids(conn):
"""获取数据库中所有符合条件的aid"""
cursor = conn.cursor()
cursor.execute("SELECT aid FROM bili_info_crawl WHERE status = 'success'")
aids = [row[0] for row in cursor.fetchall()]
return aids
def fetch_entry_data(conn, aid):
"""获取单个条目的原始数据"""
cursor = conn.cursor()
cursor.execute("SELECT data FROM bili_info_crawl WHERE aid = ?", (aid,))
d = cursor.fetchone()
data = d[0] if d else None
return data
def parse_entry_data(data):
"""解析原始数据为结构化信息"""
try:
obj = json.loads(data)
title = obj["View"]["title"]
description = obj["View"]["desc"]
tags = [tag["tag_name"] for tag in obj["Tags"]
if tag["tag_type"] in ["old_channel", "topic"]]
author_info = f"{obj['Card']['card']['name']}: {obj['Card']['card']['sign']}"
return title, description, tags, author_info
except (KeyError, json.JSONDecodeError) as e:
print(f"解析错误: {e}")
return None, None, None, None
def initialize_model():
"""初始化模型和文本编码器"""
model = VideoClassifierV3_9()
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt', map_location=torch.device('cpu')))
model.eval()
st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
return model, st_model
def predict_batch(model, st_model, batch_data):
"""批量执行预测"""
with torch.no_grad():
input_texts = {
"title": [entry["title"] for entry in batch_data],
"description": [entry["description"] for entry in batch_data],
"tags": [" ".join(entry["tags"]) for entry in batch_data],
"author_info": [entry["author_info"] for entry in batch_data]
}
logits = model(input_texts=input_texts, sentence_transformer=st_model)
return torch.argmax(logits, dim=1).tolist()
def process_entries():
"""主处理流程"""
# 初始化模型
model, st_model = initialize_model()
# 获取数据库连接
conn = sqlite3.connect(DATABASE_PATH)
# 获取所有aid
aids = fetch_all_aids(conn)
print(f"总需处理条目数: {len(aids)}")
# 批量处理并保存结果
with open(OUTPUT_FILE, "w", encoding="utf-8") as output:
batch_data = []
for idx, aid in tqdm(enumerate(aids, 1), total=len(aids), desc="处理条目"):
try:
# 获取并解析数据
raw_data = fetch_entry_data(conn, aid)
if not raw_data:
continue
title, desc, tags, author = parse_entry_data(raw_data)
if not title:
continue
# 构造预测输入
entry = {
"aid": aid,
"title": title,
"description": desc,
"tags": tags,
"author_info": author
}
batch_data.append(entry)
# 当达到批量大小时进行预测
if len(batch_data) >= BATCH_SIZE:
predictions = predict_batch(model, st_model, batch_data)
for entry, prediction in zip(batch_data, predictions):
output.write(json.dumps({
"aid": entry["aid"],
"title": entry["title"],
"description": entry["description"],
"tags": entry["tags"],
"author_info": entry["author_info"],
"label": prediction
}, ensure_ascii=False) + "\n")
batch_data = [] # 清空批量数据
except Exception as e:
print(f"处理aid {aid} 时出错: {str(e)}")
# 处理剩余的条目
if batch_data:
predictions = predict_batch(model, st_model, batch_data)
for entry, prediction in zip(batch_data, predictions):
output.write(json.dumps({
"aid": entry["aid"],
"title": entry["title"],
"description": entry["description"],
"tags": entry["tags"],
"author_info": entry["author_info"],
"label": prediction
}, ensure_ascii=False) + "\n")
# 关闭数据库连接
conn.close()
if __name__ == "__main__":
process_entries()
print("预测完成,结果已保存至", OUTPUT_FILE)