151 lines
6.6 KiB
Python
151 lines
6.6 KiB
Python
# This module contains definitions of variables, functions, classes, et cetera, that are
|
|
# imported to more than one other module. The rational for defining these things here
|
|
# is that it is easier to avoid circular imports when they are defined in a central location.
|
|
import logging
|
|
import json
|
|
import requests
|
|
#from backend import GlobalState # Assuming GlobalState is defined there
|
|
|
|
def fetch_models_from_endpoints(endpoints, global_state):
|
|
"""
|
|
Fetch models from endpoints and update the endpoint dictionaries.
|
|
|
|
Args:
|
|
endpoints (list): List of endpoint dictionaries.
|
|
global_state: The global state object.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
|
|
for endpoint in endpoints:
|
|
if endpoint["provider"] == "ollama":
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
if "requestOptions" in endpoint: # Check if authentication is needed
|
|
headers.update({
|
|
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
|
|
})
|
|
|
|
try:
|
|
models_response = requests.get(endpoint["url"] + "/api/tags", headers=headers)
|
|
models_response.raise_for_status() # Raise an exception for HTTP errors
|
|
models = models_response.json()
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error("Error fetching models from backend: %s", str(e))
|
|
continue
|
|
|
|
if isinstance(models, dict) and 'error' in models:
|
|
logger.error('Error fetching models from backend: %s', models['error'])
|
|
else:
|
|
endpoint["models"] = models.get("models", []) # Get the list of models directly
|
|
# logger.debug("models = \n{}".format(json.dumps(models, indent=4)))
|
|
|
|
|
|
class GlobalState:
|
|
"""
|
|
This class holds various variables and methods which are accessible across
|
|
different modules in the Python project using the Singleton design pattern.
|
|
This ensures that only one instance of the class is created and shared among
|
|
all modules, preventing circular imports and providing a centralized location
|
|
for managing shared resources.
|
|
"""
|
|
_instance = None # Private class attribute to hold the single instance of the class
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super(GlobalState, cls).__new__(cls)
|
|
cls._instance.log_level = 'INFO' # Default logging level
|
|
cls._instance.logger = logging.getLogger() # Get root logger for the caller module
|
|
handler = logging.StreamHandler() # Or other handler (FileHandler for logs to file)
|
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
handler.setFormatter(formatter)
|
|
cls._instance.logger.addHandler(handler)
|
|
cls._instance.logger.setLevel(getattr(logging, cls._instance.log_level)) # Initialize root logger level
|
|
cls._instance.logger.info(" __new__(cls): Logger in GlobalState created: %s", cls._instance.logger)
|
|
cls._instance.host_url = None # Currently used LLM host
|
|
cls._instance.llm = "phi3:mini" # Default LLM for queries. TODO: Check with ollama server that it actually exists
|
|
# cls._instance.backend_api_ep = "http://localhost:5005/api/chat" # Default backend API endpoint
|
|
# Try making things more aligned with the outline of the yaml file
|
|
cls._instance.backend = dict() # A dictionary that holds info on which server the clients connect to
|
|
cls._instance.endpoints = [] # A list that holds info on which endpoints are available for use (server url, model name, provider et cetera)
|
|
# logging - already done in __new__, perhaps change layout later
|
|
|
|
return cls._instance
|
|
|
|
def configure_logging(self, level=None):
|
|
"""Set up logging for the project."""
|
|
if level is None:
|
|
level = self.log_level
|
|
# numeric_level = getattr(logging, level.upper()) # Convert string to numeric level
|
|
numeric_level = getattr(logging, level.upper()) # Convert string to numeric level
|
|
self.logger.setLevel(numeric_level)
|
|
self.logger.debug(f"utils.py -- configure_logging(): effective log level is {level} which is {self.logger.getEffectiveLevel()}")
|
|
|
|
def set_log_level(self, level = 'INFO'):
|
|
"""Set the logging level."""
|
|
self.log_level = level
|
|
self.configure_logging()
|
|
|
|
def get_log_level(self):
|
|
"""Getter for log_level attribute."""
|
|
return self.log_level
|
|
|
|
def get_effective_log_level(self):
|
|
"""Getter for effective log level of loggerattribute."""
|
|
return self.logger.getEffectiveLevel()
|
|
|
|
def getLogger(self, module_name = None):
|
|
"""Return a logger based on the module name."""
|
|
if module_name is None:
|
|
module_name = __name__
|
|
logger = logging.getLogger(module_name)
|
|
return logger
|
|
|
|
def set_host_url(self, url="http://localhost:11434"):
|
|
"""Set the host url to which LLM requests are sent"""
|
|
self.host = url
|
|
|
|
def get_host_url(self):
|
|
"""Get the url for the currently used host for LLMs"""
|
|
return self.host_url
|
|
|
|
def set_llm(self, model_name="phi3:mini"):
|
|
"""Set LLM for queries"""
|
|
self.llm = model_name
|
|
|
|
def get_llm(self):
|
|
"""Getter for which LLM is used for queries"""
|
|
return self.llm
|
|
|
|
def set_backend(self, backend=None):
|
|
"""Set backend server that web clients connect to"""
|
|
self.backend = backend
|
|
|
|
def get_backend(self):
|
|
"""Getter for backend server that web clients connect to"""
|
|
return self.backend
|
|
|
|
def get_backend_api_ep(self):
|
|
"""Getter for backend API endpoint"""
|
|
return self.backend["url"]+self.backend["api"]
|
|
|
|
def set_endpoints(self, endpoints=None):
|
|
"""Set the list of endpoints."""
|
|
if endpoints is not None:
|
|
if not isinstance(endpoints, list):
|
|
raise ValueError("Endpoints must be a list, even if there is just one model")
|
|
self.endpoints = endpoints
|
|
|
|
def get_endpoints(self):
|
|
"""Return the list of endpoints"""
|
|
return self.endpoints
|
|
|
|
def get_list_of_available_llms(self, endpoint=None):
|
|
"""Return a sorted list of LLMs available at endpoint"""
|
|
llm_list = None
|
|
if isinstance(endpoint["models"], list):
|
|
llm_list = sorted([list_item['name'] for list_item in endpoint["models"]], key=str.lower)
|
|
return llm_list |