314 lines
12 KiB
Python
314 lines
12 KiB
Python
|
|
# Import the necessary functions from ollama, Flask, requests, threading
|
|
from ollama import Client
|
|
from flask import Flask, request, jsonify, send_from_directory, render_template, session, make_response, Response
|
|
from flask_cors import CORS, cross_origin # CORS stands for Cross-Origin Resource Sharing. This is necessary to allow the frontend to make requests to our backend.
|
|
import requests
|
|
import json
|
|
import logging
|
|
import os
|
|
import utils
|
|
from utils import GlobalState
|
|
from pathlib import Path
|
|
|
|
# Create a logger for this module
|
|
global_state = GlobalState() # Import the singleton that holds global states (e.g., logger)
|
|
logger = global_state.get_logger(__name__) # Logger for this module, inherit properties of the root logger
|
|
|
|
|
|
# Find out the path to current directory according to the Python interpreter (venv)
|
|
logger.debug("Current working directory: %s", os.getcwd())
|
|
|
|
# Initialize a Flask application
|
|
app = Flask(__name__)
|
|
app.config['STATIC_FOLDER'] = 'static' # Adjust if needed
|
|
|
|
# Increase the maximum cookie size
|
|
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
|
|
app.config['SESSION_COOKIE_SIZE_LIMIT'] = 4096 * 2 # Allow up to 8KB cookies
|
|
|
|
# Set the secret key for session management
|
|
secret_key = os.urandom(24)
|
|
app.config['SECRET_KEY'] = secret_key # When do I need this. How is it retained between sessions?
|
|
|
|
# Optionally set other configuration options
|
|
app.config['SESSION_PERMANENT'] = False # Session will expire after each request
|
|
app.config['SESSION_TYPE'] = 'filesystem' # Store sessions on the filesystem
|
|
|
|
|
|
logger.debug("flask app template folder: %s", app.template_folder)
|
|
|
|
@app.route('/')
|
|
def index() -> Response:
|
|
"""
|
|
This route serves index.html to connecting clients.
|
|
|
|
Initializes a new chat session by clearing the chat history in the session object.
|
|
Retrieves environment variables for the backend API endpoint, host URL of LLMs, and the selected LLM model.
|
|
Reads the client HTML template from file and passes it to the index.html template along with other necessary parameters.
|
|
|
|
Returns:
|
|
Response: A Flask response containing the rendered index.html template.
|
|
"""
|
|
session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session
|
|
logger.debug("Entering route '/'")
|
|
api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable
|
|
host_url = global_state.get_host_url()
|
|
use_model = global_state.get_llm()
|
|
logger.debug("Backend API endpoint:\t%s", api_endpoint)
|
|
logger.debug("Host of LLMs:\t\t%s", host_url)
|
|
logger.debug("LLM to use:\t\t\t%s", use_model)
|
|
with open('smartassist/src/html/client.html', 'r') as f:
|
|
client_html = f.read()
|
|
# logger.debug("Client HTML (first few characters): %s", client_html[:50]) # Print to see if it's loading
|
|
# logger.debug("Client HTML (all characters): %s", client_html) # Print to see if it's loading
|
|
|
|
return render_template('index.html', api_endpoint=api_endpoint, use_model = use_model, client_content=client_html)
|
|
|
|
@app.route('/set_session')
|
|
def set_session():
|
|
resp = make_response()
|
|
resp.set_cookie('session', 'some-value', samesite='None', secure=True) # Add SameSite attribute here
|
|
return resp
|
|
|
|
# @app.route('/profile')
|
|
# def profile():
|
|
# # Retrieve data from the session
|
|
# user_id = session.get('user_id')
|
|
|
|
# if user_id:
|
|
# return f'User ID: {user_id}'
|
|
# else:
|
|
# return 'No user ID found'
|
|
|
|
|
|
|
|
@app.route('/<path:filename>')
|
|
def serve_static(filename: str | Path) -> Response:
|
|
"""
|
|
Serves a static file from the application's static folder.
|
|
|
|
Args:
|
|
filename (str or os.PathLike[str]): The path to the static file, relative to the STATIC_FOLDER directory.
|
|
|
|
Returns:
|
|
Response: A Flask response containing the contents of the static file.
|
|
"""
|
|
return send_from_directory(app.config['STATIC_FOLDER'], filename)
|
|
|
|
|
|
|
|
# CORS(app, resources={
|
|
# r"/api/chat": {
|
|
# "origins": "*",
|
|
# "headers": ["Origin", "Content-Type", "Authorization"],
|
|
# }
|
|
# })
|
|
|
|
# CORS(app, resources={
|
|
# r"/api/chat": {
|
|
# "origins": "*"
|
|
# }
|
|
# })
|
|
|
|
@app.route('/api/tags', methods=['GET'])
|
|
def get_tags(url: str = "http://localhost:11434/api/tags", headers: dict = None) -> dict:
|
|
"""
|
|
Retrieves a list of available models from a server.
|
|
|
|
Args:
|
|
url (str): The URL of the server to query. Defaults to http://localhost:11434/api/tags.
|
|
headers (dict, optional): A dictionary of HTTP headers to include in the request. Defaults to None.
|
|
|
|
Returns:
|
|
dict: A JSON response containing a list of available models, or an error message if the request fails.
|
|
|
|
Raises:
|
|
requests.exceptions.RequestException: If there is a problem with the request.
|
|
"""
|
|
try:
|
|
logger.debug(f"url: {url} headers: {headers}")
|
|
response = requests.get(url, headers=headers)
|
|
return response.json()
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error("Request Exception: %s", str(e))
|
|
return {'error': 'Failed to process request'}
|
|
|
|
|
|
|
|
|
|
@app.route('/api/chat', methods=['POST'])
|
|
def chat() -> dict[str, any]:
|
|
"""
|
|
Handles chat functionality by sending a query to an LLM server and
|
|
returning the response.
|
|
|
|
This endpoint expects a JSON payload with the following structure:
|
|
{
|
|
'query': str,
|
|
'url_server': str (optional),
|
|
'model': str (optional)
|
|
}
|
|
|
|
:return: A dictionary containing the LLM's response
|
|
"""
|
|
# Get the message from the JSON in the request body
|
|
data = request.get_json()
|
|
message = data.get('query')
|
|
url_server = data.get('url_server', global_state.get_host_url()) # Use provided URL or current if not provided
|
|
# url_server = data.get('url_server', "https://ollama-test.wara-ops.org/api/generate") # Use provided URL or default
|
|
model = data.get('model', global_state.get_llm()) # Use provided model or current if not provided
|
|
|
|
# Get chat history from session storage (e.g., a dictionary)
|
|
chat_history = session.get('chat_history', [])
|
|
|
|
# Add the new message to the chat history
|
|
chat_history.append({'role': 'user', 'message': message})
|
|
|
|
# Update the session with the new chat history
|
|
session['chat_history'] = chat_history
|
|
|
|
# Create the data dictionary with chat history
|
|
data_to_send = {
|
|
"model": model,
|
|
'prompt': '\n'.join([f"{item['role']}: {item['message']}" for item in chat_history]),
|
|
"stream": False
|
|
}
|
|
url = url_server
|
|
headers = get_auth_headers(url)
|
|
|
|
logger.debug(f"Sending request to:\n\turl:\t{url}\nmodel:\n\t{model}")
|
|
try:
|
|
url = url + "/api/generate"
|
|
logger.debug(f"url: {url} headers: {headers}")
|
|
response = requests.post(url,
|
|
headers=headers,
|
|
data=json.dumps(data_to_send))
|
|
response.raise_for_status() # Raise an exception for bad status codes
|
|
llm_response = response.json()['response'] # Assuming the LLM's response is under 'response' key
|
|
chat_history.append({'role': 'assistant', 'message': llm_response}) # Add assistant response to chat history
|
|
logger.debug(f"Chat History: {chat_history}")
|
|
return response.json()
|
|
except requests.exceptions.RequestException as e:
|
|
logger.error("Request Exception: %s", str(e))
|
|
return jsonify({'error': 'Failed to process request'}), 500
|
|
except json.JSONDecodeError as e:
|
|
logger.error("JSON Decode Error: %s", str(e)) # Corresponds to print(f"JSON Decode Error: {e}")
|
|
return jsonify({'error': 'Invalid JSON response from server'}), 500
|
|
|
|
|
|
@app.route('/api/endpoints', methods=['GET'])
|
|
def get_endpoints() -> str:
|
|
"""
|
|
Returns a list of available endpoints with their corresponding LLMs.
|
|
|
|
This endpoint fetches all endpoints and their associated LLMs from the global state,
|
|
then returns them as a JSON response.
|
|
|
|
:return: A JSON string representing a dictionary containing a list of dictionaries,
|
|
each representing an endpoint title and supported LLM.
|
|
"""
|
|
endpoints = [] # List of dictionaries, each of which contains {'title': 'title1', 'llm': 'llm1'}
|
|
eps = global_state.get_endpoints()
|
|
for ep in eps:
|
|
llms = global_state.get_list_of_available_llms(ep)
|
|
for llm in llms:
|
|
endpoints.append({'title': ep.get('title'), 'llm': llm})
|
|
return jsonify(endpoints)
|
|
|
|
@app.route('/api/select_endpoint_llm', methods=['POST'])
|
|
def select_endpoint_llm() -> Response:
|
|
"""
|
|
Selects the endpoint associated with the tuple (title, LLM) from the request body.
|
|
|
|
Request Body:
|
|
- title: str - The title of the endpoint to select.
|
|
- llm: str - The LLM to set.
|
|
|
|
Returns:
|
|
A JSON response indicating whether the endpoint and LLM were selected successfully.
|
|
|
|
Raises:
|
|
ValueError: If there is not exactly one endpoint with the specified title.
|
|
"""
|
|
data = request.get_json()
|
|
title = data['title']
|
|
llm = data['llm']
|
|
|
|
endpoints = global_state.get_endpoints_with_key_value('title', title)
|
|
if len(endpoints) != 1:
|
|
raise ValueError(f"Expected exactly one endpoint with title '{title}', found {len(endpoints)}")
|
|
|
|
global_state.set_host_url(endpoints[0]['url'])
|
|
global_state.set_llm(llm)
|
|
logger.debug(f"Updated to host url {endpoints[0]['url']} and LLM {llm}")
|
|
|
|
return jsonify({'message': 'Endpoint and LLM selected successfully'})
|
|
|
|
|
|
@app.route('/smartassist', methods=["POST"])
|
|
def smartassist():
|
|
# Extract the query from the incoming JSON data
|
|
data = request.json
|
|
user_query = data['query']
|
|
|
|
# Get the response from the OLLAMA API based on the user's query
|
|
# NOTE: Should we append message history here? Maybe interact with SQLlite?
|
|
response = get_response(user_query)
|
|
|
|
# Return the response as a JSON object in the HTTP response
|
|
return jsonify({"response": response})
|
|
|
|
def get_response(user_query):
|
|
client = Client() # Create a client object for interacting with OLLAMA API
|
|
response = client.generate_response(user_query) # Generate and retrieve the response based on user's query
|
|
return response
|
|
|
|
def get_auth_headers(url: str) -> dict:
|
|
"""
|
|
Returns authentication headers for a given URL.
|
|
|
|
This function checks if an endpoint with the provided URL exists in the global state,
|
|
and returns the corresponding authentication headers. If no such endpoint is found,
|
|
it returns a default header.
|
|
"""
|
|
# TODO: The full operation should only have to run when changing to new endpoint.
|
|
|
|
# Set default header
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
found_endpoint = False
|
|
endpoints = global_state.get_endpoints()
|
|
for endpoint in endpoints:
|
|
if endpoint["url"] == url: # Look for endpoint with this URL
|
|
found_endpoint = True
|
|
#if endpoint["provider"] == "ollama": # Currently only supporting ollama servers - not needed if API the same
|
|
if "requestOptions" in endpoint: # Check if authentication is needed
|
|
headers.update({
|
|
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
|
|
})
|
|
if not found_endpoint:
|
|
logger.debug(f"Host {url} not found")
|
|
|
|
return headers
|
|
|
|
def run_flask(fport=5005):
|
|
"""
|
|
Starts the Flask server
|
|
"""
|
|
# Flask endpoint for user interaction
|
|
logger.debug("Entering run_flask()")
|
|
# app.run(port = str(str(fport)), debug=False)
|
|
app.run(port = str(str(fport)), debug=True)
|
|
# app.run(port=5000, debug=True, use_reloader=False)
|
|
logger.debug("Exiting run_flask()")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Run the Flask application
|
|
run_flask()
|
|
|
|
|