update: inference code

This commit is contained in:
alikia2x (寒寒) 2025-02-01 19:30:44 +08:00
parent ce1b17dbad
commit fadf121fea
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
4 changed files with 6 additions and 6 deletions

View File

View File

@ -10,7 +10,7 @@ import tty
import termios
from sentence_transformers import SentenceTransformer
from db_utils import fetch_entry_data, parse_entry_data
from modelV3_4 import VideoClassifierV3_4
from modelV3_9 import VideoClassifierV3_9
class LabelingSystem:
def __init__(self, mode='model_testing', database_path="./data/main.db",
@ -27,7 +27,7 @@ class LabelingSystem:
self.model = None
self.sentence_transformer = None
if self.mode == 'model_testing':
self.model = VideoClassifierV3_4()
self.model = VideoClassifierV3_9()
self.model.load_state_dict(torch.load(model_path))
self.model.eval()
self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")

View File

@ -3,7 +3,7 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import sqlite3
import json
import torch
from modelV3_4 import VideoClassifierV3_4
from modelV3_9 import VideoClassifierV3_9
from sentence_transformers import SentenceTransformer
from tqdm import tqdm # 导入 tqdm
@ -43,8 +43,8 @@ def parse_entry_data(data):
def initialize_model():
"""初始化模型和文本编码器"""
model = VideoClassifierV3_4()
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt', map_location=torch.device('cpu')))
model = VideoClassifierV3_9()
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt', map_location=torch.device('cpu')))
model.eval()
st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")

View File

@ -1,7 +1,7 @@
from labeling_system import LabelingSystem
DATABASE_PATH = "./data/main.db"
MODEL_PATH = "./filter/checkpoints/best_model_V3.8.pt"
MODEL_PATH = "./filter/checkpoints/best_model_V3.9.pt"
OUTPUT_FILE = "./data/filter/real_test.jsonl"
BATCH_SIZE = 50