update: inference code

This commit is contained in:
alikia2x (寒寒) 2025-02-01 19:30:44 +08:00
parent ce1b17dbad
commit fadf121fea
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
4 changed files with 6 additions and 6 deletions

View File

View File

@ -10,7 +10,7 @@ import tty
import termios import termios
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from db_utils import fetch_entry_data, parse_entry_data from db_utils import fetch_entry_data, parse_entry_data
from modelV3_4 import VideoClassifierV3_4 from modelV3_9 import VideoClassifierV3_9
class LabelingSystem: class LabelingSystem:
def __init__(self, mode='model_testing', database_path="./data/main.db", def __init__(self, mode='model_testing', database_path="./data/main.db",
@ -27,7 +27,7 @@ class LabelingSystem:
self.model = None self.model = None
self.sentence_transformer = None self.sentence_transformer = None
if self.mode == 'model_testing': if self.mode == 'model_testing':
self.model = VideoClassifierV3_4() self.model = VideoClassifierV3_9()
self.model.load_state_dict(torch.load(model_path)) self.model.load_state_dict(torch.load(model_path))
self.model.eval() self.model.eval()
self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024") self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")

View File

@ -3,7 +3,7 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import sqlite3 import sqlite3
import json import json
import torch import torch
from modelV3_4 import VideoClassifierV3_4 from modelV3_9 import VideoClassifierV3_9
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from tqdm import tqdm # 导入 tqdm from tqdm import tqdm # 导入 tqdm
@ -43,8 +43,8 @@ def parse_entry_data(data):
def initialize_model(): def initialize_model():
"""初始化模型和文本编码器""" """初始化模型和文本编码器"""
model = VideoClassifierV3_4() model = VideoClassifierV3_9()
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt', map_location=torch.device('cpu'))) model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt', map_location=torch.device('cpu')))
model.eval() model.eval()
st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024") st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")

View File

@ -1,7 +1,7 @@
from labeling_system import LabelingSystem from labeling_system import LabelingSystem
DATABASE_PATH = "./data/main.db" DATABASE_PATH = "./data/main.db"
MODEL_PATH = "./filter/checkpoints/best_model_V3.8.pt" MODEL_PATH = "./filter/checkpoints/best_model_V3.9.pt"
OUTPUT_FILE = "./data/filter/real_test.jsonl" OUTPUT_FILE = "./data/filter/real_test.jsonl"
BATCH_SIZE = 50 BATCH_SIZE = 50