update: inference code
This commit is contained in:
parent
ce1b17dbad
commit
fadf121fea
0
data/filter/model_predicted_new.jsonl
Normal file
0
data/filter/model_predicted_new.jsonl
Normal file
@ -10,7 +10,7 @@ import tty
|
||||
import termios
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from db_utils import fetch_entry_data, parse_entry_data
|
||||
from modelV3_4 import VideoClassifierV3_4
|
||||
from modelV3_9 import VideoClassifierV3_9
|
||||
|
||||
class LabelingSystem:
|
||||
def __init__(self, mode='model_testing', database_path="./data/main.db",
|
||||
@ -27,7 +27,7 @@ class LabelingSystem:
|
||||
self.model = None
|
||||
self.sentence_transformer = None
|
||||
if self.mode == 'model_testing':
|
||||
self.model = VideoClassifierV3_4()
|
||||
self.model = VideoClassifierV3_9()
|
||||
self.model.load_state_dict(torch.load(model_path))
|
||||
self.model.eval()
|
||||
self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
||||
|
@ -3,7 +3,7 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
import sqlite3
|
||||
import json
|
||||
import torch
|
||||
from modelV3_4 import VideoClassifierV3_4
|
||||
from modelV3_9 import VideoClassifierV3_9
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from tqdm import tqdm # 导入 tqdm
|
||||
|
||||
@ -43,8 +43,8 @@ def parse_entry_data(data):
|
||||
|
||||
def initialize_model():
|
||||
"""初始化模型和文本编码器"""
|
||||
model = VideoClassifierV3_4()
|
||||
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt', map_location=torch.device('cpu')))
|
||||
model = VideoClassifierV3_9()
|
||||
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt', map_location=torch.device('cpu')))
|
||||
model.eval()
|
||||
|
||||
st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
||||
|
@ -1,7 +1,7 @@
|
||||
from labeling_system import LabelingSystem
|
||||
|
||||
DATABASE_PATH = "./data/main.db"
|
||||
MODEL_PATH = "./filter/checkpoints/best_model_V3.8.pt"
|
||||
MODEL_PATH = "./filter/checkpoints/best_model_V3.9.pt"
|
||||
OUTPUT_FILE = "./data/filter/real_test.jsonl"
|
||||
BATCH_SIZE = 50
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user