Docstrings och type hinting. Lagt till route för (api/select_endpoint_llm. Brutit ut header-generering till get_auth_headers()
This commit is contained in:
+128
-52
@@ -1,7 +1,7 @@
|
||||
|
||||
# Import the necessary functions from ollama, Flask, requests, threading
|
||||
from ollama import Client
|
||||
from flask import Flask, request, jsonify, send_from_directory, render_template, session, make_response
|
||||
from flask import Flask, request, jsonify, send_from_directory, render_template, session, make_response, Response
|
||||
from flask_cors import CORS, cross_origin # CORS stands for Cross-Origin Resource Sharing. This is necessary to allow the frontend to make requests to our backend.
|
||||
import requests
|
||||
import json
|
||||
@@ -9,6 +9,7 @@ import logging
|
||||
import os
|
||||
import utils
|
||||
from utils import GlobalState
|
||||
from pathlib import Path
|
||||
|
||||
# Create a logger for this module
|
||||
global_state = GlobalState() # Import the singleton that holds global states (e.g., logger)
|
||||
@@ -38,11 +39,17 @@ app.config['SESSION_TYPE'] = 'filesystem' # Store sessions on the filesystem
|
||||
logger.debug("flask app template folder: %s", app.template_folder)
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
"""
|
||||
This route serves index.html to connecting clients
|
||||
def index() -> Response:
|
||||
"""
|
||||
This route serves index.html to connecting clients.
|
||||
|
||||
Initializes a new chat session by clearing the chat history in the session object.
|
||||
Retrieves environment variables for the backend API endpoint, host URL of LLMs, and the selected LLM model.
|
||||
Reads the client HTML template from file and passes it to the index.html template along with other necessary parameters.
|
||||
|
||||
Returns:
|
||||
Response: A Flask response containing the rendered index.html template.
|
||||
"""
|
||||
session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session
|
||||
logger.debug("Entering route '/'")
|
||||
api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable
|
||||
@@ -75,11 +82,22 @@ def set_session():
|
||||
# return 'No user ID found'
|
||||
|
||||
|
||||
|
||||
@app.route('/<path:filename>')
|
||||
def serve_static(filename):
|
||||
def serve_static(filename: str | Path) -> Response:
|
||||
"""
|
||||
Serves a static file from the application's static folder.
|
||||
|
||||
Args:
|
||||
filename (str or os.PathLike[str]): The path to the static file, relative to the STATIC_FOLDER directory.
|
||||
|
||||
Returns:
|
||||
Response: A Flask response containing the contents of the static file.
|
||||
"""
|
||||
return send_from_directory(app.config['STATIC_FOLDER'], filename)
|
||||
|
||||
|
||||
|
||||
# CORS(app, resources={
|
||||
# r"/api/chat": {
|
||||
# "origins": "*",
|
||||
@@ -87,33 +105,52 @@ def serve_static(filename):
|
||||
# }
|
||||
# })
|
||||
|
||||
CORS(app, resources={
|
||||
r"/api/chat": {
|
||||
"origins": "*"
|
||||
}
|
||||
})
|
||||
# CORS(app, resources={
|
||||
# r"/api/chat": {
|
||||
# "origins": "*"
|
||||
# }
|
||||
# })
|
||||
|
||||
@app.route('/api/tags', methods=['GET'])
|
||||
def tag(url = "http://localhost:11434/api/tags", headers = None):
|
||||
"""Get a list of models for the server located at url."""
|
||||
def get_tags(url: str = "http://localhost:11434/api/tags", headers: dict = None) -> dict:
|
||||
"""
|
||||
Retrieves a list of available models from a server.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the server to query. Defaults to http://localhost:11434/api/tags.
|
||||
headers (dict, optional): A dictionary of HTTP headers to include in the request. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: A JSON response containing a list of available models, or an error message if the request fails.
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: If there is a problem with the request.
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"url: {url} headers: {headers}")
|
||||
response = requests.get(url, headers=headers)
|
||||
return response.json()
|
||||
# return response
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("Request Exception: %s", str(e))
|
||||
logger.error("Request Exception: %s", str(e))
|
||||
return {'error': 'Failed to process request'}
|
||||
|
||||
|
||||
|
||||
|
||||
@app.route('/api/chat', methods=['POST'])
|
||||
def chat():
|
||||
# def chat(model = "phi3:mini"):
|
||||
def chat() -> dict[str, any]:
|
||||
"""
|
||||
This function handles the chat. The frontend client (web browser) calls the
|
||||
backend server through this endpoint (/api/chat) that manage queries
|
||||
to the LLM (Large Language Model) server and it also manages the response
|
||||
from the LLM server.
|
||||
Handles chat functionality by sending a query to an LLM server and
|
||||
returning the response.
|
||||
|
||||
This endpoint expects a JSON payload with the following structure:
|
||||
{
|
||||
'query': str,
|
||||
'url_server': str (optional),
|
||||
'model': str (optional)
|
||||
}
|
||||
|
||||
:return: A dictionary containing the LLM's response
|
||||
"""
|
||||
# Get the message from the JSON in the request body
|
||||
data = request.get_json()
|
||||
@@ -138,26 +175,7 @@ def chat():
|
||||
"stream": False
|
||||
}
|
||||
url = url_server
|
||||
|
||||
# TODO: This section should only run when changing to new endpoint...
|
||||
# begin refactor ###################################
|
||||
headers = { # Set default header
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
found_endpoint = False
|
||||
endpoints = global_state.get_endpoints()
|
||||
for endpoint in endpoints:
|
||||
if endpoint["url"] == url: # Look for endpoint with this url
|
||||
found_endpoint = True
|
||||
if endpoint["provider"] == "ollama": # Currently only supporting ollama servers
|
||||
if "requestOptions" in endpoint: # Check if authentication is needed
|
||||
headers.update({
|
||||
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
|
||||
})
|
||||
if found_endpoint == False:
|
||||
# Raise some error or whatever...
|
||||
logger.debug(f"Host {url} not found")
|
||||
# end refactor ###################################
|
||||
headers = get_auth_headers(url)
|
||||
|
||||
try:
|
||||
url = url + "/api/generate"
|
||||
@@ -179,24 +197,53 @@ def chat():
|
||||
|
||||
|
||||
@app.route('/api/endpoints', methods=['GET'])
|
||||
def get_endpoints():
|
||||
# Replace this with your actual logic to fetch endpoint data
|
||||
# endpoints = [
|
||||
# {'endpoint': 'endpoint1', 'llm': 'llm1'},
|
||||
# {'endpoint': 'endpoint1', 'llm': 'llm2'},
|
||||
# {'endpoint': 'endpoint2', 'llm': 'llm1'},
|
||||
# {'endpoint': 'endpoint2', 'llm': 'llm3'},
|
||||
# {'endpoint': 'endpoint2', 'llm': 'llm4'},
|
||||
# {'endpoint': 'endpoint3', 'llm': 'llm4'},
|
||||
# ]
|
||||
endpoints = [] # List of dictionaries, each of which contains {'endpoint': 'endpoint1', 'llm': 'llm1'}
|
||||
def get_endpoints() -> str:
|
||||
"""
|
||||
Returns a list of available endpoints with their corresponding LLMs.
|
||||
|
||||
This endpoint fetches all endpoints and their associated LLMs from the global state,
|
||||
then returns them as a JSON response.
|
||||
|
||||
:return: A JSON string representing a dictionary containing a list of dictionaries,
|
||||
each representing an endpoint title and supported LLM.
|
||||
"""
|
||||
endpoints = [] # List of dictionaries, each of which contains {'title': 'title1', 'llm': 'llm1'}
|
||||
eps = global_state.get_endpoints()
|
||||
for ep in eps:
|
||||
llms = global_state.get_list_of_available_llms(ep)
|
||||
for llm in llms:
|
||||
endpoints.append({'endpoint': ep.get('title'), 'llm': llm})
|
||||
endpoints.append({'title': ep.get('title'), 'llm': llm})
|
||||
return jsonify(endpoints)
|
||||
|
||||
@app.route('/api/select_endpoint_llm', methods=['POST'])
|
||||
def select_endpoint_llm() -> Response:
|
||||
"""
|
||||
Selects the endpoint associated with the tuple (title, LLM) from the request body.
|
||||
|
||||
Request Body:
|
||||
- title: str - The title of the endpoint to select.
|
||||
- llm: str - The LLM to set.
|
||||
|
||||
Returns:
|
||||
A JSON response indicating whether the endpoint and LLM were selected successfully.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is not exactly one endpoint with the specified title.
|
||||
"""
|
||||
data = request.get_json()
|
||||
title = data['title']
|
||||
llm = data['llm']
|
||||
|
||||
endpoints = global_state.get_endpoints_with_key_value('title', title)
|
||||
if len(endpoints) != 1:
|
||||
raise ValueError(f"Expected exactly one endpoint with title '{title}', found {len(endpoints)}")
|
||||
|
||||
global_state.set_host_url(endpoints[0]['url'])
|
||||
global_state.set_llm(llm)
|
||||
logger.debug(f"Updated to host url {endpoints[0]['url']} and LLM {llm}")
|
||||
|
||||
return jsonify({'message': 'Endpoint and LLM selected successfully'})
|
||||
|
||||
|
||||
@app.route('/smartassist', methods=["POST"])
|
||||
def smartassist():
|
||||
@@ -216,6 +263,35 @@ def get_response(user_query):
|
||||
response = client.generate_response(user_query) # Generate and retrieve the response based on user's query
|
||||
return response
|
||||
|
||||
def get_auth_headers(url: str) -> dict:
|
||||
"""
|
||||
Returns authentication headers for a given URL.
|
||||
|
||||
This function checks if an endpoint with the provided URL exists in the global state,
|
||||
and returns the corresponding authentication headers. If no such endpoint is found,
|
||||
it returns a default header.
|
||||
"""
|
||||
# TODO: The full operation should only have to run when changing to new endpoint.
|
||||
|
||||
# Set default header
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
found_endpoint = False
|
||||
endpoints = global_state.get_endpoints()
|
||||
for endpoint in endpoints:
|
||||
if endpoint["url"] == url: # Look for endpoint with this URL
|
||||
found_endpoint = True
|
||||
#if endpoint["provider"] == "ollama": # Currently only supporting ollama servers - not needed if API the same
|
||||
if "requestOptions" in endpoint: # Check if authentication is needed
|
||||
headers.update({
|
||||
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
|
||||
})
|
||||
if not found_endpoint:
|
||||
logger.debug(f"Host {url} not found")
|
||||
|
||||
return headers
|
||||
|
||||
def run_flask(fport=5005):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user