From 8e983919e5e006a6516634b96cdff50f07814dee Mon Sep 17 00:00:00 2001 From: Joakim Persson Date: Sun, 4 Aug 2024 11:48:38 +0200 Subject: [PATCH] =?UTF-8?q?Refakrotiserat=20s=C3=A5=20att=20den=20externa?= =?UTF-8?q?=20funktionen=20fetch=5Fmodels=5Ffrom=5Fendpoints()=20gjorts=20?= =?UTF-8?q?om=20till=20klassmetoden=20fetch=5Fmodels()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- smartassist/src/utils.py | 79 ++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/smartassist/src/utils.py b/smartassist/src/utils.py index e959950..bc0e1df 100644 --- a/smartassist/src/utils.py +++ b/smartassist/src/utils.py @@ -4,45 +4,6 @@ import logging import json import requests -#from backend import GlobalState # Assuming GlobalState is defined there - -def fetch_models_from_endpoints(endpoints, global_state): - """ - Fetch models from endpoints and update the endpoint dictionaries. - - Args: - endpoints (list): List of endpoint dictionaries. - global_state: The global state object. - - Returns: - None - """ - logger = logging.getLogger(__name__) - - for endpoint in endpoints: - if endpoint["provider"] == "ollama": - headers = { - "Content-Type": "application/json", - } - if "requestOptions" in endpoint: # Check if authentication is needed - headers.update({ - "Authorization": endpoint["requestOptions"]["headers"]["Authorization"] - }) - - try: - models_response = requests.get(endpoint["url"] + "/api/tags", headers=headers) - models_response.raise_for_status() # Raise an exception for HTTP errors - models = models_response.json() - except requests.exceptions.RequestException as e: - logger.error("Error fetching models from backend: %s", str(e)) - continue - - if isinstance(models, dict) and 'error' in models: - logger.error('Error fetching models from backend: %s', models['error']) - else: - endpoint["models"] = models.get("models", []) # Get the list of models directly - # logger.debug("models = \n{}".format(json.dumps(models, indent=4))) - class GlobalState: """ @@ -140,8 +101,46 @@ class GlobalState: self.endpoints = endpoints def get_endpoints(self): - """Return the list of endpoints""" + """ + Get the list of endpoints. + + Returns: + List of endpoints + """ return self.endpoints + + def fetch_models(self): + """ + Fetch models from endpoints and update the endpoint dictionaries. + + Returns: + None + """ + logger = logging.getLogger(__name__) + + for endpoint in self.endpoints: + if endpoint["provider"] == "ollama": + headers = { + "Content-Type": "application/json", + } + if "requestOptions" in endpoint: # Check if authentication is needed + headers.update({ + "Authorization": endpoint["requestOptions"]["headers"]["Authorization"] + }) + + try: + models_response = requests.get(endpoint["url"] + "/api/tags", headers=headers) + models_response.raise_for_status() # Raise an exception for HTTP errors + models = models_response.json() + except requests.exceptions.RequestException as e: + logger.error("Error fetching models from backend: %s", str(e)) + continue + + if isinstance(models, dict) and 'error' in models: + logger.error('Error fetching models from backend: %s', models['error']) + else: + endpoint["models"] = models.get("models", []) # Get the list of models directly + def get_list_of_available_llms(self, endpoint=None): """Return a sorted list of LLMs available at endpoint"""