Docstrings och type hinting. Lagt till route för (api/select_endpoint_llm. Brutit ut header-generering till get_auth_headers()
This commit is contained in:
+128
-52
@@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
# Import the necessary functions from ollama, Flask, requests, threading
|
# Import the necessary functions from ollama, Flask, requests, threading
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
from flask import Flask, request, jsonify, send_from_directory, render_template, session, make_response
|
from flask import Flask, request, jsonify, send_from_directory, render_template, session, make_response, Response
|
||||||
from flask_cors import CORS, cross_origin # CORS stands for Cross-Origin Resource Sharing. This is necessary to allow the frontend to make requests to our backend.
|
from flask_cors import CORS, cross_origin # CORS stands for Cross-Origin Resource Sharing. This is necessary to allow the frontend to make requests to our backend.
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
@@ -9,6 +9,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import utils
|
import utils
|
||||||
from utils import GlobalState
|
from utils import GlobalState
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# Create a logger for this module
|
# Create a logger for this module
|
||||||
global_state = GlobalState() # Import the singleton that holds global states (e.g., logger)
|
global_state = GlobalState() # Import the singleton that holds global states (e.g., logger)
|
||||||
@@ -38,11 +39,17 @@ app.config['SESSION_TYPE'] = 'filesystem' # Store sessions on the filesystem
|
|||||||
logger.debug("flask app template folder: %s", app.template_folder)
|
logger.debug("flask app template folder: %s", app.template_folder)
|
||||||
|
|
||||||
@app.route('/')
|
@app.route('/')
|
||||||
def index():
|
def index() -> Response:
|
||||||
"""
|
|
||||||
This route serves index.html to connecting clients
|
|
||||||
"""
|
"""
|
||||||
|
This route serves index.html to connecting clients.
|
||||||
|
|
||||||
|
Initializes a new chat session by clearing the chat history in the session object.
|
||||||
|
Retrieves environment variables for the backend API endpoint, host URL of LLMs, and the selected LLM model.
|
||||||
|
Reads the client HTML template from file and passes it to the index.html template along with other necessary parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: A Flask response containing the rendered index.html template.
|
||||||
|
"""
|
||||||
session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session
|
session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session
|
||||||
logger.debug("Entering route '/'")
|
logger.debug("Entering route '/'")
|
||||||
api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable
|
api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable
|
||||||
@@ -75,11 +82,22 @@ def set_session():
|
|||||||
# return 'No user ID found'
|
# return 'No user ID found'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.route('/<path:filename>')
|
@app.route('/<path:filename>')
|
||||||
def serve_static(filename):
|
def serve_static(filename: str | Path) -> Response:
|
||||||
|
"""
|
||||||
|
Serves a static file from the application's static folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename (str or os.PathLike[str]): The path to the static file, relative to the STATIC_FOLDER directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response: A Flask response containing the contents of the static file.
|
||||||
|
"""
|
||||||
return send_from_directory(app.config['STATIC_FOLDER'], filename)
|
return send_from_directory(app.config['STATIC_FOLDER'], filename)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# CORS(app, resources={
|
# CORS(app, resources={
|
||||||
# r"/api/chat": {
|
# r"/api/chat": {
|
||||||
# "origins": "*",
|
# "origins": "*",
|
||||||
@@ -87,33 +105,52 @@ def serve_static(filename):
|
|||||||
# }
|
# }
|
||||||
# })
|
# })
|
||||||
|
|
||||||
CORS(app, resources={
|
# CORS(app, resources={
|
||||||
r"/api/chat": {
|
# r"/api/chat": {
|
||||||
"origins": "*"
|
# "origins": "*"
|
||||||
}
|
# }
|
||||||
})
|
# })
|
||||||
|
|
||||||
@app.route('/api/tags', methods=['GET'])
|
@app.route('/api/tags', methods=['GET'])
|
||||||
def tag(url = "http://localhost:11434/api/tags", headers = None):
|
def get_tags(url: str = "http://localhost:11434/api/tags", headers: dict = None) -> dict:
|
||||||
"""Get a list of models for the server located at url."""
|
"""
|
||||||
|
Retrieves a list of available models from a server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL of the server to query. Defaults to http://localhost:11434/api/tags.
|
||||||
|
headers (dict, optional): A dictionary of HTTP headers to include in the request. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A JSON response containing a list of available models, or an error message if the request fails.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
requests.exceptions.RequestException: If there is a problem with the request.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
logger.debug(f"url: {url} headers: {headers}")
|
logger.debug(f"url: {url} headers: {headers}")
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
return response.json()
|
return response.json()
|
||||||
# return response
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.error("Request Exception: %s", str(e))
|
logger.error("Request Exception: %s", str(e))
|
||||||
return {'error': 'Failed to process request'}
|
return {'error': 'Failed to process request'}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/chat', methods=['POST'])
|
@app.route('/api/chat', methods=['POST'])
|
||||||
def chat():
|
def chat() -> dict[str, any]:
|
||||||
# def chat(model = "phi3:mini"):
|
|
||||||
"""
|
"""
|
||||||
This function handles the chat. The frontend client (web browser) calls the
|
Handles chat functionality by sending a query to an LLM server and
|
||||||
backend server through this endpoint (/api/chat) that manage queries
|
returning the response.
|
||||||
to the LLM (Large Language Model) server and it also manages the response
|
|
||||||
from the LLM server.
|
This endpoint expects a JSON payload with the following structure:
|
||||||
|
{
|
||||||
|
'query': str,
|
||||||
|
'url_server': str (optional),
|
||||||
|
'model': str (optional)
|
||||||
|
}
|
||||||
|
|
||||||
|
:return: A dictionary containing the LLM's response
|
||||||
"""
|
"""
|
||||||
# Get the message from the JSON in the request body
|
# Get the message from the JSON in the request body
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
@@ -138,26 +175,7 @@ def chat():
|
|||||||
"stream": False
|
"stream": False
|
||||||
}
|
}
|
||||||
url = url_server
|
url = url_server
|
||||||
|
headers = get_auth_headers(url)
|
||||||
# TODO: This section should only run when changing to new endpoint...
|
|
||||||
# begin refactor ###################################
|
|
||||||
headers = { # Set default header
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
found_endpoint = False
|
|
||||||
endpoints = global_state.get_endpoints()
|
|
||||||
for endpoint in endpoints:
|
|
||||||
if endpoint["url"] == url: # Look for endpoint with this url
|
|
||||||
found_endpoint = True
|
|
||||||
if endpoint["provider"] == "ollama": # Currently only supporting ollama servers
|
|
||||||
if "requestOptions" in endpoint: # Check if authentication is needed
|
|
||||||
headers.update({
|
|
||||||
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
|
|
||||||
})
|
|
||||||
if found_endpoint == False:
|
|
||||||
# Raise some error or whatever...
|
|
||||||
logger.debug(f"Host {url} not found")
|
|
||||||
# end refactor ###################################
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
url = url + "/api/generate"
|
url = url + "/api/generate"
|
||||||
@@ -179,24 +197,53 @@ def chat():
|
|||||||
|
|
||||||
|
|
||||||
@app.route('/api/endpoints', methods=['GET'])
|
@app.route('/api/endpoints', methods=['GET'])
|
||||||
def get_endpoints():
|
def get_endpoints() -> str:
|
||||||
# Replace this with your actual logic to fetch endpoint data
|
"""
|
||||||
# endpoints = [
|
Returns a list of available endpoints with their corresponding LLMs.
|
||||||
# {'endpoint': 'endpoint1', 'llm': 'llm1'},
|
|
||||||
# {'endpoint': 'endpoint1', 'llm': 'llm2'},
|
This endpoint fetches all endpoints and their associated LLMs from the global state,
|
||||||
# {'endpoint': 'endpoint2', 'llm': 'llm1'},
|
then returns them as a JSON response.
|
||||||
# {'endpoint': 'endpoint2', 'llm': 'llm3'},
|
|
||||||
# {'endpoint': 'endpoint2', 'llm': 'llm4'},
|
:return: A JSON string representing a dictionary containing a list of dictionaries,
|
||||||
# {'endpoint': 'endpoint3', 'llm': 'llm4'},
|
each representing an endpoint title and supported LLM.
|
||||||
# ]
|
"""
|
||||||
endpoints = [] # List of dictionaries, each of which contains {'endpoint': 'endpoint1', 'llm': 'llm1'}
|
endpoints = [] # List of dictionaries, each of which contains {'title': 'title1', 'llm': 'llm1'}
|
||||||
eps = global_state.get_endpoints()
|
eps = global_state.get_endpoints()
|
||||||
for ep in eps:
|
for ep in eps:
|
||||||
llms = global_state.get_list_of_available_llms(ep)
|
llms = global_state.get_list_of_available_llms(ep)
|
||||||
for llm in llms:
|
for llm in llms:
|
||||||
endpoints.append({'endpoint': ep.get('title'), 'llm': llm})
|
endpoints.append({'title': ep.get('title'), 'llm': llm})
|
||||||
return jsonify(endpoints)
|
return jsonify(endpoints)
|
||||||
|
|
||||||
|
@app.route('/api/select_endpoint_llm', methods=['POST'])
|
||||||
|
def select_endpoint_llm() -> Response:
|
||||||
|
"""
|
||||||
|
Selects the endpoint associated with the tuple (title, LLM) from the request body.
|
||||||
|
|
||||||
|
Request Body:
|
||||||
|
- title: str - The title of the endpoint to select.
|
||||||
|
- llm: str - The LLM to set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON response indicating whether the endpoint and LLM were selected successfully.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there is not exactly one endpoint with the specified title.
|
||||||
|
"""
|
||||||
|
data = request.get_json()
|
||||||
|
title = data['title']
|
||||||
|
llm = data['llm']
|
||||||
|
|
||||||
|
endpoints = global_state.get_endpoints_with_key_value('title', title)
|
||||||
|
if len(endpoints) != 1:
|
||||||
|
raise ValueError(f"Expected exactly one endpoint with title '{title}', found {len(endpoints)}")
|
||||||
|
|
||||||
|
global_state.set_host_url(endpoints[0]['url'])
|
||||||
|
global_state.set_llm(llm)
|
||||||
|
logger.debug(f"Updated to host url {endpoints[0]['url']} and LLM {llm}")
|
||||||
|
|
||||||
|
return jsonify({'message': 'Endpoint and LLM selected successfully'})
|
||||||
|
|
||||||
|
|
||||||
@app.route('/smartassist', methods=["POST"])
|
@app.route('/smartassist', methods=["POST"])
|
||||||
def smartassist():
|
def smartassist():
|
||||||
@@ -216,6 +263,35 @@ def get_response(user_query):
|
|||||||
response = client.generate_response(user_query) # Generate and retrieve the response based on user's query
|
response = client.generate_response(user_query) # Generate and retrieve the response based on user's query
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def get_auth_headers(url: str) -> dict:
|
||||||
|
"""
|
||||||
|
Returns authentication headers for a given URL.
|
||||||
|
|
||||||
|
This function checks if an endpoint with the provided URL exists in the global state,
|
||||||
|
and returns the corresponding authentication headers. If no such endpoint is found,
|
||||||
|
it returns a default header.
|
||||||
|
"""
|
||||||
|
# TODO: The full operation should only have to run when changing to new endpoint.
|
||||||
|
|
||||||
|
# Set default header
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
found_endpoint = False
|
||||||
|
endpoints = global_state.get_endpoints()
|
||||||
|
for endpoint in endpoints:
|
||||||
|
if endpoint["url"] == url: # Look for endpoint with this URL
|
||||||
|
found_endpoint = True
|
||||||
|
#if endpoint["provider"] == "ollama": # Currently only supporting ollama servers - not needed if API the same
|
||||||
|
if "requestOptions" in endpoint: # Check if authentication is needed
|
||||||
|
headers.update({
|
||||||
|
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
|
||||||
|
})
|
||||||
|
if not found_endpoint:
|
||||||
|
logger.debug(f"Host {url} not found")
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
def run_flask(fport=5005):
|
def run_flask(fport=5005):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user