add: predict for filter model

This commit is contained in:
alikia2x (寒寒) 2025-01-25 04:52:51 +08:00
parent 175f3c3f6b
commit 7c318c942f
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
6 changed files with 181 additions and 4 deletions

3
.gitignore vendored
View File

@ -73,4 +73,5 @@ __pycache__
filter/runs filter/runs
data/filter/eval* data/filter/eval*
data/filter/train* data/filter/train*
filter/checkpoints filter/checkpoints
data/filter/model_predicted.jsonl

View File

@ -914,3 +914,4 @@
{"aid": 754905929, "title": "【螺旋音リボ sublimation】未来線【多音階CVVC配布】", "description": "螺旋音リボが一周年を迎えました。\n今回配布する音源名は、螺旋音リボ 多音階CVVC sublimationです。デルタ式7モーラ母音優先リストを使用して録音した三音階CVVC音源です。\n収録音階はF#3,A3,A#3です。A#3は強めに録っております。\nsublimationは昇華という意味の英単語です。\n\nイズ除去を塩釜mylist/41289446さんにお願いいたしました。ありがとうございました\n\nそれではこれからも螺旋音リボをよろしくお願いいたします。\n\n音源ダウンロード: http://gluglu-singer.jimdo.com/\n素敵な本家: sm15746943\nust: sm16402299\nillustration: 悠みなも mylist/34371459\nvocal: 螺旋音リボ 多音階CVVC sublimation (cho 螺旋音リボ 単独音)\n\n------------------------------------------------------------------------------------\n我是日本人。我使用Google翻译进行了翻译但是如果我输入有误我们深感抱歉。\n\n螺旋音リボ正在庆祝其成立一周年。\n这次要分配的声源名称是螺旋音リボ Multitone CVVC升华。使用delta 7-mora元音优先级列表录制的三音CVVC声源。\n记录的音阶为F3A3A3。强烈记录了3。\n升华是英文单词意为升华。\n\n我请盐og先生mylist / 41289446消除噪音。非常感谢你\n\n感谢您一直以来对螺旋音リボ的支持。\n\n声音源下载http://gluglu-singer.jimdo.com/\n不错的户主sm15746943\nUSTsm16402299\n插图悠みなも mylist / 34371459\nvocal螺旋音リボ 多音CVVC升华cho螺旋音リボ 单声)", "tags": ["UTAU", "日本", "UTAU音源配布所リンク", "螺旋音リボ", "VOCALOID→UTAUカバー曲", "未来線"], "author_info": "悠みなも: UTAU/螺旋音リボ/DTM Weibo MINAMO_YU0001", "label": 0} {"aid": 754905929, "title": "【螺旋音リボ sublimation】未来線【多音階CVVC配布】", "description": "螺旋音リボが一周年を迎えました。\n今回配布する音源名は、螺旋音リボ 多音階CVVC sublimationです。デルタ式7モーラ母音優先リストを使用して録音した三音階CVVC音源です。\n収録音階はF#3,A3,A#3です。A#3は強めに録っております。\nsublimationは昇華という意味の英単語です。\n\nイズ除去を塩釜mylist/41289446さんにお願いいたしました。ありがとうございました\n\nそれではこれからも螺旋音リボをよろしくお願いいたします。\n\n音源ダウンロード: http://gluglu-singer.jimdo.com/\n素敵な本家: sm15746943\nust: sm16402299\nillustration: 悠みなも mylist/34371459\nvocal: 螺旋音リボ 多音階CVVC sublimation (cho 螺旋音リボ 単独音)\n\n------------------------------------------------------------------------------------\n我是日本人。我使用Google翻译进行了翻译但是如果我输入有误我们深感抱歉。\n\n螺旋音リボ正在庆祝其成立一周年。\n这次要分配的声源名称是螺旋音リボ Multitone CVVC升华。使用delta 7-mora元音优先级列表录制的三音CVVC声源。\n记录的音阶为F3A3A3。强烈记录了3。\n升华是英文单词意为升华。\n\n我请盐og先生mylist / 41289446消除噪音。非常感谢你\n\n感谢您一直以来对螺旋音リボ的支持。\n\n声音源下载http://gluglu-singer.jimdo.com/\n不错的户主sm15746943\nUSTsm16402299\n插图悠みなも mylist / 34371459\nvocal螺旋音リボ 多音CVVC升华cho螺旋音リボ 单声)", "tags": ["UTAU", "日本", "UTAU音源配布所リンク", "螺旋音リボ", "VOCALOID→UTAUカバー曲", "未来線"], "author_info": "悠みなも: UTAU/螺旋音リボ/DTM Weibo MINAMO_YU0001", "label": 0}
{"aid": 685634071, "title": "【AI歌手】腐草为萤【Cover】", "description": "原唱:银临\n引擎歌叽歌叽\nPV借用av2700587\nPV星の祈\n绘otakucake阿饼P站pixiv.me/otakucake", "tags": ["AI", "歌叽歌叽"], "author_info": "一支粉笔w: 人生难得一知己,千古知音最难求。 | 网易云:粉笔 微博http://weibo.com/Chalk10", "label": 2} {"aid": 685634071, "title": "【AI歌手】腐草为萤【Cover】", "description": "原唱:银临\n引擎歌叽歌叽\nPV借用av2700587\nPV星の祈\n绘otakucake阿饼P站pixiv.me/otakucake", "tags": ["AI", "歌叽歌叽"], "author_info": "一支粉笔w: 人生难得一知己,千古知音最难求。 | 网易云:粉笔 微博http://weibo.com/Chalk10", "label": 2}
{"aid": 3372898, "title": "【言和英文】Counting Stars", "description": "自制 终于投了。1.封面仍然灵魂p图。2.这玩意儿9月底就在做了说好10月投pv还解决不了。说好11月投但又因为曲子本身又拖了。上个星期想投结果电脑坏掉加发烧。。。。。所以一路拖到今天才投。算是放下心中的一块石头了吧。3.这次画了很多参数感觉有提升不知道朱军是怎么看的。5.counting stars是我最最最喜欢的一首歌已经循环一年。6.1R赶紧出4专", "tags": ["VOCALOID", "COUNTING STARS", "言和", "ONEREPUBLIC"], "author_info": "JackChenZz: 随缘更新", "label": 0} {"aid": 3372898, "title": "【言和英文】Counting Stars", "description": "自制 终于投了。1.封面仍然灵魂p图。2.这玩意儿9月底就在做了说好10月投pv还解决不了。说好11月投但又因为曲子本身又拖了。上个星期想投结果电脑坏掉加发烧。。。。。所以一路拖到今天才投。算是放下心中的一块石头了吧。3.这次画了很多参数感觉有提升不知道朱军是怎么看的。5.counting stars是我最最最喜欢的一首歌已经循环一年。6.1R赶紧出4专", "tags": ["VOCALOID", "COUNTING STARS", "言和", "ONEREPUBLIC"], "author_info": "JackChenZz: 随缘更新", "label": 0}
{"aid": 589575, "title": "【洛天依】约定", "description": "绫:天依,不管几年之后你能否记得我,我会一直恋你,爱你,铭心记住你! QAQ 被小汐撺掇调的,咱非常不擅长这拐来拐去的长调式滴说,不过这次应该算不上渣吧?(好吧- -无视这句。翻的周慧的《约定》谱子依旧自己来其实只是自己重做MIDI了网上的MIDI做觉得不好用调子太高3分43秒的字幕是UP无聊的XXX。PV用慢放的《春来发几只》弄的剪掉了两组燕子镜头~ ~赶脚两个曲子的BPM同步了", "tags": ["乐正绫", "洛天依", "洛天依翻唱曲", "VOCALOID", "良调教", "约定", "周慧"], "author_info": "星璇の天空: Dr.冥月星璇 成就不会调教的调教师U/V/SV/C)专业咸鱼不会科研的副研究员不会编程的系统架构师不会AI的无人机设计师...", "label": 2}

