diff --git a/main.py b/main.py index 6f352b8..179ac45 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ import json import os import pprint import queue +import random import re import secrets import sqlite3 @@ -11,24 +12,32 @@ import time import uuid from datetime import datetime from typing import List, Optional +import enum import GPUtil import ollama import psutil import structlog +import logging from flask import Flask, g, jsonify, request, send_from_directory -from flask_openapi3 import Info, OpenAPI from flask_socketio import SocketIO, emit from pydantic import BaseModel from models import model_manager from tools import DefaultToolManager +# Configure logging +logging.basicConfig(level=logging.INFO, format="%(message)s") +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) logger = structlog.get_logger() - # Configuration setup CONFIG_FILE = "config.ini" +# Add this near the top of the file, after imports +processing_thread = None +processing_thread_started = False + def create_default_config(): config = configparser.ConfigParser() @@ -65,8 +74,7 @@ ENABLE_API_ENDPOINTS = config["SERVER_FEATURES"].getboolean("EnableAPIEndpoints" PRIMARY_MODEL = config["MODEL"]["PrimaryModel"] UPDATE_INTERVAL = config["PERFORMANCE"].getfloat("UpdateInterval") -openapi = OpenAPI(__name__, info=Info(title="LLM Chat Server", version="1.0.0")) -app = openapi +app = Flask(__name__) socketio = SocketIO(app, cors_allowed_origins="*") tool_manager = DefaultToolManager() @@ -88,6 +96,12 @@ def close_connection(exception): db.close() +class QueryStatus(enum.Enum): + QUEUED = "queued" + PROCESSING = "processing" + DONE = "done" + + def init_db(): with app.app_context(): db = get_db() @@ -105,6 +119,7 @@ def init_db(): timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, query TEXT NOT NULL, api_key_id INTEGER, + status TEXT NOT NULL, conversation_history TEXT, FOREIGN KEY (api_key_id) REFERENCES Keys (id) ) @@ -126,6 +141,7 @@ CREATE TABLE IF NOT EXISTS Queries ( timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, query TEXT NOT NULL, api_key_id INTEGER, + status TEXT NOT NULL, conversation_history TEXT, FOREIGN KEY (api_key_id) REFERENCES Keys (id) ); @@ -432,14 +448,9 @@ class QueryStatusResponse(BaseModel): @app.post( - "/api/v1/query", - responses={ - "200": QueryResponse, - "401": {"description": "Unauthorized"}, - "500": {"description": "Internal Server Error"}, - }, + "/api/v1/query" ) -def api_query(body: QueryRequest): +def api_query(): """ Submit a new query to the LLM Chat Server. @@ -462,31 +473,31 @@ def api_query(body: QueryRequest): if not api_key_id: return jsonify({"error": "Invalid API key"}), 401 - user_input = body.message + data = request.get_json() + if not data or 'message' not in data: + return jsonify({"error": "Invalid request body"}), 400 + + user_input = data['message'] query_id = str(uuid.uuid4()) try: db = get_db() cursor = db.cursor() cursor.execute( - "INSERT INTO Queries (id, ip, query, api_key_id) VALUES (?, ?, ?, ?)", - (query_id, request.remote_addr, user_input, api_key_id) + "INSERT INTO Queries (id, ip, query, api_key_id, status) VALUES (?, ?, ?, ?, ?)", + (query_id, request.remote_addr, user_input, api_key_id, QueryStatus.QUEUED.value) ) db.commit() + logger.info(f"Added new query to database: {query_id}") return jsonify({"query_id": query_id}) except Exception as e: - logger.exception("Error during API query processing", error=str(e)) + logger.exception(f"Error during API query processing: {str(e)}") return jsonify({"error": str(e)}), 500 @app.get( - "/api/v1/query_status/", - responses={ - "200": QueryStatusResponse, - "404": {"description": "Query not found"}, - "500": {"description": "Internal Server Error"}, - }, + "/api/v1/query_status/" ) def get_query_status(query_id: str): """ @@ -509,23 +520,19 @@ def get_query_status(query_id: str): try: db = get_db() cursor = db.cursor() - cursor.execute("SELECT conversation_history FROM Queries WHERE id = ?", (query_id,)) + cursor.execute("SELECT status, conversation_history FROM Queries WHERE id = ?", (query_id,)) result = cursor.fetchone() if result is None: return jsonify({"error": "Query not found"}), 404 - conversation_history = result[0] + status, conversation_history = result - if conversation_history is None: - return jsonify({"status": "processing"}), 202 - else: - return jsonify( - { - "status": "completed", - "conversation_history": json.loads(conversation_history), - } - ) + response = {"status": status} + if status == QueryStatus.DONE.value: + response["conversation_history"] = json.loads(conversation_history) + + return jsonify(response) except Exception as e: logger.exception("Error retrieving query status", error=str(e)) return jsonify({"error": str(e)}), 500 @@ -600,30 +607,62 @@ def answer_question_tools_api( def process_queries(): + logger.info("Query processing thread started") with app.app_context(): while True: try: db = get_db() cursor = db.cursor() + + # First, check if there are any PROCESSING queries cursor.execute( - "SELECT id, query FROM Queries WHERE conversation_history IS NULL ORDER BY timestamp ASC LIMIT 1" + "SELECT id FROM Queries WHERE status = ? LIMIT 1", + (QueryStatus.PROCESSING.value,) + ) + processing_query = cursor.fetchone() + if processing_query: + logger.info(f"Found processing query: {processing_query[0]}. Waiting...") + db.commit() + time.sleep(10) + continue + + # If no PROCESSING queries, get the oldest QUEUED query + cursor.execute( + "SELECT id, query FROM Queries WHERE status = ? ORDER BY timestamp ASC LIMIT 1", + (QueryStatus.QUEUED.value,) ) result = cursor.fetchone() if result: query_id, user_input = result - conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}] - final_conversation_history = answer_question_tools_api(user_input, conversation_history) - + logger.info(f"Processing query: {query_id}") + + # Update status to PROCESSING cursor.execute( - "UPDATE Queries SET conversation_history = ? WHERE id = ?", - (json.dumps(final_conversation_history), query_id) + "UPDATE Queries SET status = ? WHERE id = ?", + (QueryStatus.PROCESSING.value, query_id) ) db.commit() + logger.info(f"Updated query {query_id} status to PROCESSING") + + conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}] + logger.info(f"Starting answer_question_tools_api for query {query_id}") + final_conversation_history = answer_question_tools_api(user_input, conversation_history) + logger.info(f"Finished answer_question_tools_api for query {query_id}") + + # Update with final result and set status to DONE + db.execute("BEGIN TRANSACTION") + cursor.execute( + "UPDATE Queries SET conversation_history = ?, status = ? WHERE id = ?", + (json.dumps(final_conversation_history), QueryStatus.DONE.value, query_id) + ) + db.commit() + logger.info(f"Updated query {query_id} status to DONE") else: - time.sleep(1) # Wait for 1 second before checking again if no queries are found + logger.info("No queued queries found. Waiting...") + time.sleep(random.uniform(5, 10)) # Wait for 5 seconds before checking again if no queries are found except Exception as e: - logger.exception("Error processing query", error=str(e)) + logger.exception(f"Error processing query: {str(e)}") time.sleep(1) # Wait for 1 second before retrying in case of an error @@ -638,14 +677,9 @@ class GenerateKeyResponse(BaseModel): @app.post( - "/admin/generate_key", - responses={ - "200": GenerateKeyResponse, - "401": {"description": "Unauthorized"}, - "500": {"description": "Internal Server Error"}, - }, + "/admin/generate_key" ) -def generate_api_key(body: GenerateKeyRequest): +def generate_api_key(): """ Generate a new API key for a user. @@ -661,7 +695,11 @@ def generate_api_key(body: GenerateKeyRequest): if not admin_key or admin_key != ADMIN_KEY: return jsonify({"error": "Invalid admin key"}), 401 - username = body.username + data = request.get_json() + if not data or 'username' not in data: + return jsonify({"error": "Invalid request body"}), 400 + + username = data['username'] api_key = secrets.token_urlsafe(32) try: @@ -679,16 +717,30 @@ def generate_api_key(body: GenerateKeyRequest): return jsonify({"error": str(e)}), 500 +def start_processing_thread(): + global processing_thread, processing_thread_started + if not processing_thread_started: + processing_thread = threading.Thread(target=process_queries, daemon=True) + processing_thread.start() + processing_thread_started = True + logger.info("Query processing thread started") + + +# Replace the if __main__ block with this: if __name__ == "__main__": logger.info("Starting LLM Chat Server") init_db() # Initialize the database if ENABLE_FRONTEND or ENABLE_CHAT_ENDPOINTS: threading.Thread(target=send_system_resources, daemon=True).start() + logger.info("System resources thread started") if ENABLE_API_ENDPOINTS: - threading.Thread( - target=lambda: app.app_context().push() and process_queries(), daemon=True - ).start() + start_processing_thread() - socketio.run(app, debug=True, host="0.0.0.0", port=5001) \ No newline at end of file + logger.info("Starting Flask application") + socketio.run(app, debug=True, host="0.0.0.0", port=5001) +else: + # This will run when the module is imported, e.g., by the reloader + if ENABLE_API_ENDPOINTS: + start_processing_thread() \ No newline at end of file diff --git a/schema.sql b/schema.sql index 1e5dbfd..5a0a29e 100644 --- a/schema.sql +++ b/schema.sql @@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS Queries ( timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, query TEXT NOT NULL, api_key_id INTEGER, + status TEXT NOT NULL, conversation_history TEXT, FOREIGN KEY (api_key_id) REFERENCES Keys (id) ); \ No newline at end of file