This commit is contained in:
Tanishq Dubey 2024-10-06 23:19:57 -04:00
parent 0c9059bcbc
commit f56d9b59c0
2 changed files with 105 additions and 52 deletions

156
main.py
View File

@ -3,6 +3,7 @@ import json
import os
import pprint
import queue
import random
import re
import secrets
import sqlite3
@ -11,24 +12,32 @@ import time
import uuid
from datetime import datetime
from typing import List, Optional
import enum
import GPUtil
import ollama
import psutil
import structlog
import logging
from flask import Flask, g, jsonify, request, send_from_directory
from flask_openapi3 import Info, OpenAPI
from flask_socketio import SocketIO, emit
from pydantic import BaseModel
from models import model_manager
from tools import DefaultToolManager
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(message)s")
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
logger = structlog.get_logger()
# Configuration setup
CONFIG_FILE = "config.ini"
# Add this near the top of the file, after imports
processing_thread = None
processing_thread_started = False
def create_default_config():
config = configparser.ConfigParser()
@ -65,8 +74,7 @@ ENABLE_API_ENDPOINTS = config["SERVER_FEATURES"].getboolean("EnableAPIEndpoints"
PRIMARY_MODEL = config["MODEL"]["PrimaryModel"]
UPDATE_INTERVAL = config["PERFORMANCE"].getfloat("UpdateInterval")
openapi = OpenAPI(__name__, info=Info(title="LLM Chat Server", version="1.0.0"))
app = openapi
app = Flask(__name__)
socketio = SocketIO(app, cors_allowed_origins="*")
tool_manager = DefaultToolManager()
@ -88,6 +96,12 @@ def close_connection(exception):
db.close()
class QueryStatus(enum.Enum):
QUEUED = "queued"
PROCESSING = "processing"
DONE = "done"
def init_db():
with app.app_context():
db = get_db()
@ -105,6 +119,7 @@ def init_db():
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
query TEXT NOT NULL,
api_key_id INTEGER,
status TEXT NOT NULL,
conversation_history TEXT,
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
)
@ -126,6 +141,7 @@ CREATE TABLE IF NOT EXISTS Queries (
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
query TEXT NOT NULL,
api_key_id INTEGER,
status TEXT NOT NULL,
conversation_history TEXT,
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
);
@ -432,14 +448,9 @@ class QueryStatusResponse(BaseModel):
@app.post(
"/api/v1/query",
responses={
"200": QueryResponse,
"401": {"description": "Unauthorized"},
"500": {"description": "Internal Server Error"},
},
"/api/v1/query"
)
def api_query(body: QueryRequest):
def api_query():
"""
Submit a new query to the LLM Chat Server.
@ -462,31 +473,31 @@ def api_query(body: QueryRequest):
if not api_key_id:
return jsonify({"error": "Invalid API key"}), 401
user_input = body.message
data = request.get_json()
if not data or 'message' not in data:
return jsonify({"error": "Invalid request body"}), 400
user_input = data['message']
query_id = str(uuid.uuid4())
try:
db = get_db()
cursor = db.cursor()
cursor.execute(
"INSERT INTO Queries (id, ip, query, api_key_id) VALUES (?, ?, ?, ?)",
(query_id, request.remote_addr, user_input, api_key_id)
"INSERT INTO Queries (id, ip, query, api_key_id, status) VALUES (?, ?, ?, ?, ?)",
(query_id, request.remote_addr, user_input, api_key_id, QueryStatus.QUEUED.value)
)
db.commit()
logger.info(f"Added new query to database: {query_id}")
return jsonify({"query_id": query_id})
except Exception as e:
logger.exception("Error during API query processing", error=str(e))
logger.exception(f"Error during API query processing: {str(e)}")
return jsonify({"error": str(e)}), 500
@app.get(
"/api/v1/query_status/<string:query_id>",
responses={
"200": QueryStatusResponse,
"404": {"description": "Query not found"},
"500": {"description": "Internal Server Error"},
},
"/api/v1/query_status/<string:query_id>"
)
def get_query_status(query_id: str):
"""
@ -509,23 +520,19 @@ def get_query_status(query_id: str):
try:
db = get_db()
cursor = db.cursor()
cursor.execute("SELECT conversation_history FROM Queries WHERE id = ?", (query_id,))
cursor.execute("SELECT status, conversation_history FROM Queries WHERE id = ?", (query_id,))
result = cursor.fetchone()
if result is None:
return jsonify({"error": "Query not found"}), 404
conversation_history = result[0]
status, conversation_history = result
if conversation_history is None:
return jsonify({"status": "processing"}), 202
else:
return jsonify(
{
"status": "completed",
"conversation_history": json.loads(conversation_history),
}
)
response = {"status": status}
if status == QueryStatus.DONE.value:
response["conversation_history"] = json.loads(conversation_history)
return jsonify(response)
except Exception as e:
logger.exception("Error retrieving query status", error=str(e))
return jsonify({"error": str(e)}), 500
@ -600,30 +607,62 @@ def answer_question_tools_api(
def process_queries():
logger.info("Query processing thread started")
with app.app_context():
while True:
try:
db = get_db()
cursor = db.cursor()
# First, check if there are any PROCESSING queries
cursor.execute(
"SELECT id, query FROM Queries WHERE conversation_history IS NULL ORDER BY timestamp ASC LIMIT 1"
"SELECT id FROM Queries WHERE status = ? LIMIT 1",
(QueryStatus.PROCESSING.value,)
)
processing_query = cursor.fetchone()
if processing_query:
logger.info(f"Found processing query: {processing_query[0]}. Waiting...")
db.commit()
time.sleep(10)
continue
# If no PROCESSING queries, get the oldest QUEUED query
cursor.execute(
"SELECT id, query FROM Queries WHERE status = ? ORDER BY timestamp ASC LIMIT 1",
(QueryStatus.QUEUED.value,)
)
result = cursor.fetchone()
if result:
query_id, user_input = result
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
logger.info(f"Processing query: {query_id}")
# Update status to PROCESSING
cursor.execute(
"UPDATE Queries SET conversation_history = ? WHERE id = ?",
(json.dumps(final_conversation_history), query_id)
"UPDATE Queries SET status = ? WHERE id = ?",
(QueryStatus.PROCESSING.value, query_id)
)
db.commit()
logger.info(f"Updated query {query_id} status to PROCESSING")
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
logger.info(f"Starting answer_question_tools_api for query {query_id}")
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
logger.info(f"Finished answer_question_tools_api for query {query_id}")
# Update with final result and set status to DONE
db.execute("BEGIN TRANSACTION")
cursor.execute(
"UPDATE Queries SET conversation_history = ?, status = ? WHERE id = ?",
(json.dumps(final_conversation_history), QueryStatus.DONE.value, query_id)
)
db.commit()
logger.info(f"Updated query {query_id} status to DONE")
else:
time.sleep(1) # Wait for 1 second before checking again if no queries are found
logger.info("No queued queries found. Waiting...")
time.sleep(random.uniform(5, 10)) # Wait for 5 seconds before checking again if no queries are found
except Exception as e:
logger.exception("Error processing query", error=str(e))
logger.exception(f"Error processing query: {str(e)}")
time.sleep(1) # Wait for 1 second before retrying in case of an error
@ -638,14 +677,9 @@ class GenerateKeyResponse(BaseModel):
@app.post(
"/admin/generate_key",
responses={
"200": GenerateKeyResponse,
"401": {"description": "Unauthorized"},
"500": {"description": "Internal Server Error"},
},
"/admin/generate_key"
)
def generate_api_key(body: GenerateKeyRequest):
def generate_api_key():
"""
Generate a new API key for a user.
@ -661,7 +695,11 @@ def generate_api_key(body: GenerateKeyRequest):
if not admin_key or admin_key != ADMIN_KEY:
return jsonify({"error": "Invalid admin key"}), 401
username = body.username
data = request.get_json()
if not data or 'username' not in data:
return jsonify({"error": "Invalid request body"}), 400
username = data['username']
api_key = secrets.token_urlsafe(32)
try:
@ -679,16 +717,30 @@ def generate_api_key(body: GenerateKeyRequest):
return jsonify({"error": str(e)}), 500
def start_processing_thread():
global processing_thread, processing_thread_started
if not processing_thread_started:
processing_thread = threading.Thread(target=process_queries, daemon=True)
processing_thread.start()
processing_thread_started = True
logger.info("Query processing thread started")
# Replace the if __main__ block with this:
if __name__ == "__main__":
logger.info("Starting LLM Chat Server")
init_db() # Initialize the database
if ENABLE_FRONTEND or ENABLE_CHAT_ENDPOINTS:
threading.Thread(target=send_system_resources, daemon=True).start()
logger.info("System resources thread started")
if ENABLE_API_ENDPOINTS:
threading.Thread(
target=lambda: app.app_context().push() and process_queries(), daemon=True
).start()
start_processing_thread()
socketio.run(app, debug=True, host="0.0.0.0", port=5001)
logger.info("Starting Flask application")
socketio.run(app, debug=True, host="0.0.0.0", port=5001)
else:
# This will run when the module is imported, e.g., by the reloader
if ENABLE_API_ENDPOINTS:
start_processing_thread()

View File

@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS Queries (
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
query TEXT NOT NULL,
api_key_id INTEGER,
status TEXT NOT NULL,
conversation_history TEXT,
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
);