update: inference code
This commit is contained in:
parent
ce1b17dbad
commit
fadf121fea
0
data/filter/model_predicted_new.jsonl
Normal file
0
data/filter/model_predicted_new.jsonl
Normal file
@ -10,7 +10,7 @@ import tty
|
|||||||
import termios
|
import termios
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from db_utils import fetch_entry_data, parse_entry_data
|
from db_utils import fetch_entry_data, parse_entry_data
|
||||||
from modelV3_4 import VideoClassifierV3_4
|
from modelV3_9 import VideoClassifierV3_9
|
||||||
|
|
||||||
class LabelingSystem:
|
class LabelingSystem:
|
||||||
def __init__(self, mode='model_testing', database_path="./data/main.db",
|
def __init__(self, mode='model_testing', database_path="./data/main.db",
|
||||||
@ -27,7 +27,7 @@ class LabelingSystem:
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.sentence_transformer = None
|
self.sentence_transformer = None
|
||||||
if self.mode == 'model_testing':
|
if self.mode == 'model_testing':
|
||||||
self.model = VideoClassifierV3_4()
|
self.model = VideoClassifierV3_9()
|
||||||
self.model.load_state_dict(torch.load(model_path))
|
self.model.load_state_dict(torch.load(model_path))
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
||||||
|
@ -3,7 +3,7 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
from modelV3_4 import VideoClassifierV3_4
|
from modelV3_9 import VideoClassifierV3_9
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from tqdm import tqdm # 导入 tqdm
|
from tqdm import tqdm # 导入 tqdm
|
||||||
|
|
||||||
@ -43,8 +43,8 @@ def parse_entry_data(data):
|
|||||||
|
|
||||||
def initialize_model():
|
def initialize_model():
|
||||||
"""初始化模型和文本编码器"""
|
"""初始化模型和文本编码器"""
|
||||||
model = VideoClassifierV3_4()
|
model = VideoClassifierV3_9()
|
||||||
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.8.pt', map_location=torch.device('cpu')))
|
model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt', map_location=torch.device('cpu')))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
st_model = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from labeling_system import LabelingSystem
|
from labeling_system import LabelingSystem
|
||||||
|
|
||||||
DATABASE_PATH = "./data/main.db"
|
DATABASE_PATH = "./data/main.db"
|
||||||
MODEL_PATH = "./filter/checkpoints/best_model_V3.8.pt"
|
MODEL_PATH = "./filter/checkpoints/best_model_V3.9.pt"
|
||||||
OUTPUT_FILE = "./data/filter/real_test.jsonl"
|
OUTPUT_FILE = "./data/filter/real_test.jsonl"
|
||||||
BATCH_SIZE = 50
|
BATCH_SIZE = 50
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user