View File

@ -18,4 +18,4 @@ Note
0324: V3.5-test3 # 用回3.2的FC层试试 0324: V3.5-test3 # 用回3.2的FC层试试
0331: V3.6-test3 # 3.5不太行,我试着调下超参 0331: V3.6-test3 # 3.5不太行,我试着调下超参
0335: V3.7-test3 # 3.6还行,再调超参试试看 0335: V3.7-test3 # 3.6还行,再调超参试试看
0352: V3.8-test3 # 3.7不行从3.6的基础重新调 0414: V3.8-test3 # 3.7不行从3.6的基础重新调

139
filter/predict.py Normal file
View File

@ -0,0 +1,139 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import sqlite3
import json
import torch
from modelV3_4 import VideoClassifierV3_4
from sentence_transformers import SentenceTransformer
# 数据库配置
DATABASE_PATH = "./data/main.db"
OUTPUT_FILE = "./data/filter/model_predicted.jsonl"
BATCH_SIZE = 128 # 批量处理的大小
def fetch_all_aids(conn):
"""获取数据库中所有符合条件的aid"""
cursor = conn.cursor()
cursor.execute("SELECT aid FROM bili_info_crawl WHERE status = 'success'")
aids = [row[0] for row in cursor.fetchall()]
return aids
def fetch_entry_data(conn, aid):
"""获取单个条目的原始数据"""
cursor = conn.cursor()
cursor.execute("SELECT data FROM bili_info_crawl WHERE aid = ?", (aid,))
d = cursor.fetchone()
data = d[0] if d else None
return data
def parse_entry_data(data):
"""解析原始数据为结构化信息"""
try:
obj = json.loads(data)
title = obj["View"]["title"]
description = obj["View"]["desc"]
tags = [tag["tag_name"] for tag in obj["Tags"]
if tag["tag_type"] in ["old_channel", "topic"]]
author_info = f"{obj['Card']['card']['name']}: {obj['Card']['card']['sign']}"
return title, description, tags, author_info
except (KeyError, json.JSONDecodeError) as e:
print(f"解析错误: {e}")
return None, None, None, None
def initialize_model():
"""初始化模型和文本编码器"""
model = VideoClassifierV3_4()
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt', map_location=torch.device('cpu')))
model.eval()
st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
return model, st_model
def predict_batch(model, st_model, batch_data):
"""批量执行预测"""
with torch.no_grad():
input_texts = {
"title": [entry["title"] for entry in batch_data],
"description": [entry["description"] for entry in batch_data],
"tags": [" ".join(entry["tags"]) for entry in batch_data],
"author_info": [entry["author_info"] for entry in batch_data]
}
logits = model(input_texts=input_texts, sentence_transformer=st_model)
return torch.argmax(logits, dim=1).tolist()
def process_entries():
"""主处理流程"""
# 初始化模型
model, st_model = initialize_model()
# 获取数据库连接
conn = sqlite3.connect(DATABASE_PATH)
# 获取所有aid
aids = fetch_all_aids(conn)
print(f"总需处理条目数: {len(aids)}")
# 批量处理并保存结果
with open(OUTPUT_FILE, "w", encoding="utf-8") as output:
batch_data = []
for idx, aid in enumerate(aids, 1):
try:
# 获取并解析数据
raw_data = fetch_entry_data(conn, aid)
if not raw_data:
continue
title, desc, tags, author = parse_entry_data(raw_data)
if not title:
continue
# 构造预测输入
entry = {
"aid": aid,
"title": title,
"description": desc,
"tags": tags,
"author_info": author
}
batch_data.append(entry)
# 当达到批量大小时进行预测
if len(batch_data) >= BATCH_SIZE:
predictions = predict_batch(model, st_model, batch_data)
for entry, prediction in zip(batch_data, predictions):
output.write(json.dumps({
"aid": entry["aid"],
"title": entry["title"],
"description": entry["description"],
"tags": entry["tags"],
"author_info": entry["author_info"],
"label": prediction
}, ensure_ascii=False) + "\n")
batch_data = [] # 清空批量数据
# 进度显示
if idx % 100 == 0:
print(f"已处理 {idx}/{len(aids)} 条...")
except Exception as e:
print(f"处理aid {aid} 时出错: {str(e)}")
# 处理剩余的条目
if batch_data:
predictions = predict_batch(model, st_model, batch_data)
for entry, prediction in zip(batch_data, predictions):
output.write(json.dumps({
"aid": entry["aid"],
"title": entry["title"],
"description": entry["description"],
"tags": entry["tags"],
"author_info": entry["author_info"],
"label": prediction
}, ensure_ascii=False) + "\n")
# 关闭数据库连接
conn.close()
if __name__ == "__main__":
process_entries()
print("预测完成,结果已保存至", OUTPUT_FILE)

