95 lines
2.9 KiB
Python
95 lines
2.9 KiB
Python
"""
|
|
Configuration loader for embedding models and other settings
|
|
"""
|
|
|
|
import toml
|
|
import os
|
|
from typing import Dict
|
|
from pydantic import BaseModel
|
|
from logger_config import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class EmbeddingModelConfig(BaseModel):
|
|
name: str
|
|
dimensions: int
|
|
type: str
|
|
api_endpoint: str = "https://api.openai.com/v1"
|
|
max_tokens: int = 8191
|
|
max_batch_size: int = 8
|
|
api_key_env: str = "OPENAI_API_KEY"
|
|
model_path: str = ""
|
|
tokenizer_name: str = ""
|
|
|
|
|
|
class ConfigLoader:
|
|
def __init__(self, config_path: str = None):
|
|
if config_path is None:
|
|
# Default to the embedding_models.toml file we created
|
|
config_path = os.path.join(
|
|
os.path.dirname(__file__), "embedding_models.toml"
|
|
)
|
|
|
|
self.config_path = config_path
|
|
self.embedding_models: Dict[str, EmbeddingModelConfig] = {}
|
|
self.selected_model: str = None
|
|
self._load_config()
|
|
|
|
def _load_config(self):
|
|
"""Load configuration from TOML file"""
|
|
try:
|
|
if not os.path.exists(self.config_path):
|
|
logger.warning(f"Config file not found: {self.config_path}")
|
|
return
|
|
|
|
with open(self.config_path, "r", encoding="utf-8") as f:
|
|
config_data = toml.load(f)
|
|
|
|
# Load embedding models
|
|
if "models" not in config_data:
|
|
return
|
|
|
|
for model_key, model_data in config_data["models"].items():
|
|
self.embedding_models[model_key] = EmbeddingModelConfig(
|
|
**model_data
|
|
)
|
|
|
|
self.selected_model = config_data.get("model", list(self.embedding_models.keys())[0])
|
|
|
|
logger.info(
|
|
f"Loaded {len(self.embedding_models)} embedding models from {self.config_path}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load config from {self.config_path}: {e}")
|
|
|
|
def get_selected_model(self) -> str:
|
|
"""Get selected model for health check"""
|
|
return self.selected_model
|
|
|
|
def get_embedding_models(self) -> Dict[str, EmbeddingModelConfig]:
|
|
"""Get all available embedding models"""
|
|
return self.embedding_models.copy()
|
|
|
|
def get_embedding_model(self, model_name: str) -> EmbeddingModelConfig:
|
|
"""Get specific embedding model config"""
|
|
if model_name not in self.embedding_models:
|
|
raise ValueError(
|
|
f"Embedding model '{model_name}' not found in configuration"
|
|
)
|
|
return self.embedding_models[model_name]
|
|
|
|
def list_model_names(self) -> list:
|
|
"""Get list of available model names"""
|
|
return list(self.embedding_models.keys())
|
|
|
|
def reload_config(self):
|
|
"""Reload configuration from file"""
|
|
self.embedding_models = {}
|
|
self._load_config()
|
|
|
|
|
|
# Global config loader instance
|
|
config_loader = ConfigLoader()
|