diff --git a/smartassist/src/backend.py b/smartassist/src/backend.py index 8ab7b6b..92db094 100644 --- a/smartassist/src/backend.py +++ b/smartassist/src/backend.py @@ -1,7 +1,7 @@ # Import the necessary functions from ollama, Flask, requests, threading from ollama import Client -from flask import Flask, request, jsonify, send_from_directory, render_template, session, make_response +from flask import Flask, request, jsonify, send_from_directory, render_template, session, make_response, Response from flask_cors import CORS, cross_origin # CORS stands for Cross-Origin Resource Sharing. This is necessary to allow the frontend to make requests to our backend. import requests import json @@ -9,6 +9,7 @@ import logging import os import utils from utils import GlobalState +from pathlib import Path # Create a logger for this module global_state = GlobalState() # Import the singleton that holds global states (e.g., logger) @@ -38,11 +39,17 @@ app.config['SESSION_TYPE'] = 'filesystem' # Store sessions on the filesystem logger.debug("flask app template folder: %s", app.template_folder) @app.route('/') -def index(): - """ - This route serves index.html to connecting clients +def index() -> Response: """ + This route serves index.html to connecting clients. + Initializes a new chat session by clearing the chat history in the session object. + Retrieves environment variables for the backend API endpoint, host URL of LLMs, and the selected LLM model. + Reads the client HTML template from file and passes it to the index.html template along with other necessary parameters. + + Returns: + Response: A Flask response containing the rendered index.html template. + """ session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session logger.debug("Entering route '/'") api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable @@ -75,11 +82,22 @@ def set_session(): # return 'No user ID found' + @app.route('/') -def serve_static(filename): +def serve_static(filename: str | Path) -> Response: + """ + Serves a static file from the application's static folder. + + Args: + filename (str or os.PathLike[str]): The path to the static file, relative to the STATIC_FOLDER directory. + + Returns: + Response: A Flask response containing the contents of the static file. + """ return send_from_directory(app.config['STATIC_FOLDER'], filename) + # CORS(app, resources={ # r"/api/chat": { # "origins": "*", @@ -87,33 +105,52 @@ def serve_static(filename): # } # }) -CORS(app, resources={ - r"/api/chat": { - "origins": "*" - } -}) +# CORS(app, resources={ +# r"/api/chat": { +# "origins": "*" +# } +# }) @app.route('/api/tags', methods=['GET']) -def tag(url = "http://localhost:11434/api/tags", headers = None): - """Get a list of models for the server located at url.""" +def get_tags(url: str = "http://localhost:11434/api/tags", headers: dict = None) -> dict: + """ + Retrieves a list of available models from a server. + + Args: + url (str): The URL of the server to query. Defaults to http://localhost:11434/api/tags. + headers (dict, optional): A dictionary of HTTP headers to include in the request. Defaults to None. + + Returns: + dict: A JSON response containing a list of available models, or an error message if the request fails. + + Raises: + requests.exceptions.RequestException: If there is a problem with the request. + """ try: logger.debug(f"url: {url} headers: {headers}") response = requests.get(url, headers=headers) return response.json() - # return response except requests.exceptions.RequestException as e: - logger.error("Request Exception: %s", str(e)) + logger.error("Request Exception: %s", str(e)) return {'error': 'Failed to process request'} + + @app.route('/api/chat', methods=['POST']) -def chat(): -# def chat(model = "phi3:mini"): +def chat() -> dict[str, any]: """ - This function handles the chat. The frontend client (web browser) calls the - backend server through this endpoint (/api/chat) that manage queries - to the LLM (Large Language Model) server and it also manages the response - from the LLM server. + Handles chat functionality by sending a query to an LLM server and + returning the response. + + This endpoint expects a JSON payload with the following structure: + { + 'query': str, + 'url_server': str (optional), + 'model': str (optional) + } + + :return: A dictionary containing the LLM's response """ # Get the message from the JSON in the request body data = request.get_json() @@ -138,26 +175,7 @@ def chat(): "stream": False } url = url_server - - # TODO: This section should only run when changing to new endpoint... - # begin refactor ################################### - headers = { # Set default header - "Content-Type": "application/json", - } - found_endpoint = False - endpoints = global_state.get_endpoints() - for endpoint in endpoints: - if endpoint["url"] == url: # Look for endpoint with this url - found_endpoint = True - if endpoint["provider"] == "ollama": # Currently only supporting ollama servers - if "requestOptions" in endpoint: # Check if authentication is needed - headers.update({ - "Authorization": endpoint["requestOptions"]["headers"]["Authorization"] - }) - if found_endpoint == False: - # Raise some error or whatever... - logger.debug(f"Host {url} not found") - # end refactor ################################### + headers = get_auth_headers(url) try: url = url + "/api/generate" @@ -179,24 +197,53 @@ def chat(): @app.route('/api/endpoints', methods=['GET']) -def get_endpoints(): - # Replace this with your actual logic to fetch endpoint data - # endpoints = [ - # {'endpoint': 'endpoint1', 'llm': 'llm1'}, - # {'endpoint': 'endpoint1', 'llm': 'llm2'}, - # {'endpoint': 'endpoint2', 'llm': 'llm1'}, - # {'endpoint': 'endpoint2', 'llm': 'llm3'}, - # {'endpoint': 'endpoint2', 'llm': 'llm4'}, - # {'endpoint': 'endpoint3', 'llm': 'llm4'}, - # ] - endpoints = [] # List of dictionaries, each of which contains {'endpoint': 'endpoint1', 'llm': 'llm1'} +def get_endpoints() -> str: + """ + Returns a list of available endpoints with their corresponding LLMs. + + This endpoint fetches all endpoints and their associated LLMs from the global state, + then returns them as a JSON response. + + :return: A JSON string representing a dictionary containing a list of dictionaries, + each representing an endpoint title and supported LLM. + """ + endpoints = [] # List of dictionaries, each of which contains {'title': 'title1', 'llm': 'llm1'} eps = global_state.get_endpoints() for ep in eps: llms = global_state.get_list_of_available_llms(ep) for llm in llms: - endpoints.append({'endpoint': ep.get('title'), 'llm': llm}) + endpoints.append({'title': ep.get('title'), 'llm': llm}) return jsonify(endpoints) +@app.route('/api/select_endpoint_llm', methods=['POST']) +def select_endpoint_llm() -> Response: + """ + Selects the endpoint associated with the tuple (title, LLM) from the request body. + + Request Body: + - title: str - The title of the endpoint to select. + - llm: str - The LLM to set. + + Returns: + A JSON response indicating whether the endpoint and LLM were selected successfully. + + Raises: + ValueError: If there is not exactly one endpoint with the specified title. + """ + data = request.get_json() + title = data['title'] + llm = data['llm'] + + endpoints = global_state.get_endpoints_with_key_value('title', title) + if len(endpoints) != 1: + raise ValueError(f"Expected exactly one endpoint with title '{title}', found {len(endpoints)}") + + global_state.set_host_url(endpoints[0]['url']) + global_state.set_llm(llm) + logger.debug(f"Updated to host url {endpoints[0]['url']} and LLM {llm}") + + return jsonify({'message': 'Endpoint and LLM selected successfully'}) + @app.route('/smartassist', methods=["POST"]) def smartassist(): @@ -216,6 +263,35 @@ def get_response(user_query): response = client.generate_response(user_query) # Generate and retrieve the response based on user's query return response +def get_auth_headers(url: str) -> dict: + """ + Returns authentication headers for a given URL. + + This function checks if an endpoint with the provided URL exists in the global state, + and returns the corresponding authentication headers. If no such endpoint is found, + it returns a default header. + """ + # TODO: The full operation should only have to run when changing to new endpoint. + + # Set default header + headers = { + "Content-Type": "application/json", + } + + found_endpoint = False + endpoints = global_state.get_endpoints() + for endpoint in endpoints: + if endpoint["url"] == url: # Look for endpoint with this URL + found_endpoint = True + #if endpoint["provider"] == "ollama": # Currently only supporting ollama servers - not needed if API the same + if "requestOptions" in endpoint: # Check if authentication is needed + headers.update({ + "Authorization": endpoint["requestOptions"]["headers"]["Authorization"] + }) + if not found_endpoint: + logger.debug(f"Host {url} not found") + + return headers def run_flask(fport=5005): """