add: predict for filter model
This commit is contained in:
parent
175f3c3f6b
commit
7c318c942f
1
.gitignore
vendored
1
.gitignore
vendored
@ -74,3 +74,4 @@ filter/runs
|
|||||||
data/filter/eval*
|
data/filter/eval*
|
||||||
data/filter/train*
|
data/filter/train*
|
||||||
filter/checkpoints
|
filter/checkpoints
|
||||||
|
data/filter/model_predicted.jsonl
|
@ -914,3 +914,4 @@
|
|||||||
{"aid": 754905929, "title": "【螺旋音リボ sublimation】未来線【多音階CVVC配布】", "description": "螺旋音リボが一周年を迎えました。\n今回配布する音源名は、螺旋音リボ 多音階CVVC sublimationです。デルタ式7モーラ母音優先リストを使用して録音した三音階CVVC音源です。\n収録音階はF#3,A3,A#3です。A#3は強めに録っております。\nsublimationは昇華という意味の英単語です。\n\nノイズ除去を塩釜(mylist/41289446)さんにお願いいたしました。ありがとうございました!\n\nそれではこれからも螺旋音リボをよろしくお願いいたします。\n\n音源ダウンロード: http://gluglu-singer.jimdo.com/\n素敵な本家: sm15746943\nust: sm16402299\nillustration: 悠みなも mylist/34371459\nvocal: 螺旋音リボ 多音階CVVC sublimation (cho 螺旋音リボ 単独音)\n\n------------------------------------------------------------------------------------\n我是日本人。我使用Google翻译进行了翻译,但是如果我输入有误,我们深感抱歉。\n\n螺旋音リボ正在庆祝其成立一周年。\n这次要分配的声源名称是螺旋音リボ Multitone CVVC升华。使用delta 7-mora元音优先级列表录制的三音CVVC声源。\n记录的音阶为F#3,A3,A#3。强烈记录了#3。\n升华是英文单词,意为升华。\n\n我请盐og先生(mylist / 41289446)消除噪音。非常感谢你!\n\n感谢您一直以来对螺旋音リボ的支持。\n\n声音源下载:http://gluglu-singer.jimdo.com/\n不错的户主:sm15746943\nUST:sm16402299\n插图:悠みなも mylist / 34371459\nvocal:螺旋音リボ 多音CVVC升华(cho螺旋音リボ 单声)", "tags": ["UTAU", "日本", "UTAU音源配布所リンク", "螺旋音リボ", "VOCALOID→UTAUカバー曲", "未来線"], "author_info": "悠みなも: UTAU/螺旋音リボ/DTM Weibo MINAMO_YU0001", "label": 0}
|
{"aid": 754905929, "title": "【螺旋音リボ sublimation】未来線【多音階CVVC配布】", "description": "螺旋音リボが一周年を迎えました。\n今回配布する音源名は、螺旋音リボ 多音階CVVC sublimationです。デルタ式7モーラ母音優先リストを使用して録音した三音階CVVC音源です。\n収録音階はF#3,A3,A#3です。A#3は強めに録っております。\nsublimationは昇華という意味の英単語です。\n\nノイズ除去を塩釜(mylist/41289446)さんにお願いいたしました。ありがとうございました!\n\nそれではこれからも螺旋音リボをよろしくお願いいたします。\n\n音源ダウンロード: http://gluglu-singer.jimdo.com/\n素敵な本家: sm15746943\nust: sm16402299\nillustration: 悠みなも mylist/34371459\nvocal: 螺旋音リボ 多音階CVVC sublimation (cho 螺旋音リボ 単独音)\n\n------------------------------------------------------------------------------------\n我是日本人。我使用Google翻译进行了翻译,但是如果我输入有误,我们深感抱歉。\n\n螺旋音リボ正在庆祝其成立一周年。\n这次要分配的声源名称是螺旋音リボ Multitone CVVC升华。使用delta 7-mora元音优先级列表录制的三音CVVC声源。\n记录的音阶为F#3,A3,A#3。强烈记录了#3。\n升华是英文单词,意为升华。\n\n我请盐og先生(mylist / 41289446)消除噪音。非常感谢你!\n\n感谢您一直以来对螺旋音リボ的支持。\n\n声音源下载:http://gluglu-singer.jimdo.com/\n不错的户主:sm15746943\nUST:sm16402299\n插图:悠みなも mylist / 34371459\nvocal:螺旋音リボ 多音CVVC升华(cho螺旋音リボ 单声)", "tags": ["UTAU", "日本", "UTAU音源配布所リンク", "螺旋音リボ", "VOCALOID→UTAUカバー曲", "未来線"], "author_info": "悠みなも: UTAU/螺旋音リボ/DTM Weibo MINAMO_YU0001", "label": 0}
|
||||||
{"aid": 685634071, "title": "【AI歌手】腐草为萤【Cover】", "description": "原唱:银临\n引擎:歌叽歌叽\nPV借用:av2700587\nPV:星の祈\n绘:otakucake阿饼(P站:pixiv.me/otakucake)", "tags": ["AI", "歌叽歌叽"], "author_info": "一支粉笔w: 人生难得一知己,千古知音最难求。 | 网易云:粉笔 微博:http://weibo.com/Chalk10", "label": 2}
|
{"aid": 685634071, "title": "【AI歌手】腐草为萤【Cover】", "description": "原唱:银临\n引擎:歌叽歌叽\nPV借用:av2700587\nPV:星の祈\n绘:otakucake阿饼(P站:pixiv.me/otakucake)", "tags": ["AI", "歌叽歌叽"], "author_info": "一支粉笔w: 人生难得一知己,千古知音最难求。 | 网易云:粉笔 微博:http://weibo.com/Chalk10", "label": 2}
|
||||||
{"aid": 3372898, "title": "【言和英文】Counting Stars", "description": "自制 终于投了。1.封面仍然灵魂p图。2.这玩意儿9月底就在做了,说好10月投,pv还解决不了。说好11月投,但又因为曲子本身又拖了。上个星期想投,结果,电脑坏掉加发烧。。。。。所以一路拖到今天才投。算是放下心中的一块石头了吧。3.这次画了很多参数,感觉有提升,不知道朱军是怎么看的。5.counting stars是我最最最喜欢的一首歌,已经循环一年。6.1R赶紧出4专!!!", "tags": ["VOCALOID", "COUNTING STARS", "言和", "ONEREPUBLIC"], "author_info": "JackChenZz: 随缘更新", "label": 0}
|
{"aid": 3372898, "title": "【言和英文】Counting Stars", "description": "自制 终于投了。1.封面仍然灵魂p图。2.这玩意儿9月底就在做了,说好10月投,pv还解决不了。说好11月投,但又因为曲子本身又拖了。上个星期想投,结果,电脑坏掉加发烧。。。。。所以一路拖到今天才投。算是放下心中的一块石头了吧。3.这次画了很多参数,感觉有提升,不知道朱军是怎么看的。5.counting stars是我最最最喜欢的一首歌,已经循环一年。6.1R赶紧出4专!!!", "tags": ["VOCALOID", "COUNTING STARS", "言和", "ONEREPUBLIC"], "author_info": "JackChenZz: 随缘更新", "label": 0}
|
||||||
|
{"aid": 589575, "title": "【洛天依】约定", "description": "绫:天依,不管几年之后你能否记得我,我会一直恋你,爱你,铭心记住你! QAQ 被小汐撺掇调的,咱非常不擅长这拐来拐去的长调式滴说,不过这次应该算不上渣吧?(好吧- -无视这句)。翻的周慧的《约定》,谱子依旧自己来(其实只是自己重做MIDI了,网上的MIDI做觉得不好用,调子太高),3分43秒的字幕是UP无聊的XXX。PV用慢放的《春来发几只》弄的,剪掉了两组燕子镜头~ ~赶脚两个曲子的BPM同步了", "tags": ["乐正绫", "洛天依", "洛天依翻唱曲", "VOCALOID", "良调教", "约定", "周慧"], "author_info": "星璇の天空: Dr.冥月星璇 成就:不会调教的调教师(U/V/SV/C),专业咸鱼,不会科研的副研究员,不会编程的系统架构师,不会AI的无人机设计师...", "label": 2}
|
@ -18,4 +18,4 @@ Note
|
|||||||
0324: V3.5-test3 # 用回3.2的FC层试试
|
0324: V3.5-test3 # 用回3.2的FC层试试
|
||||||
0331: V3.6-test3 # 3.5不太行,我试着调下超参
|
0331: V3.6-test3 # 3.5不太行,我试着调下超参
|
||||||
0335: V3.7-test3 # 3.6还行,再调超参试试看
|
0335: V3.7-test3 # 3.6还行,再调超参试试看
|
||||||
0352: V3.8-test3 # 3.7不行,从3.6的基础重新调
|
0414: V3.8-test3 # 3.7不行,从3.6的基础重新调
|
139
filter/predict.py
Normal file
139
filter/predict.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
import os
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from modelV3_4 import VideoClassifierV3_4
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
DATABASE_PATH = "./data/main.db"
|
||||||
|
OUTPUT_FILE = "./data/filter/model_predicted.jsonl"
|
||||||
|
BATCH_SIZE = 128 # 批量处理的大小
|
||||||
|
|
||||||
|
def fetch_all_aids(conn):
|
||||||
|
"""获取数据库中所有符合条件的aid"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT aid FROM bili_info_crawl WHERE status = 'success'")
|
||||||
|
aids = [row[0] for row in cursor.fetchall()]
|
||||||
|
return aids
|
||||||
|
|
||||||
|
def fetch_entry_data(conn, aid):
|
||||||
|
"""获取单个条目的原始数据"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT data FROM bili_info_crawl WHERE aid = ?", (aid,))
|
||||||
|
d = cursor.fetchone()
|
||||||
|
data = d[0] if d else None
|
||||||
|
return data
|
||||||
|
|
||||||
|
def parse_entry_data(data):
|
||||||
|
"""解析原始数据为结构化信息"""
|
||||||
|
try:
|
||||||
|
obj = json.loads(data)
|
||||||
|
title = obj["View"]["title"]
|
||||||
|
description = obj["View"]["desc"]
|
||||||
|
tags = [tag["tag_name"] for tag in obj["Tags"]
|
||||||
|
if tag["tag_type"] in ["old_channel", "topic"]]
|
||||||
|
author_info = f"{obj['Card']['card']['name']}: {obj['Card']['card']['sign']}"
|
||||||
|
return title, description, tags, author_info
|
||||||
|
except (KeyError, json.JSONDecodeError) as e:
|
||||||
|
print(f"解析错误: {e}")
|
||||||
|
return None, None, None, None
|
||||||
|
|
||||||
|
def initialize_model():
|
||||||
|
"""初始化模型和文本编码器"""
|
||||||
|
model = VideoClassifierV3_4()
|
||||||
|
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt', map_location=torch.device('cpu')))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
||||||
|
return model, st_model
|
||||||
|
|
||||||
|
def predict_batch(model, st_model, batch_data):
|
||||||
|
"""批量执行预测"""
|
||||||
|
with torch.no_grad():
|
||||||
|
input_texts = {
|
||||||
|
"title": [entry["title"] for entry in batch_data],
|
||||||
|
"description": [entry["description"] for entry in batch_data],
|
||||||
|
"tags": [" ".join(entry["tags"]) for entry in batch_data],
|
||||||
|
"author_info": [entry["author_info"] for entry in batch_data]
|
||||||
|
}
|
||||||
|
logits = model(input_texts=input_texts, sentence_transformer=st_model)
|
||||||
|
return torch.argmax(logits, dim=1).tolist()
|
||||||
|
|
||||||
|
def process_entries():
|
||||||
|
"""主处理流程"""
|
||||||
|
# 初始化模型
|
||||||
|
model, st_model = initialize_model()
|
||||||
|
|
||||||
|
# 获取数据库连接
|
||||||
|
conn = sqlite3.connect(DATABASE_PATH)
|
||||||
|
|
||||||
|
# 获取所有aid
|
||||||
|
aids = fetch_all_aids(conn)
|
||||||
|
print(f"总需处理条目数: {len(aids)}")
|
||||||
|
|
||||||
|
# 批量处理并保存结果
|
||||||
|
with open(OUTPUT_FILE, "w", encoding="utf-8") as output:
|
||||||
|
batch_data = []
|
||||||
|
for idx, aid in enumerate(aids, 1):
|
||||||
|
try:
|
||||||
|
# 获取并解析数据
|
||||||
|
raw_data = fetch_entry_data(conn, aid)
|
||||||
|
if not raw_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
title, desc, tags, author = parse_entry_data(raw_data)
|
||||||
|
if not title:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 构造预测输入
|
||||||
|
entry = {
|
||||||
|
"aid": aid,
|
||||||
|
"title": title,
|
||||||
|
"description": desc,
|
||||||
|
"tags": tags,
|
||||||
|
"author_info": author
|
||||||
|
}
|
||||||
|
batch_data.append(entry)
|
||||||
|
|
||||||
|
# 当达到批量大小时进行预测
|
||||||
|
if len(batch_data) >= BATCH_SIZE:
|
||||||
|
predictions = predict_batch(model, st_model, batch_data)
|
||||||
|
for entry, prediction in zip(batch_data, predictions):
|
||||||
|
output.write(json.dumps({
|
||||||
|
"aid": entry["aid"],
|
||||||
|
"title": entry["title"],
|
||||||
|
"description": entry["description"],
|
||||||
|
"tags": entry["tags"],
|
||||||
|
"author_info": entry["author_info"],
|
||||||
|
"label": prediction
|
||||||
|
}, ensure_ascii=False) + "\n")
|
||||||
|
batch_data = [] # 清空批量数据
|
||||||
|
|
||||||
|
# 进度显示
|
||||||
|
if idx % 100 == 0:
|
||||||
|
print(f"已处理 {idx}/{len(aids)} 条...")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理aid {aid} 时出错: {str(e)}")
|
||||||
|
|
||||||
|
# 处理剩余的条目
|
||||||
|
if batch_data:
|
||||||
|
predictions = predict_batch(model, st_model, batch_data)
|
||||||
|
for entry, prediction in zip(batch_data, predictions):
|
||||||
|
output.write(json.dumps({
|
||||||
|
"aid": entry["aid"],
|
||||||
|
"title": entry["title"],
|
||||||
|
"description": entry["description"],
|
||||||
|
"tags": entry["tags"],
|
||||||
|
"author_info": entry["author_info"],
|
||||||
|
"label": prediction
|
||||||
|
}, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
# 关闭数据库连接
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
process_entries()
|
||||||
|
print("预测完成,结果已保存至", OUTPUT_FILE)
|
@ -86,8 +86,6 @@ def label_entries(db_path, aids):
|
|||||||
title, description, tags, author_info, url = parse_entry_data(data)
|
title, description, tags, author_info, url = parse_entry_data(data)
|
||||||
if not title: # 如果解析失败,跳过
|
if not title: # 如果解析失败,跳过
|
||||||
continue
|
continue
|
||||||
if '原创' not in title and '原创' not in description:
|
|
||||||
continue
|
|
||||||
# 展示信息
|
# 展示信息
|
||||||
os.system("clear")
|
os.system("clear")
|
||||||
print(f"AID: {aid}")
|
print(f"AID: {aid}")
|
||||||
|
38
filter/test.py
Normal file
38
filter/test.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
import os
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
|
||||||
|
import torch
|
||||||
|
from modelV3_4 import VideoClassifierV3_4
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
def predict(json_input):
|
||||||
|
# 加载模型
|
||||||
|
model = VideoClassifierV3_4()
|
||||||
|
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt'))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 加载SentenceTransformer
|
||||||
|
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
||||||
|
|
||||||
|
input_texts = {
|
||||||
|
"title": [json_input["title"]],
|
||||||
|
"description": [json_input["description"]],
|
||||||
|
"tags": [" ".join(json_input["tags"])],
|
||||||
|
"author_info": [json_input["author_info"]]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(
|
||||||
|
input_texts=input_texts,
|
||||||
|
sentence_transformer=sentence_transformer
|
||||||
|
)
|
||||||
|
pred = torch.argmax(logits, dim=1).item()
|
||||||
|
|
||||||
|
return pred
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 示例用法
|
||||||
|
sample_input = {"title": "", "description": "", "tags": ["",""], "author_info": "xx: yy"}
|
||||||
|
|
||||||
|
result = predict(sample_input)
|
||||||
|
print(f"预测结果: {result}")
|
Loading…
Reference in New Issue
Block a user