From 7c318c942f0290da9faf9402a7e2fbcad9b45b74 Mon Sep 17 00:00:00 2001 From: alikia2x Date: Sat, 25 Jan 2025 04:52:51 +0800 Subject: [PATCH] add: predict for filter model --- .gitignore | 3 +- data/filter/labeled_data.jsonl | 1 + filter/RunningLogs.txt | 2 +- filter/predict.py | 139 +++++++++++++++++++++++++++++++++ filter/tag.py | 2 - filter/test.py | 38 +++++++++ 6 files changed, 181 insertions(+), 4 deletions(-) create mode 100644 filter/predict.py create mode 100644 filter/test.py diff --git a/.gitignore b/.gitignore index b11ef49..238ba4d 100644 --- a/.gitignore +++ b/.gitignore @@ -73,4 +73,5 @@ __pycache__ filter/runs data/filter/eval* data/filter/train* -filter/checkpoints \ No newline at end of file +filter/checkpoints +data/filter/model_predicted.jsonl \ No newline at end of file diff --git a/data/filter/labeled_data.jsonl b/data/filter/labeled_data.jsonl index d27927f..9411088 100644 --- a/data/filter/labeled_data.jsonl +++ b/data/filter/labeled_data.jsonl @@ -914,3 +914,4 @@ {"aid": 754905929, "title": "【螺旋音リボ sublimation】未来線【多音階CVVC配布】", "description": "螺旋音リボが一周年を迎えました。\n今回配布する音源名は、螺旋音リボ 多音階CVVC sublimationです。デルタ式7モーラ母音優先リストを使用して録音した三音階CVVC音源です。\n収録音階はF#3,A3,A#3です。A#3は強めに録っております。\nsublimationは昇華という意味の英単語です。\n\nノイズ除去を塩釜(mylist/41289446)さんにお願いいたしました。ありがとうございました!\n\nそれではこれからも螺旋音リボをよろしくお願いいたします。\n\n音源ダウンロード: http://gluglu-singer.jimdo.com/\n素敵な本家: sm15746943\nust: sm16402299\nillustration: 悠みなも mylist/34371459\nvocal: 螺旋音リボ 多音階CVVC sublimation (cho 螺旋音リボ 単独音)\n\n------------------------------------------------------------------------------------\n我是日本人。我使用Google翻译进行了翻译,但是如果我输入有误,我们深感抱歉。\n\n螺旋音リボ正在庆祝其成立一周年。\n这次要分配的声源名称是螺旋音リボ Multitone CVVC升华。使用delta 7-mora元音优先级列表录制的三音CVVC声源。\n记录的音阶为F#3,A3,A#3。强烈记录了#3。\n升华是英文单词,意为升华。\n\n我请盐og先生(mylist / 41289446)消除噪音。非常感谢你!\n\n感谢您一直以来对螺旋音リボ的支持。\n\n声音源下载:http://gluglu-singer.jimdo.com/\n不错的户主:sm15746943\nUST:sm16402299\n插图:悠みなも mylist / 34371459\nvocal:螺旋音リボ 多音CVVC升华(cho螺旋音リボ 单声)", "tags": ["UTAU", "日本", "UTAU音源配布所リンク", "螺旋音リボ", "VOCALOID→UTAUカバー曲", "未来線"], "author_info": "悠みなも: UTAU/螺旋音リボ/DTM Weibo MINAMO_YU0001", "label": 0} {"aid": 685634071, "title": "【AI歌手】腐草为萤【Cover】", "description": "原唱:银临\n引擎:歌叽歌叽\nPV借用:av2700587\nPV:星の祈\n绘:otakucake阿饼(P站:pixiv.me/otakucake)", "tags": ["AI", "歌叽歌叽"], "author_info": "一支粉笔w: 人生难得一知己,千古知音最难求。 | 网易云:粉笔 微博:http://weibo.com/Chalk10", "label": 2} {"aid": 3372898, "title": "【言和英文】Counting Stars", "description": "自制 终于投了。1.封面仍然灵魂p图。2.这玩意儿9月底就在做了,说好10月投,pv还解决不了。说好11月投,但又因为曲子本身又拖了。上个星期想投,结果,电脑坏掉加发烧。。。。。所以一路拖到今天才投。算是放下心中的一块石头了吧。3.这次画了很多参数,感觉有提升,不知道朱军是怎么看的。5.counting stars是我最最最喜欢的一首歌,已经循环一年。6.1R赶紧出4专!!!", "tags": ["VOCALOID", "COUNTING STARS", "言和", "ONEREPUBLIC"], "author_info": "JackChenZz: 随缘更新", "label": 0} +{"aid": 589575, "title": "【洛天依】约定", "description": "绫:天依,不管几年之后你能否记得我,我会一直恋你,爱你,铭心记住你! QAQ 被小汐撺掇调的,咱非常不擅长这拐来拐去的长调式滴说,不过这次应该算不上渣吧?(好吧- -无视这句)。翻的周慧的《约定》,谱子依旧自己来(其实只是自己重做MIDI了,网上的MIDI做觉得不好用,调子太高),3分43秒的字幕是UP无聊的XXX。PV用慢放的《春来发几只》弄的,剪掉了两组燕子镜头~ ~赶脚两个曲子的BPM同步了", "tags": ["乐正绫", "洛天依", "洛天依翻唱曲", "VOCALOID", "良调教", "约定", "周慧"], "author_info": "星璇の天空: Dr.冥月星璇 成就:不会调教的调教师(U/V/SV/C),专业咸鱼,不会科研的副研究员,不会编程的系统架构师,不会AI的无人机设计师...", "label": 2} \ No newline at end of file diff --git a/filter/RunningLogs.txt b/filter/RunningLogs.txt index 2518942..29ce991 100644 --- a/filter/RunningLogs.txt +++ b/filter/RunningLogs.txt @@ -18,4 +18,4 @@ Note 0324: V3.5-test3 # 用回3.2的FC层试试 0331: V3.6-test3 # 3.5不太行,我试着调下超参 0335: V3.7-test3 # 3.6还行,再调超参试试看 -0352: V3.8-test3 # 3.7不行,从3.6的基础重新调 \ No newline at end of file +0414: V3.8-test3 # 3.7不行,从3.6的基础重新调 \ No newline at end of file diff --git a/filter/predict.py b/filter/predict.py new file mode 100644 index 0000000..64e5540 --- /dev/null +++ b/filter/predict.py @@ -0,0 +1,139 @@ +import os +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +import sqlite3 +import json +import torch +from modelV3_4 import VideoClassifierV3_4 +from sentence_transformers import SentenceTransformer + +# 数据库配置 +DATABASE_PATH = "./data/main.db" +OUTPUT_FILE = "./data/filter/model_predicted.jsonl" +BATCH_SIZE = 128 # 批量处理的大小 + +def fetch_all_aids(conn): + """获取数据库中所有符合条件的aid""" + cursor = conn.cursor() + cursor.execute("SELECT aid FROM bili_info_crawl WHERE status = 'success'") + aids = [row[0] for row in cursor.fetchall()] + return aids + +def fetch_entry_data(conn, aid): + """获取单个条目的原始数据""" + cursor = conn.cursor() + cursor.execute("SELECT data FROM bili_info_crawl WHERE aid = ?", (aid,)) + d = cursor.fetchone() + data = d[0] if d else None + return data + +def parse_entry_data(data): + """解析原始数据为结构化信息""" + try: + obj = json.loads(data) + title = obj["View"]["title"] + description = obj["View"]["desc"] + tags = [tag["tag_name"] for tag in obj["Tags"] + if tag["tag_type"] in ["old_channel", "topic"]] + author_info = f"{obj['Card']['card']['name']}: {obj['Card']['card']['sign']}" + return title, description, tags, author_info + except (KeyError, json.JSONDecodeError) as e: + print(f"解析错误: {e}") + return None, None, None, None + +def initialize_model(): + """初始化模型和文本编码器""" + model = VideoClassifierV3_4() + model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt', map_location=torch.device('cpu'))) + model.eval() + + st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024") + return model, st_model + +def predict_batch(model, st_model, batch_data): + """批量执行预测""" + with torch.no_grad(): + input_texts = { + "title": [entry["title"] for entry in batch_data], + "description": [entry["description"] for entry in batch_data], + "tags": [" ".join(entry["tags"]) for entry in batch_data], + "author_info": [entry["author_info"] for entry in batch_data] + } + logits = model(input_texts=input_texts, sentence_transformer=st_model) + return torch.argmax(logits, dim=1).tolist() + +def process_entries(): + """主处理流程""" + # 初始化模型 + model, st_model = initialize_model() + + # 获取数据库连接 + conn = sqlite3.connect(DATABASE_PATH) + + # 获取所有aid + aids = fetch_all_aids(conn) + print(f"总需处理条目数: {len(aids)}") + + # 批量处理并保存结果 + with open(OUTPUT_FILE, "w", encoding="utf-8") as output: + batch_data = [] + for idx, aid in enumerate(aids, 1): + try: + # 获取并解析数据 + raw_data = fetch_entry_data(conn, aid) + if not raw_data: + continue + + title, desc, tags, author = parse_entry_data(raw_data) + if not title: + continue + + # 构造预测输入 + entry = { + "aid": aid, + "title": title, + "description": desc, + "tags": tags, + "author_info": author + } + batch_data.append(entry) + + # 当达到批量大小时进行预测 + if len(batch_data) >= BATCH_SIZE: + predictions = predict_batch(model, st_model, batch_data) + for entry, prediction in zip(batch_data, predictions): + output.write(json.dumps({ + "aid": entry["aid"], + "title": entry["title"], + "description": entry["description"], + "tags": entry["tags"], + "author_info": entry["author_info"], + "label": prediction + }, ensure_ascii=False) + "\n") + batch_data = [] # 清空批量数据 + + # 进度显示 + if idx % 100 == 0: + print(f"已处理 {idx}/{len(aids)} 条...") + + except Exception as e: + print(f"处理aid {aid} 时出错: {str(e)}") + + # 处理剩余的条目 + if batch_data: + predictions = predict_batch(model, st_model, batch_data) + for entry, prediction in zip(batch_data, predictions): + output.write(json.dumps({ + "aid": entry["aid"], + "title": entry["title"], + "description": entry["description"], + "tags": entry["tags"], + "author_info": entry["author_info"], + "label": prediction + }, ensure_ascii=False) + "\n") + + # 关闭数据库连接 + conn.close() + +if __name__ == "__main__": + process_entries() + print("预测完成,结果已保存至", OUTPUT_FILE) diff --git a/filter/tag.py b/filter/tag.py index 74ea251..6f0556d 100644 --- a/filter/tag.py +++ b/filter/tag.py @@ -86,8 +86,6 @@ def label_entries(db_path, aids): title, description, tags, author_info, url = parse_entry_data(data) if not title: # 如果解析失败,跳过 continue - if '原创' not in title and '原创' not in description: - continue # 展示信息 os.system("clear") print(f"AID: {aid}") diff --git a/filter/test.py b/filter/test.py new file mode 100644 index 0000000..5b83298 --- /dev/null +++ b/filter/test.py @@ -0,0 +1,38 @@ +import os +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1" +import torch +from modelV3_4 import VideoClassifierV3_4 +from sentence_transformers import SentenceTransformer + +def predict(json_input): + # 加载模型 + model = VideoClassifierV3_4() + model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt')) + model.eval() + + # 加载SentenceTransformer + sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024") + + input_texts = { + "title": [json_input["title"]], + "description": [json_input["description"]], + "tags": [" ".join(json_input["tags"])], + "author_info": [json_input["author_info"]] + } + + # 预测 + with torch.no_grad(): + logits = model( + input_texts=input_texts, + sentence_transformer=sentence_transformer + ) + pred = torch.argmax(logits, dim=1).item() + + return pred + +if __name__ == "__main__": + # 示例用法 + sample_input = {"title": "", "description": "", "tags": ["",""], "author_info": "xx: yy"} + + result = predict(sample_input) + print(f"预测结果: {result}")