Nu används informationen från konfigurationsfilen när url för LLM-server bestäms

This commit is contained in:
2024-08-05 00:50:45 +02:00
parent cb1caceee7
commit 8452c7569b
+33 -14
View File
@@ -12,7 +12,7 @@ from utils import GlobalState
# Create a logger for this module
global_state = GlobalState() # Import the singleton that holds global states (e.g., logger)
logger = global_state.getLogger(__name__) # Logger for this module, inherit properties of the root logger
logger = global_state.get_logger(__name__) # Logger for this module, inherit properties of the root logger
# Find out the path to current directory according to the Python interpreter (venv)
@@ -45,13 +45,15 @@ def index():
session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session
logger.debug("Entering route '/'")
# api_endpoint = os.environ['BE_API_ENDPOINT'] # Retrieve the environment variable
api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable
logger.debug("Backend API endpoint: %s", api_endpoint)
host_url = global_state.get_host_url()
use_model = global_state.get_llm()
logger.debug("Backend API endpoint:\t%s", api_endpoint)
logger.debug("Host of LLMs:\t\t%s", host_url)
logger.debug("LLM to use:\t\t\t%s", use_model)
with open('smartassist/src/html/client.html', 'r') as f:
client_html = f.read()
logger.debug("Client HTML (first few characters): %s", client_html[:50]) # Print to see if it's loading
# logger.debug("Client HTML (first few characters): %s", client_html[:50]) # Print to see if it's loading
# logger.debug("Client HTML (all characters): %s", client_html) # Print to see if it's loading
return render_template('index.html', api_endpoint=api_endpoint, use_model = use_model, client_content=client_html)
@@ -93,7 +95,6 @@ CORS(app, resources={
@app.route('/api/tags', methods=['GET'])
def tag(url = "http://localhost:11434/api/tags", headers = None):
# def tag(url = "http://localhost:11434/api/tags", headers = {"Content-Type": "application/json"}):
"""Get a list of models for the server located at url."""
try:
logger.debug(f"url: {url} headers: {headers}")
@@ -106,8 +107,8 @@ def tag(url = "http://localhost:11434/api/tags", headers = None):
@app.route('/api/chat', methods=['POST'])
def chat(model = "phi3:mini"):
# def chat(url_server = "http://localhost:11434/api/generate", model = "phi3:mini"):
def chat():
# def chat(model = "phi3:mini"):
"""
This function handles the chat. The frontend client (web browser) calls the
backend server through this endpoint (/api/chat) that manage queries
@@ -117,8 +118,9 @@ def chat(model = "phi3:mini"):
# Get the message from the JSON in the request body
data = request.get_json()
message = data.get('query')
url_server = data.get('url_server', "https://ollama-test.wara-ops.org/api/generate") # Use provided URL or default
model = data.get('model', model) # Use provided model or default if not provided
url_server = data.get('url_server', global_state.get_host_url()) # Use provided URL or current if not provided
# url_server = data.get('url_server', "https://ollama-test.wara-ops.org/api/generate") # Use provided URL or default
model = data.get('model', global_state.get_llm()) # Use provided model or current if not provided
# Get chat history from session storage (e.g., a dictionary)
chat_history = session.get('chat_history', [])
@@ -135,13 +137,30 @@ def chat(model = "phi3:mini"):
'prompt': '\n'.join([f"{item['role']}: {item['message']}" for item in chat_history]),
"stream": False
}
url = url_server
# TODO: This section should only run when changing to new endpoint...
# begin refactor ###################################
headers = { # Set default header
"Content-Type": "application/json",
}
found_endpoint = False
endpoints = global_state.get_endpoints()
for endpoint in endpoints:
if endpoint["url"] == url: # Look for endpoint with this url
found_endpoint = True
if endpoint["provider"] == "ollama": # Currently only supporting ollama servers
if "requestOptions" in endpoint: # Check if authentication is needed
headers.update({
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
})
if found_endpoint == False:
# Raise some error or whatever...
logger.debug(f"Host {url} not found")
# end refactor ###################################
try:
url = url_server
headers = {
"Content-Type": "application/json",
"Authorization": "Basic ZWNzanBlcjoxM2JjMTU4ZDhmNmY5YTU4YTkzZDNmY2I="
}
url = url + "/api/generate"
logger.debug(f"url: {url} headers: {headers}")
response = requests.post(url,
headers=headers,