View File

@ -86,8 +86,6 @@ def label_entries(db_path, aids):
title, description, tags, author_info, url = parse_entry_data(data) title, description, tags, author_info, url = parse_entry_data(data)
if not title: # 如果解析失败,跳过 if not title: # 如果解析失败,跳过
continue continue
if '原创' not in title and '原创' not in description:
continue
# 展示信息 # 展示信息
os.system("clear") os.system("clear")
print(f"AID: {aid}") print(f"AID: {aid}")

38
filter/test.py Normal file
View File

@ -0,0 +1,38 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
import torch
from modelV3_4 import VideoClassifierV3_4
from sentence_transformers import SentenceTransformer
def predict(json_input):
# 加载模型
model = VideoClassifierV3_4()
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt'))
model.eval()
# 加载SentenceTransformer
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
input_texts = {
"title": [json_input["title"]],
"description": [json_input["description"]],
"tags": [" ".join(json_input["tags"])],
"author_info": [json_input["author_info"]]
}
# 预测
with torch.no_grad():
logits = model(
input_texts=input_texts,
sentence_transformer=sentence_transformer
)
pred = torch.argmax(logits, dim=1).item()
return pred
if __name__ == "__main__":
# 示例用法
sample_input = {"title": "", "description": "", "tags": ["",""], "author_info": "xx: yy"}
result = predict(sample_input)
print(f"预测结果: {result}")