From 06628d5c19281ecd3dbc10573c53fc220f9e7814 Mon Sep 17 00:00:00 2001 From: Joakim Persson Date: Sat, 3 Aug 2024 18:44:35 +0200 Subject: [PATCH] =?UTF-8?q?Lagt=20till=20funktion=20som=20h=C3=A4mtar=20in?= =?UTF-8?q?formation=20om=20tillg=C3=A4ngliga=20LLM=20fr=C3=A5n=20endpoint?= =?UTF-8?q?-servrar?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- smartassist/src/utils.py | 43 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/smartassist/src/utils.py b/smartassist/src/utils.py index d0c5c6d..2fb8fa0 100644 --- a/smartassist/src/utils.py +++ b/smartassist/src/utils.py @@ -2,6 +2,49 @@ # imported to more than one other module. The rational for defining these things here # is that it is easier to avoid circular imports when they are defined in a central location. import logging +import json +import requests +#from backend import GlobalState # Assuming GlobalState is defined there + +def fetch_models_from_endpoints(endpoints, global_state): + """ + Fetch models from endpoints and update the endpoint dictionaries. + + Args: + endpoints (list): List of endpoint dictionaries. + global_state: The global state object. + + Returns: + None + """ + logger = logging.getLogger(__name__) + + for endpoint in endpoints: + if endpoint["provider"] == "ollama": + headers = { + "Content-Type": "application/json", + } + if "requestOptions" in endpoint: # Check if authentication is needed + headers.update({ + "Authorization": endpoint["requestOptions"]["headers"]["Authorization"] + }) + + try: + models_response = requests.get(endpoint["url"] + "/api/tags", headers=headers) + models_response.raise_for_status() # Raise an exception for HTTP errors + models = models_response.json() + except requests.exceptions.RequestException as e: + logger.error("Error fetching models from backend: %s", str(e)) + continue + + if isinstance(models, dict) and 'error' in models: + logger.error('Error fetching models from backend: %s', models['error']) + else: + endpoint["models"] = models.get("models", []) # get the list of models directly + logger.debug("models = \n{}".format(json.dumps(models, indent=4))) + if endpoint["model"] is not "AUTODETECT": # Check if specified model is available + logger.debug("Asking for specific model") + class GlobalState: """