From 8452c7569b9d9486d18ad375d187810bbf304e04 Mon Sep 17 00:00:00 2001 From: Joakim Persson Date: Mon, 5 Aug 2024 00:50:45 +0200 Subject: [PATCH] =?UTF-8?q?Nu=20anv=C3=A4nds=20informationen=20fr=C3=A5n?= =?UTF-8?q?=20konfigurationsfilen=20n=C3=A4r=20url=20f=C3=B6r=20LLM-server?= =?UTF-8?q?=20best=C3=A4ms?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- smartassist/src/backend.py | 49 ++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/smartassist/src/backend.py b/smartassist/src/backend.py index b746d34..571cbf5 100644 --- a/smartassist/src/backend.py +++ b/smartassist/src/backend.py @@ -12,7 +12,7 @@ from utils import GlobalState # Create a logger for this module global_state = GlobalState() # Import the singleton that holds global states (e.g., logger) -logger = global_state.getLogger(__name__) # Logger for this module, inherit properties of the root logger +logger = global_state.get_logger(__name__) # Logger for this module, inherit properties of the root logger # Find out the path to current directory according to the Python interpreter (venv) @@ -45,13 +45,15 @@ def index(): session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session logger.debug("Entering route '/'") - # api_endpoint = os.environ['BE_API_ENDPOINT'] # Retrieve the environment variable api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable - logger.debug("Backend API endpoint: %s", api_endpoint) + host_url = global_state.get_host_url() use_model = global_state.get_llm() + logger.debug("Backend API endpoint:\t%s", api_endpoint) + logger.debug("Host of LLMs:\t\t%s", host_url) + logger.debug("LLM to use:\t\t\t%s", use_model) with open('smartassist/src/html/client.html', 'r') as f: client_html = f.read() - logger.debug("Client HTML (first few characters): %s", client_html[:50]) # Print to see if it's loading + # logger.debug("Client HTML (first few characters): %s", client_html[:50]) # Print to see if it's loading # logger.debug("Client HTML (all characters): %s", client_html) # Print to see if it's loading return render_template('index.html', api_endpoint=api_endpoint, use_model = use_model, client_content=client_html) @@ -93,7 +95,6 @@ CORS(app, resources={ @app.route('/api/tags', methods=['GET']) def tag(url = "http://localhost:11434/api/tags", headers = None): -# def tag(url = "http://localhost:11434/api/tags", headers = {"Content-Type": "application/json"}): """Get a list of models for the server located at url.""" try: logger.debug(f"url: {url} headers: {headers}") @@ -106,8 +107,8 @@ def tag(url = "http://localhost:11434/api/tags", headers = None): @app.route('/api/chat', methods=['POST']) -def chat(model = "phi3:mini"): -# def chat(url_server = "http://localhost:11434/api/generate", model = "phi3:mini"): +def chat(): +# def chat(model = "phi3:mini"): """ This function handles the chat. The frontend client (web browser) calls the backend server through this endpoint (/api/chat) that manage queries @@ -117,8 +118,9 @@ def chat(model = "phi3:mini"): # Get the message from the JSON in the request body data = request.get_json() message = data.get('query') - url_server = data.get('url_server', "https://ollama-test.wara-ops.org/api/generate") # Use provided URL or default - model = data.get('model', model) # Use provided model or default if not provided + url_server = data.get('url_server', global_state.get_host_url()) # Use provided URL or current if not provided + # url_server = data.get('url_server', "https://ollama-test.wara-ops.org/api/generate") # Use provided URL or default + model = data.get('model', global_state.get_llm()) # Use provided model or current if not provided # Get chat history from session storage (e.g., a dictionary) chat_history = session.get('chat_history', []) @@ -135,13 +137,30 @@ def chat(model = "phi3:mini"): 'prompt': '\n'.join([f"{item['role']}: {item['message']}" for item in chat_history]), "stream": False } - + url = url_server + + # TODO: This section should only run when changing to new endpoint... + # begin refactor ################################### + headers = { # Set default header + "Content-Type": "application/json", + } + found_endpoint = False + endpoints = global_state.get_endpoints() + for endpoint in endpoints: + if endpoint["url"] == url: # Look for endpoint with this url + found_endpoint = True + if endpoint["provider"] == "ollama": # Currently only supporting ollama servers + if "requestOptions" in endpoint: # Check if authentication is needed + headers.update({ + "Authorization": endpoint["requestOptions"]["headers"]["Authorization"] + }) + if found_endpoint == False: + # Raise some error or whatever... + logger.debug(f"Host {url} not found") + # end refactor ################################### + try: - url = url_server - headers = { - "Content-Type": "application/json", - "Authorization": "Basic ZWNzanBlcjoxM2JjMTU4ZDhmNmY5YTU4YTkzZDNmY2I=" - } + url = url + "/api/generate" logger.debug(f"url: {url} headers: {headers}") response = requests.post(url, headers=headers,