Docstrings och type hinting. Lagt till route för (api/select_endpoint_llm. Brutit ut header-generering till get_auth_headers()

This commit is contained in:
Joakim Persson
2024-08-06 17:32:46 +02:00
parent fa98c7b162
commit f7f6ce2e49
+128 -52
View File
@@ -1,7 +1,7 @@
# Import the necessary functions from ollama, Flask, requests, threading # Import the necessary functions from ollama, Flask, requests, threading
from ollama import Client from ollama import Client
from flask import Flask, request, jsonify, send_from_directory, render_template, session, make_response from flask import Flask, request, jsonify, send_from_directory, render_template, session, make_response, Response
from flask_cors import CORS, cross_origin # CORS stands for Cross-Origin Resource Sharing. This is necessary to allow the frontend to make requests to our backend. from flask_cors import CORS, cross_origin # CORS stands for Cross-Origin Resource Sharing. This is necessary to allow the frontend to make requests to our backend.
import requests import requests
import json import json
@@ -9,6 +9,7 @@ import logging
import os import os
import utils import utils
from utils import GlobalState from utils import GlobalState
from pathlib import Path
# Create a logger for this module # Create a logger for this module
global_state = GlobalState() # Import the singleton that holds global states (e.g., logger) global_state = GlobalState() # Import the singleton that holds global states (e.g., logger)
@@ -38,11 +39,17 @@ app.config['SESSION_TYPE'] = 'filesystem' # Store sessions on the filesystem
logger.debug("flask app template folder: %s", app.template_folder) logger.debug("flask app template folder: %s", app.template_folder)
@app.route('/') @app.route('/')
def index(): def index() -> Response:
"""
This route serves index.html to connecting clients
""" """
This route serves index.html to connecting clients.
Initializes a new chat session by clearing the chat history in the session object.
Retrieves environment variables for the backend API endpoint, host URL of LLMs, and the selected LLM model.
Reads the client HTML template from file and passes it to the index.html template along with other necessary parameters.
Returns:
Response: A Flask response containing the rendered index.html template.
"""
session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session
logger.debug("Entering route '/'") logger.debug("Entering route '/'")
api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable
@@ -75,11 +82,22 @@ def set_session():
# return 'No user ID found' # return 'No user ID found'
@app.route('/<path:filename>') @app.route('/<path:filename>')
def serve_static(filename): def serve_static(filename: str | Path) -> Response:
"""
Serves a static file from the application's static folder.
Args:
filename (str or os.PathLike[str]): The path to the static file, relative to the STATIC_FOLDER directory.
Returns:
Response: A Flask response containing the contents of the static file.
"""
return send_from_directory(app.config['STATIC_FOLDER'], filename) return send_from_directory(app.config['STATIC_FOLDER'], filename)
# CORS(app, resources={ # CORS(app, resources={
# r"/api/chat": { # r"/api/chat": {
# "origins": "*", # "origins": "*",
@@ -87,33 +105,52 @@ def serve_static(filename):
# } # }
# }) # })
CORS(app, resources={ # CORS(app, resources={
r"/api/chat": { # r"/api/chat": {
"origins": "*" # "origins": "*"
} # }
}) # })
@app.route('/api/tags', methods=['GET']) @app.route('/api/tags', methods=['GET'])
def tag(url = "http://localhost:11434/api/tags", headers = None): def get_tags(url: str = "http://localhost:11434/api/tags", headers: dict = None) -> dict:
"""Get a list of models for the server located at url.""" """
Retrieves a list of available models from a server.
Args:
url (str): The URL of the server to query. Defaults to http://localhost:11434/api/tags.
headers (dict, optional): A dictionary of HTTP headers to include in the request. Defaults to None.
Returns:
dict: A JSON response containing a list of available models, or an error message if the request fails.
Raises:
requests.exceptions.RequestException: If there is a problem with the request.
"""
try: try:
logger.debug(f"url: {url} headers: {headers}") logger.debug(f"url: {url} headers: {headers}")
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
return response.json() return response.json()
# return response
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.error("Request Exception: %s", str(e)) logger.error("Request Exception: %s", str(e))
return {'error': 'Failed to process request'} return {'error': 'Failed to process request'}
@app.route('/api/chat', methods=['POST']) @app.route('/api/chat', methods=['POST'])
def chat(): def chat() -> dict[str, any]:
# def chat(model = "phi3:mini"):
""" """
This function handles the chat. The frontend client (web browser) calls the Handles chat functionality by sending a query to an LLM server and
backend server through this endpoint (/api/chat) that manage queries returning the response.
to the LLM (Large Language Model) server and it also manages the response
from the LLM server. This endpoint expects a JSON payload with the following structure:
{
'query': str,
'url_server': str (optional),
'model': str (optional)
}
:return: A dictionary containing the LLM's response
""" """
# Get the message from the JSON in the request body # Get the message from the JSON in the request body
data = request.get_json() data = request.get_json()
@@ -138,26 +175,7 @@ def chat():
"stream": False "stream": False
} }
url = url_server url = url_server
headers = get_auth_headers(url)
# TODO: This section should only run when changing to new endpoint...
# begin refactor ###################################
headers = { # Set default header
"Content-Type": "application/json",
}
found_endpoint = False
endpoints = global_state.get_endpoints()
for endpoint in endpoints:
if endpoint["url"] == url: # Look for endpoint with this url
found_endpoint = True
if endpoint["provider"] == "ollama": # Currently only supporting ollama servers
if "requestOptions" in endpoint: # Check if authentication is needed
headers.update({
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
})
if found_endpoint == False:
# Raise some error or whatever...
logger.debug(f"Host {url} not found")
# end refactor ###################################
try: try:
url = url + "/api/generate" url = url + "/api/generate"
@@ -179,24 +197,53 @@ def chat():
@app.route('/api/endpoints', methods=['GET']) @app.route('/api/endpoints', methods=['GET'])
def get_endpoints(): def get_endpoints() -> str:
# Replace this with your actual logic to fetch endpoint data """
# endpoints = [ Returns a list of available endpoints with their corresponding LLMs.
# {'endpoint': 'endpoint1', 'llm': 'llm1'},
# {'endpoint': 'endpoint1', 'llm': 'llm2'}, This endpoint fetches all endpoints and their associated LLMs from the global state,
# {'endpoint': 'endpoint2', 'llm': 'llm1'}, then returns them as a JSON response.
# {'endpoint': 'endpoint2', 'llm': 'llm3'},
# {'endpoint': 'endpoint2', 'llm': 'llm4'}, :return: A JSON string representing a dictionary containing a list of dictionaries,
# {'endpoint': 'endpoint3', 'llm': 'llm4'}, each representing an endpoint title and supported LLM.
# ] """
endpoints = [] # List of dictionaries, each of which contains {'endpoint': 'endpoint1', 'llm': 'llm1'} endpoints = [] # List of dictionaries, each of which contains {'title': 'title1', 'llm': 'llm1'}
eps = global_state.get_endpoints() eps = global_state.get_endpoints()
for ep in eps: for ep in eps:
llms = global_state.get_list_of_available_llms(ep) llms = global_state.get_list_of_available_llms(ep)
for llm in llms: for llm in llms:
endpoints.append({'endpoint': ep.get('title'), 'llm': llm}) endpoints.append({'title': ep.get('title'), 'llm': llm})
return jsonify(endpoints) return jsonify(endpoints)
@app.route('/api/select_endpoint_llm', methods=['POST'])
def select_endpoint_llm() -> Response:
"""
Selects the endpoint associated with the tuple (title, LLM) from the request body.
Request Body:
- title: str - The title of the endpoint to select.
- llm: str - The LLM to set.
Returns:
A JSON response indicating whether the endpoint and LLM were selected successfully.
Raises:
ValueError: If there is not exactly one endpoint with the specified title.
"""
data = request.get_json()
title = data['title']
llm = data['llm']
endpoints = global_state.get_endpoints_with_key_value('title', title)
if len(endpoints) != 1:
raise ValueError(f"Expected exactly one endpoint with title '{title}', found {len(endpoints)}")
global_state.set_host_url(endpoints[0]['url'])
global_state.set_llm(llm)
logger.debug(f"Updated to host url {endpoints[0]['url']} and LLM {llm}")
return jsonify({'message': 'Endpoint and LLM selected successfully'})
@app.route('/smartassist', methods=["POST"]) @app.route('/smartassist', methods=["POST"])
def smartassist(): def smartassist():
@@ -216,6 +263,35 @@ def get_response(user_query):
response = client.generate_response(user_query) # Generate and retrieve the response based on user's query response = client.generate_response(user_query) # Generate and retrieve the response based on user's query
return response return response
def get_auth_headers(url: str) -> dict:
"""
Returns authentication headers for a given URL.
This function checks if an endpoint with the provided URL exists in the global state,
and returns the corresponding authentication headers. If no such endpoint is found,
it returns a default header.
"""
# TODO: The full operation should only have to run when changing to new endpoint.
# Set default header
headers = {
"Content-Type": "application/json",
}
found_endpoint = False
endpoints = global_state.get_endpoints()
for endpoint in endpoints:
if endpoint["url"] == url: # Look for endpoint with this URL
found_endpoint = True
#if endpoint["provider"] == "ollama": # Currently only supporting ollama servers - not needed if API the same
if "requestOptions" in endpoint: # Check if authentication is needed
headers.update({
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
})
if not found_endpoint:
logger.debug(f"Host {url} not found")
return headers
def run_flask(fport=5005): def run_flask(fport=5005):
""" """