tool_use #1

Open
dubey wants to merge 21 commits from tool_use into main
2 changed files with 105 additions and 52 deletions
Showing only changes of commit f56d9b59c0 - Show all commits

156
main.py
View File

@ -3,6 +3,7 @@ import json
import os import os
import pprint import pprint
import queue import queue
import random
import re import re
import secrets import secrets
import sqlite3 import sqlite3
@ -11,24 +12,32 @@ import time
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
import enum
import GPUtil import GPUtil
import ollama import ollama
import psutil import psutil
import structlog import structlog
import logging
from flask import Flask, g, jsonify, request, send_from_directory from flask import Flask, g, jsonify, request, send_from_directory
from flask_openapi3 import Info, OpenAPI
from flask_socketio import SocketIO, emit from flask_socketio import SocketIO, emit
from pydantic import BaseModel from pydantic import BaseModel
from models import model_manager from models import model_manager
from tools import DefaultToolManager from tools import DefaultToolManager
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(message)s")
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
logger = structlog.get_logger() logger = structlog.get_logger()
# Configuration setup # Configuration setup
CONFIG_FILE = "config.ini" CONFIG_FILE = "config.ini"
# Add this near the top of the file, after imports
processing_thread = None
processing_thread_started = False
def create_default_config(): def create_default_config():
config = configparser.ConfigParser() config = configparser.ConfigParser()
@ -65,8 +74,7 @@ ENABLE_API_ENDPOINTS = config["SERVER_FEATURES"].getboolean("EnableAPIEndpoints"
PRIMARY_MODEL = config["MODEL"]["PrimaryModel"] PRIMARY_MODEL = config["MODEL"]["PrimaryModel"]
UPDATE_INTERVAL = config["PERFORMANCE"].getfloat("UpdateInterval") UPDATE_INTERVAL = config["PERFORMANCE"].getfloat("UpdateInterval")
openapi = OpenAPI(__name__, info=Info(title="LLM Chat Server", version="1.0.0")) app = Flask(__name__)
app = openapi
socketio = SocketIO(app, cors_allowed_origins="*") socketio = SocketIO(app, cors_allowed_origins="*")
tool_manager = DefaultToolManager() tool_manager = DefaultToolManager()
@ -88,6 +96,12 @@ def close_connection(exception):
db.close() db.close()
class QueryStatus(enum.Enum):
QUEUED = "queued"
PROCESSING = "processing"
DONE = "done"
def init_db(): def init_db():
with app.app_context(): with app.app_context():
db = get_db() db = get_db()
@ -105,6 +119,7 @@ def init_db():
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
query TEXT NOT NULL, query TEXT NOT NULL,
api_key_id INTEGER, api_key_id INTEGER,
status TEXT NOT NULL,
conversation_history TEXT, conversation_history TEXT,
FOREIGN KEY (api_key_id) REFERENCES Keys (id) FOREIGN KEY (api_key_id) REFERENCES Keys (id)
) )
@ -126,6 +141,7 @@ CREATE TABLE IF NOT EXISTS Queries (
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
query TEXT NOT NULL, query TEXT NOT NULL,
api_key_id INTEGER, api_key_id INTEGER,
status TEXT NOT NULL,
conversation_history TEXT, conversation_history TEXT,
FOREIGN KEY (api_key_id) REFERENCES Keys (id) FOREIGN KEY (api_key_id) REFERENCES Keys (id)
); );
@ -432,14 +448,9 @@ class QueryStatusResponse(BaseModel):
@app.post( @app.post(
"/api/v1/query", "/api/v1/query"
responses={
"200": QueryResponse,
"401": {"description": "Unauthorized"},
"500": {"description": "Internal Server Error"},
},
) )
def api_query(body: QueryRequest): def api_query():
""" """
Submit a new query to the LLM Chat Server. Submit a new query to the LLM Chat Server.
@ -462,31 +473,31 @@ def api_query(body: QueryRequest):
if not api_key_id: if not api_key_id:
return jsonify({"error": "Invalid API key"}), 401 return jsonify({"error": "Invalid API key"}), 401
user_input = body.message data = request.get_json()
if not data or 'message' not in data:
return jsonify({"error": "Invalid request body"}), 400
user_input = data['message']
query_id = str(uuid.uuid4()) query_id = str(uuid.uuid4())
try: try:
db = get_db() db = get_db()
cursor = db.cursor() cursor = db.cursor()
cursor.execute( cursor.execute(
"INSERT INTO Queries (id, ip, query, api_key_id) VALUES (?, ?, ?, ?)", "INSERT INTO Queries (id, ip, query, api_key_id, status) VALUES (?, ?, ?, ?, ?)",
(query_id, request.remote_addr, user_input, api_key_id) (query_id, request.remote_addr, user_input, api_key_id, QueryStatus.QUEUED.value)
) )
db.commit() db.commit()
logger.info(f"Added new query to database: {query_id}")
return jsonify({"query_id": query_id}) return jsonify({"query_id": query_id})
except Exception as e: except Exception as e:
logger.exception("Error during API query processing", error=str(e)) logger.exception(f"Error during API query processing: {str(e)}")
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500
@app.get( @app.get(
"/api/v1/query_status/<string:query_id>", "/api/v1/query_status/<string:query_id>"
responses={
"200": QueryStatusResponse,
"404": {"description": "Query not found"},
"500": {"description": "Internal Server Error"},
},
) )
def get_query_status(query_id: str): def get_query_status(query_id: str):
""" """
@ -509,23 +520,19 @@ def get_query_status(query_id: str):
try: try:
db = get_db() db = get_db()
cursor = db.cursor() cursor = db.cursor()
cursor.execute("SELECT conversation_history FROM Queries WHERE id = ?", (query_id,)) cursor.execute("SELECT status, conversation_history FROM Queries WHERE id = ?", (query_id,))
result = cursor.fetchone() result = cursor.fetchone()
if result is None: if result is None:
return jsonify({"error": "Query not found"}), 404 return jsonify({"error": "Query not found"}), 404
conversation_history = result[0] status, conversation_history = result
if conversation_history is None: response = {"status": status}
return jsonify({"status": "processing"}), 202 if status == QueryStatus.DONE.value:
else: response["conversation_history"] = json.loads(conversation_history)
return jsonify(
{ return jsonify(response)
"status": "completed",
"conversation_history": json.loads(conversation_history),
}
)
except Exception as e: except Exception as e:
logger.exception("Error retrieving query status", error=str(e)) logger.exception("Error retrieving query status", error=str(e))
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500
@ -600,30 +607,62 @@ def answer_question_tools_api(
def process_queries(): def process_queries():
logger.info("Query processing thread started")
with app.app_context(): with app.app_context():
while True: while True:
try: try:
db = get_db() db = get_db()
cursor = db.cursor() cursor = db.cursor()
# First, check if there are any PROCESSING queries
cursor.execute( cursor.execute(
"SELECT id, query FROM Queries WHERE conversation_history IS NULL ORDER BY timestamp ASC LIMIT 1" "SELECT id FROM Queries WHERE status = ? LIMIT 1",
(QueryStatus.PROCESSING.value,)
)
processing_query = cursor.fetchone()
if processing_query:
logger.info(f"Found processing query: {processing_query[0]}. Waiting...")
db.commit()
time.sleep(10)
continue
# If no PROCESSING queries, get the oldest QUEUED query
cursor.execute(
"SELECT id, query FROM Queries WHERE status = ? ORDER BY timestamp ASC LIMIT 1",
(QueryStatus.QUEUED.value,)
) )
result = cursor.fetchone() result = cursor.fetchone()
if result: if result:
query_id, user_input = result query_id, user_input = result
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}] logger.info(f"Processing query: {query_id}")
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
# Update status to PROCESSING
cursor.execute( cursor.execute(
"UPDATE Queries SET conversation_history = ? WHERE id = ?", "UPDATE Queries SET status = ? WHERE id = ?",
(json.dumps(final_conversation_history), query_id) (QueryStatus.PROCESSING.value, query_id)
) )
db.commit() db.commit()
logger.info(f"Updated query {query_id} status to PROCESSING")
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
logger.info(f"Starting answer_question_tools_api for query {query_id}")
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
logger.info(f"Finished answer_question_tools_api for query {query_id}")
# Update with final result and set status to DONE
db.execute("BEGIN TRANSACTION")
cursor.execute(
"UPDATE Queries SET conversation_history = ?, status = ? WHERE id = ?",
(json.dumps(final_conversation_history), QueryStatus.DONE.value, query_id)
)
db.commit()
logger.info(f"Updated query {query_id} status to DONE")
else: else:
time.sleep(1) # Wait for 1 second before checking again if no queries are found logger.info("No queued queries found. Waiting...")
time.sleep(random.uniform(5, 10)) # Wait for 5 seconds before checking again if no queries are found
except Exception as e: except Exception as e:
logger.exception("Error processing query", error=str(e)) logger.exception(f"Error processing query: {str(e)}")
time.sleep(1) # Wait for 1 second before retrying in case of an error time.sleep(1) # Wait for 1 second before retrying in case of an error
@ -638,14 +677,9 @@ class GenerateKeyResponse(BaseModel):
@app.post( @app.post(
"/admin/generate_key", "/admin/generate_key"
responses={
"200": GenerateKeyResponse,
"401": {"description": "Unauthorized"},
"500": {"description": "Internal Server Error"},
},
) )
def generate_api_key(body: GenerateKeyRequest): def generate_api_key():
""" """
Generate a new API key for a user. Generate a new API key for a user.
@ -661,7 +695,11 @@ def generate_api_key(body: GenerateKeyRequest):
if not admin_key or admin_key != ADMIN_KEY: if not admin_key or admin_key != ADMIN_KEY:
return jsonify({"error": "Invalid admin key"}), 401 return jsonify({"error": "Invalid admin key"}), 401
username = body.username data = request.get_json()
if not data or 'username' not in data:
return jsonify({"error": "Invalid request body"}), 400
username = data['username']
api_key = secrets.token_urlsafe(32) api_key = secrets.token_urlsafe(32)
try: try:
@ -679,16 +717,30 @@ def generate_api_key(body: GenerateKeyRequest):
return jsonify({"error": str(e)}), 500 return jsonify({"error": str(e)}), 500
def start_processing_thread():
global processing_thread, processing_thread_started
if not processing_thread_started:
processing_thread = threading.Thread(target=process_queries, daemon=True)
processing_thread.start()
processing_thread_started = True
logger.info("Query processing thread started")
# Replace the if __main__ block with this:
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Starting LLM Chat Server") logger.info("Starting LLM Chat Server")
init_db() # Initialize the database init_db() # Initialize the database
if ENABLE_FRONTEND or ENABLE_CHAT_ENDPOINTS: if ENABLE_FRONTEND or ENABLE_CHAT_ENDPOINTS:
threading.Thread(target=send_system_resources, daemon=True).start() threading.Thread(target=send_system_resources, daemon=True).start()
logger.info("System resources thread started")
if ENABLE_API_ENDPOINTS: if ENABLE_API_ENDPOINTS:
threading.Thread( start_processing_thread()
target=lambda: app.app_context().push() and process_queries(), daemon=True
).start()
socketio.run(app, debug=True, host="0.0.0.0", port=5001) logger.info("Starting Flask application")
socketio.run(app, debug=True, host="0.0.0.0", port=5001)
else:
# This will run when the module is imported, e.g., by the reloader
if ENABLE_API_ENDPOINTS:
start_processing_thread()

View File

@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS Queries (
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
query TEXT NOT NULL, query TEXT NOT NULL,
api_key_id INTEGER, api_key_id INTEGER,
status TEXT NOT NULL,
conversation_history TEXT, conversation_history TEXT,
FOREIGN KEY (api_key_id) REFERENCES Keys (id) FOREIGN KEY (api_key_id) REFERENCES Keys (id)
); );