tool_use #1
152
main.py
152
main.py
@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
import pprint
|
||||
import queue
|
||||
import random
|
||||
import re
|
||||
import secrets
|
||||
import sqlite3
|
||||
@ -11,24 +12,32 @@ import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
import enum
|
||||
|
||||
import GPUtil
|
||||
import ollama
|
||||
import psutil
|
||||
import structlog
|
||||
import logging
|
||||
from flask import Flask, g, jsonify, request, send_from_directory
|
||||
from flask_openapi3 import Info, OpenAPI
|
||||
from flask_socketio import SocketIO, emit
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models import model_manager
|
||||
from tools import DefaultToolManager
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Configuration setup
|
||||
CONFIG_FILE = "config.ini"
|
||||
|
||||
# Add this near the top of the file, after imports
|
||||
processing_thread = None
|
||||
processing_thread_started = False
|
||||
|
||||
|
||||
def create_default_config():
|
||||
config = configparser.ConfigParser()
|
||||
@ -65,8 +74,7 @@ ENABLE_API_ENDPOINTS = config["SERVER_FEATURES"].getboolean("EnableAPIEndpoints"
|
||||
PRIMARY_MODEL = config["MODEL"]["PrimaryModel"]
|
||||
UPDATE_INTERVAL = config["PERFORMANCE"].getfloat("UpdateInterval")
|
||||
|
||||
openapi = OpenAPI(__name__, info=Info(title="LLM Chat Server", version="1.0.0"))
|
||||
app = openapi
|
||||
app = Flask(__name__)
|
||||
socketio = SocketIO(app, cors_allowed_origins="*")
|
||||
|
||||
tool_manager = DefaultToolManager()
|
||||
@ -88,6 +96,12 @@ def close_connection(exception):
|
||||
db.close()
|
||||
|
||||
|
||||
class QueryStatus(enum.Enum):
|
||||
QUEUED = "queued"
|
||||
PROCESSING = "processing"
|
||||
DONE = "done"
|
||||
|
||||
|
||||
def init_db():
|
||||
with app.app_context():
|
||||
db = get_db()
|
||||
@ -105,6 +119,7 @@ def init_db():
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
query TEXT NOT NULL,
|
||||
api_key_id INTEGER,
|
||||
status TEXT NOT NULL,
|
||||
conversation_history TEXT,
|
||||
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
|
||||
)
|
||||
@ -126,6 +141,7 @@ CREATE TABLE IF NOT EXISTS Queries (
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
query TEXT NOT NULL,
|
||||
api_key_id INTEGER,
|
||||
status TEXT NOT NULL,
|
||||
conversation_history TEXT,
|
||||
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
|
||||
);
|
||||
@ -432,14 +448,9 @@ class QueryStatusResponse(BaseModel):
|
||||
|
||||
|
||||
@app.post(
|
||||
"/api/v1/query",
|
||||
responses={
|
||||
"200": QueryResponse,
|
||||
"401": {"description": "Unauthorized"},
|
||||
"500": {"description": "Internal Server Error"},
|
||||
},
|
||||
"/api/v1/query"
|
||||
)
|
||||
def api_query(body: QueryRequest):
|
||||
def api_query():
|
||||
"""
|
||||
Submit a new query to the LLM Chat Server.
|
||||
|
||||
@ -462,31 +473,31 @@ def api_query(body: QueryRequest):
|
||||
if not api_key_id:
|
||||
return jsonify({"error": "Invalid API key"}), 401
|
||||
|
||||
user_input = body.message
|
||||
data = request.get_json()
|
||||
if not data or 'message' not in data:
|
||||
return jsonify({"error": "Invalid request body"}), 400
|
||||
|
||||
user_input = data['message']
|
||||
query_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
db = get_db()
|
||||
cursor = db.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO Queries (id, ip, query, api_key_id) VALUES (?, ?, ?, ?)",
|
||||
(query_id, request.remote_addr, user_input, api_key_id)
|
||||
"INSERT INTO Queries (id, ip, query, api_key_id, status) VALUES (?, ?, ?, ?, ?)",
|
||||
(query_id, request.remote_addr, user_input, api_key_id, QueryStatus.QUEUED.value)
|
||||
)
|
||||
db.commit()
|
||||
logger.info(f"Added new query to database: {query_id}")
|
||||
|
||||
return jsonify({"query_id": query_id})
|
||||
except Exception as e:
|
||||
logger.exception("Error during API query processing", error=str(e))
|
||||
logger.exception(f"Error during API query processing: {str(e)}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
@app.get(
|
||||
"/api/v1/query_status/<string:query_id>",
|
||||
responses={
|
||||
"200": QueryStatusResponse,
|
||||
"404": {"description": "Query not found"},
|
||||
"500": {"description": "Internal Server Error"},
|
||||
},
|
||||
"/api/v1/query_status/<string:query_id>"
|
||||
)
|
||||
def get_query_status(query_id: str):
|
||||
"""
|
||||
@ -509,23 +520,19 @@ def get_query_status(query_id: str):
|
||||
try:
|
||||
db = get_db()
|
||||
cursor = db.cursor()
|
||||
cursor.execute("SELECT conversation_history FROM Queries WHERE id = ?", (query_id,))
|
||||
cursor.execute("SELECT status, conversation_history FROM Queries WHERE id = ?", (query_id,))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result is None:
|
||||
return jsonify({"error": "Query not found"}), 404
|
||||
|
||||
conversation_history = result[0]
|
||||
status, conversation_history = result
|
||||
|
||||
if conversation_history is None:
|
||||
return jsonify({"status": "processing"}), 202
|
||||
else:
|
||||
return jsonify(
|
||||
{
|
||||
"status": "completed",
|
||||
"conversation_history": json.loads(conversation_history),
|
||||
}
|
||||
)
|
||||
response = {"status": status}
|
||||
if status == QueryStatus.DONE.value:
|
||||
response["conversation_history"] = json.loads(conversation_history)
|
||||
|
||||
return jsonify(response)
|
||||
except Exception as e:
|
||||
logger.exception("Error retrieving query status", error=str(e))
|
||||
return jsonify({"error": str(e)}), 500
|
||||
@ -600,30 +607,62 @@ def answer_question_tools_api(
|
||||
|
||||
|
||||
def process_queries():
|
||||
logger.info("Query processing thread started")
|
||||
with app.app_context():
|
||||
while True:
|
||||
try:
|
||||
db = get_db()
|
||||
cursor = db.cursor()
|
||||
|
||||
# First, check if there are any PROCESSING queries
|
||||
cursor.execute(
|
||||
"SELECT id, query FROM Queries WHERE conversation_history IS NULL ORDER BY timestamp ASC LIMIT 1"
|
||||
"SELECT id FROM Queries WHERE status = ? LIMIT 1",
|
||||
(QueryStatus.PROCESSING.value,)
|
||||
)
|
||||
processing_query = cursor.fetchone()
|
||||
if processing_query:
|
||||
logger.info(f"Found processing query: {processing_query[0]}. Waiting...")
|
||||
db.commit()
|
||||
time.sleep(10)
|
||||
continue
|
||||
|
||||
# If no PROCESSING queries, get the oldest QUEUED query
|
||||
cursor.execute(
|
||||
"SELECT id, query FROM Queries WHERE status = ? ORDER BY timestamp ASC LIMIT 1",
|
||||
(QueryStatus.QUEUED.value,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
query_id, user_input = result
|
||||
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
|
||||
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
|
||||
logger.info(f"Processing query: {query_id}")
|
||||
|
||||
# Update status to PROCESSING
|
||||
cursor.execute(
|
||||
"UPDATE Queries SET conversation_history = ? WHERE id = ?",
|
||||
(json.dumps(final_conversation_history), query_id)
|
||||
"UPDATE Queries SET status = ? WHERE id = ?",
|
||||
(QueryStatus.PROCESSING.value, query_id)
|
||||
)
|
||||
db.commit()
|
||||
logger.info(f"Updated query {query_id} status to PROCESSING")
|
||||
|
||||
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
|
||||
logger.info(f"Starting answer_question_tools_api for query {query_id}")
|
||||
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
|
||||
logger.info(f"Finished answer_question_tools_api for query {query_id}")
|
||||
|
||||
# Update with final result and set status to DONE
|
||||
db.execute("BEGIN TRANSACTION")
|
||||
cursor.execute(
|
||||
"UPDATE Queries SET conversation_history = ?, status = ? WHERE id = ?",
|
||||
(json.dumps(final_conversation_history), QueryStatus.DONE.value, query_id)
|
||||
)
|
||||
db.commit()
|
||||
logger.info(f"Updated query {query_id} status to DONE")
|
||||
else:
|
||||
time.sleep(1) # Wait for 1 second before checking again if no queries are found
|
||||
logger.info("No queued queries found. Waiting...")
|
||||
time.sleep(random.uniform(5, 10)) # Wait for 5 seconds before checking again if no queries are found
|
||||
except Exception as e:
|
||||
logger.exception("Error processing query", error=str(e))
|
||||
logger.exception(f"Error processing query: {str(e)}")
|
||||
time.sleep(1) # Wait for 1 second before retrying in case of an error
|
||||
|
||||
|
||||
@ -638,14 +677,9 @@ class GenerateKeyResponse(BaseModel):
|
||||
|
||||
|
||||
@app.post(
|
||||
"/admin/generate_key",
|
||||
responses={
|
||||
"200": GenerateKeyResponse,
|
||||
"401": {"description": "Unauthorized"},
|
||||
"500": {"description": "Internal Server Error"},
|
||||
},
|
||||
"/admin/generate_key"
|
||||
)
|
||||
def generate_api_key(body: GenerateKeyRequest):
|
||||
def generate_api_key():
|
||||
"""
|
||||
Generate a new API key for a user.
|
||||
|
||||
@ -661,7 +695,11 @@ def generate_api_key(body: GenerateKeyRequest):
|
||||
if not admin_key or admin_key != ADMIN_KEY:
|
||||
return jsonify({"error": "Invalid admin key"}), 401
|
||||
|
||||
username = body.username
|
||||
data = request.get_json()
|
||||
if not data or 'username' not in data:
|
||||
return jsonify({"error": "Invalid request body"}), 400
|
||||
|
||||
username = data['username']
|
||||
api_key = secrets.token_urlsafe(32)
|
||||
|
||||
try:
|
||||
@ -679,16 +717,30 @@ def generate_api_key(body: GenerateKeyRequest):
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
def start_processing_thread():
|
||||
global processing_thread, processing_thread_started
|
||||
if not processing_thread_started:
|
||||
processing_thread = threading.Thread(target=process_queries, daemon=True)
|
||||
processing_thread.start()
|
||||
processing_thread_started = True
|
||||
logger.info("Query processing thread started")
|
||||
|
||||
|
||||
# Replace the if __main__ block with this:
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting LLM Chat Server")
|
||||
init_db() # Initialize the database
|
||||
|
||||
if ENABLE_FRONTEND or ENABLE_CHAT_ENDPOINTS:
|
||||
threading.Thread(target=send_system_resources, daemon=True).start()
|
||||
logger.info("System resources thread started")
|
||||
|
||||
if ENABLE_API_ENDPOINTS:
|
||||
threading.Thread(
|
||||
target=lambda: app.app_context().push() and process_queries(), daemon=True
|
||||
).start()
|
||||
start_processing_thread()
|
||||
|
||||
logger.info("Starting Flask application")
|
||||
socketio.run(app, debug=True, host="0.0.0.0", port=5001)
|
||||
else:
|
||||
# This will run when the module is imported, e.g., by the reloader
|
||||
if ENABLE_API_ENDPOINTS:
|
||||
start_processing_thread()
|
@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS Queries (
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
query TEXT NOT NULL,
|
||||
api_key_id INTEGER,
|
||||
status TEXT NOT NULL,
|
||||
conversation_history TEXT,
|
||||
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
|
||||
);
|
Loading…
Reference in New Issue
Block a user