Nu används informationen från konfigurationsfilen när url för LLM-server bestäms
This commit is contained in:
+34
-15
@@ -12,7 +12,7 @@ from utils import GlobalState
|
||||
|
||||
# Create a logger for this module
|
||||
global_state = GlobalState() # Import the singleton that holds global states (e.g., logger)
|
||||
logger = global_state.getLogger(__name__) # Logger for this module, inherit properties of the root logger
|
||||
logger = global_state.get_logger(__name__) # Logger for this module, inherit properties of the root logger
|
||||
|
||||
|
||||
# Find out the path to current directory according to the Python interpreter (venv)
|
||||
@@ -45,13 +45,15 @@ def index():
|
||||
|
||||
session['chat_history'] = [] # The session object (actually, a dictonary) holds the chat session
|
||||
logger.debug("Entering route '/'")
|
||||
# api_endpoint = os.environ['BE_API_ENDPOINT'] # Retrieve the environment variable
|
||||
api_endpoint = global_state.get_backend_api_ep() # Retrieve the environment variable
|
||||
logger.debug("Backend API endpoint: %s", api_endpoint)
|
||||
host_url = global_state.get_host_url()
|
||||
use_model = global_state.get_llm()
|
||||
logger.debug("Backend API endpoint:\t%s", api_endpoint)
|
||||
logger.debug("Host of LLMs:\t\t%s", host_url)
|
||||
logger.debug("LLM to use:\t\t\t%s", use_model)
|
||||
with open('smartassist/src/html/client.html', 'r') as f:
|
||||
client_html = f.read()
|
||||
logger.debug("Client HTML (first few characters): %s", client_html[:50]) # Print to see if it's loading
|
||||
# logger.debug("Client HTML (first few characters): %s", client_html[:50]) # Print to see if it's loading
|
||||
# logger.debug("Client HTML (all characters): %s", client_html) # Print to see if it's loading
|
||||
|
||||
return render_template('index.html', api_endpoint=api_endpoint, use_model = use_model, client_content=client_html)
|
||||
@@ -93,7 +95,6 @@ CORS(app, resources={
|
||||
|
||||
@app.route('/api/tags', methods=['GET'])
|
||||
def tag(url = "http://localhost:11434/api/tags", headers = None):
|
||||
# def tag(url = "http://localhost:11434/api/tags", headers = {"Content-Type": "application/json"}):
|
||||
"""Get a list of models for the server located at url."""
|
||||
try:
|
||||
logger.debug(f"url: {url} headers: {headers}")
|
||||
@@ -106,8 +107,8 @@ def tag(url = "http://localhost:11434/api/tags", headers = None):
|
||||
|
||||
|
||||
@app.route('/api/chat', methods=['POST'])
|
||||
def chat(model = "phi3:mini"):
|
||||
# def chat(url_server = "http://localhost:11434/api/generate", model = "phi3:mini"):
|
||||
def chat():
|
||||
# def chat(model = "phi3:mini"):
|
||||
"""
|
||||
This function handles the chat. The frontend client (web browser) calls the
|
||||
backend server through this endpoint (/api/chat) that manage queries
|
||||
@@ -117,8 +118,9 @@ def chat(model = "phi3:mini"):
|
||||
# Get the message from the JSON in the request body
|
||||
data = request.get_json()
|
||||
message = data.get('query')
|
||||
url_server = data.get('url_server', "https://ollama-test.wara-ops.org/api/generate") # Use provided URL or default
|
||||
model = data.get('model', model) # Use provided model or default if not provided
|
||||
url_server = data.get('url_server', global_state.get_host_url()) # Use provided URL or current if not provided
|
||||
# url_server = data.get('url_server', "https://ollama-test.wara-ops.org/api/generate") # Use provided URL or default
|
||||
model = data.get('model', global_state.get_llm()) # Use provided model or current if not provided
|
||||
|
||||
# Get chat history from session storage (e.g., a dictionary)
|
||||
chat_history = session.get('chat_history', [])
|
||||
@@ -135,13 +137,30 @@ def chat(model = "phi3:mini"):
|
||||
'prompt': '\n'.join([f"{item['role']}: {item['message']}" for item in chat_history]),
|
||||
"stream": False
|
||||
}
|
||||
|
||||
url = url_server
|
||||
|
||||
# TODO: This section should only run when changing to new endpoint...
|
||||
# begin refactor ###################################
|
||||
headers = { # Set default header
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
found_endpoint = False
|
||||
endpoints = global_state.get_endpoints()
|
||||
for endpoint in endpoints:
|
||||
if endpoint["url"] == url: # Look for endpoint with this url
|
||||
found_endpoint = True
|
||||
if endpoint["provider"] == "ollama": # Currently only supporting ollama servers
|
||||
if "requestOptions" in endpoint: # Check if authentication is needed
|
||||
headers.update({
|
||||
"Authorization": endpoint["requestOptions"]["headers"]["Authorization"]
|
||||
})
|
||||
if found_endpoint == False:
|
||||
# Raise some error or whatever...
|
||||
logger.debug(f"Host {url} not found")
|
||||
# end refactor ###################################
|
||||
|
||||
try:
|
||||
url = url_server
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Basic ZWNzanBlcjoxM2JjMTU4ZDhmNmY5YTU4YTkzZDNmY2I="
|
||||
}
|
||||
url = url + "/api/generate"
|
||||
logger.debug(f"url: {url} headers: {headers}")
|
||||
response = requests.post(url,
|
||||
headers=headers,
|
||||
|
||||
Reference in New Issue
Block a user