290 lines
11 KiB
Python
290 lines
11 KiB
Python
# This module contains definitions of variables, functions, classes, et cetera, that are
|
|
# imported to more than one other module. The rational for defining these things here
|
|
# is that it is easier to avoid circular imports when they are defined in a central location.
|
|
import logging
|
|
import json
|
|
import requests
|
|
from typing import Optional
|
|
from enums import LogLevel
|
|
|
|
class GlobalState:
|
|
"""
|
|
This class holds various variables and methods which are accessible across
|
|
different modules in the Python project using the Singleton design pattern.
|
|
This ensures that only one instance of the class is created and shared among
|
|
all modules, preventing circular imports and providing a centralized location
|
|
for managing shared resources.
|
|
"""
|
|
_instance = None # Private class attribute to hold the single instance of the class
|
|
|
|
def __new__(cls) -> 'GlobalState':
|
|
"""
|
|
Create a new instance of the GlobalState class.
|
|
|
|
This is a singleton implementation, so only one instance will be created.
|
|
"""
|
|
if cls._instance is None:
|
|
cls._instance = super(GlobalState, cls).__new__(cls)
|
|
cls._instance.log_level = 'INFO' # Default logging level
|
|
cls._instance.logger = logging.getLogger() # Get root logger for the caller module
|
|
handler = logging.StreamHandler() # Or other handler (FileHandler for logs to file)
|
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
handler.setFormatter(formatter)
|
|
cls._instance.logger.addHandler(handler)
|
|
cls._instance.logger.setLevel(getattr(logging, cls._instance.log_level)) # Initialize root logger level
|
|
cls._instance.logger.info(" __new__(cls): Logger in GlobalState created: %s", cls._instance.logger)
|
|
cls._instance.host_url = None # Currently used LLM host
|
|
cls._instance.llm = "phi3:mini" # Default LLM for queries. TODO: Check with ollama server that it actually exists
|
|
# cls._instance.backend_api_ep = "http://localhost:5005/api/chat" # Default backend API endpoint
|
|
# Try making things more aligned with the outline of the yaml file
|
|
cls._instance.backend = dict() # A dictionary that holds info on which server the clients connect to
|
|
cls._instance.endpoints = [] # A list that holds info on which endpoints are available for use (server url, model name, provider et cetera)
|
|
# logging - already done in __new__, perhaps change layout later
|
|
|
|
return cls._instance
|
|
|
|
# def configure_logging(self, level: Optional[LogLevel] = None) -> None:
|
|
# """
|
|
# Configure the logging system for this project.
|
|
|
|
# Args:
|
|
# level (LogLevel): The log level to use. If None, uses the default log level set in `self.log_level`.
|
|
|
|
# Notes:
|
|
# This method sets up logging for the project and logs a message at the debug level indicating the effective log level.
|
|
# """
|
|
def configure_logging(self, level: Optional[str] = None) -> None:
|
|
|
|
"""
|
|
Configure the logging system for this project.
|
|
|
|
Args:
|
|
level (str): The log level to use. Can be one of the standard Python log levels (e.g., 'DEBUG', 'INFO', 'WARNING', etc.). If None, uses the default log level set in `self.log_level`.
|
|
|
|
Notes:
|
|
This method sets up logging for the project and logs a message at the debug level indicating the effective log level.
|
|
"""
|
|
if level is None:
|
|
level = self.log_level
|
|
# numeric_level = getattr(logging, level.upper()) # Convert string to numeric level
|
|
numeric_level = getattr(logging, level.upper()) # Convert string to numeric level
|
|
self.logger.setLevel(numeric_level)
|
|
self.logger.debug(f"utils.py -- configure_logging(): effective log level is {level} which is {self.logger.getEffectiveLevel()}")
|
|
|
|
# def set_log_level(self, level: LogLevel) -> None:
|
|
# """
|
|
# Set the log level for this project.
|
|
|
|
# Args:
|
|
# level (LogLevel): The new log level to use. Can be one of the evels defined in enum.py (e.g., DEBUG, INFO, WARNING, CRITICAL etc.).
|
|
|
|
# Notes:
|
|
# This method updates the `self.log_level` attribute and calls `configure_logging()` to apply the change.
|
|
# """
|
|
def set_log_level(self, level: str = 'INFO') -> None:
|
|
"""
|
|
Set the log level for this project.
|
|
|
|
Args:
|
|
level (str): The new log level to use. Can be one of the standard Python log levels (e.g., 'DEBUG', 'INFO', 'WARNING', etc.).
|
|
|
|
Notes:
|
|
This method updates the `self.log_level` attribute and calls `configure_logging()` to apply the change.
|
|
"""
|
|
self.log_level = level
|
|
self.configure_logging()
|
|
|
|
def get_log_level(self) -> str:
|
|
"""
|
|
Get the current log level.
|
|
|
|
Returns:
|
|
str: The current log level (e.g., 'DEBUG', 'INFO', 'WARNING', etc.).
|
|
"""
|
|
return self.log_level
|
|
|
|
def get_effective_log_level(self) -> int:
|
|
"""
|
|
Get the effective log level of the logger.
|
|
|
|
Returns:
|
|
int: The numeric value of the effective log level.
|
|
"""
|
|
return self.logger.getEffectiveLevel()
|
|
|
|
def get_logger(self, module_name: Optional[str] = None) -> logging.Logger:
|
|
|
|
"""
|
|
Get a logger instance based on the module name.
|
|
|
|
Args:
|
|
module_name (str): The name of the module to get a logger for. If None, uses the current module name (`__name__`).
|
|
|
|
Returns:
|
|
Logger: A logger instance configured for the specified module.
|
|
"""
|
|
if module_name is None:
|
|
module_name = __name__
|
|
logger = logging.getLogger(module_name)
|
|
return logger
|
|
|
|
def set_host_url(self, url: str = "http://localhost:11434") -> None:
|
|
"""
|
|
Set the URL of the host to which LLM requests are sent.
|
|
|
|
Args:
|
|
url (str): The new URL to use. Defaults to 'http://localhost:11434' if not specified.
|
|
"""
|
|
self.host_url = url
|
|
|
|
def get_host_url(self) -> str:
|
|
"""
|
|
Get the current URL of the host used for LLMs.
|
|
|
|
Returns:
|
|
str: The current URL of the host.
|
|
"""
|
|
return self.host_url
|
|
|
|
def set_llm(self, model_name: str = "phi3:mini") -> None:
|
|
"""
|
|
Set the LLM to use for queries.
|
|
|
|
Args:
|
|
model_name (str): The name of the LLM to use. Defaults to 'phi3:mini' if not specified.
|
|
"""
|
|
self.llm = model_name
|
|
|
|
def get_llm(self) -> str:
|
|
"""
|
|
Get the current LLM used for queries.
|
|
|
|
Returns:
|
|
str: The name of the current LLM.
|
|
"""
|
|
return self.llm
|
|
|
|
def set_backend(self, backend: Optional[dict] = None) -> None:
|
|
|
|
"""
|
|
Set the backend server that web clients connect to.
|
|
|
|
Args:
|
|
backend (dict): A dictionary containing information about the backend server. If None, resets the backend server to its default value.
|
|
"""
|
|
self.backend = backend
|
|
|
|
def get_backend(self) -> dict:
|
|
"""
|
|
Get the current backend server used by web clients.
|
|
|
|
Returns:
|
|
dict: A dictionary containing information about the current backend server.
|
|
"""
|
|
return self.backend
|
|
|
|
def get_backend_api_ep(self) -> str:
|
|
"""
|
|
Get the API endpoint of the backend server.
|
|
|
|
Returns:
|
|
str: The URL of the API endpoint.
|
|
"""
|
|
return self.backend["url"]+self.backend["api"]
|
|
|
|
def set_endpoints(self, endpoints: Optional[list[dict]] = None) -> None:
|
|
"""
|
|
Set the list of endpoints used by this object.
|
|
|
|
Args:
|
|
endpoints (list): A list of endpoint dictionaries. Each dictionary should contain information about an endpoint.
|
|
If None, resets the endpoints to their default value.
|
|
|
|
Raises:
|
|
ValueError: If endpoints is not a list.
|
|
|
|
Notes:
|
|
Endpoints can be reset to their default value by passing None as the argument.
|
|
"""
|
|
if endpoints is not None:
|
|
if not isinstance(endpoints, list):
|
|
raise ValueError("Endpoints must be a list, even if there is just one model")
|
|
self.endpoints = endpoints
|
|
|
|
def get_endpoints(self) -> list[dict]:
|
|
"""
|
|
Get the complete list of endpoints.
|
|
|
|
Returns:
|
|
List of endpoints
|
|
"""
|
|
return self.endpoints
|
|
|
|
def get_endpoints_with_key(self, key: str) -> list[dict]:
|
|
"""
|
|
Returns a list of endpoint dictionaries that contain the specified key.
|
|
|
|
Args:
|
|
key (str): The key to search for in the endpoint dictionaries.
|
|
|
|
Returns:
|
|
List[Dict]: A list of endpoint dictionaries containing the specified key.
|
|
"""
|
|
return [ep for ep in self.endpoints if key in ep]
|
|
|
|
|
|
def fetch_models(self) -> None:
|
|
"""
|
|
Fetch models from endpoints and update the endpoint dictionaries.
|
|
Returns:
|
|
None
|
|
"""
|
|
logger = logging.getLogger(__name__)
|
|
|
|
for endpoint in self.endpoints:
|
|
try:
|
|
if endpoint["provider"] == "ollama":
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
if "requestOptions" in endpoint: # Check if authentication is needed
|
|
headers.update({
|
|
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
|
|
})
|
|
|
|
models_response = requests.get(endpoint["url"] + "/api/tags", headers=headers)
|
|
models_response.raise_for_status() # Raise an exception for HTTP errors
|
|
|
|
try:
|
|
models = models_response.json()
|
|
except json.JSONDecodeError as e:
|
|
logger.error(f"Failed to parse JSON response: {e}")
|
|
continue
|
|
|
|
if isinstance(models, dict) and 'error' in models: # Unclear if requests to any API actually add this in the response
|
|
logger.error('Error fetching models from backend: %s', models['error'])
|
|
else:
|
|
endpoint["models"] = models.get("models", []) # Get the list of models directly
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error(f"Request error: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error: {e}")
|
|
|
|
return # No value returned
|
|
|
|
|
|
def get_list_of_available_llms(self, endpoint: Optional[dict] = None) -> Optional[list[str]]:
|
|
"""
|
|
Returns a sorted list of Large Language Models (LLMs) available at the specified endpoint.
|
|
|
|
Args:
|
|
endpoint (dict): Optional endpoint dictionary to retrieve LLMs from. If not provided, will use internal endpoint configuration.
|
|
|
|
Returns:
|
|
list: A sorted list of LLM names (strings). Returns None if no LLMs are found or endpoint is invalid.
|
|
"""
|
|
llm_list = None
|
|
if isinstance(endpoint["models"], list):
|
|
llm_list = sorted([list_item['name'] for list_item in endpoint["models"]], key=str.lower)
|
|
return llm_list
|
|
|