sql gen
This commit is contained in:
parent
0c9059bcbc
commit
f56d9b59c0
156
main.py
156
main.py
@ -3,6 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
import queue
|
import queue
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import sqlite3
|
import sqlite3
|
||||||
@ -11,24 +12,32 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
import enum
|
||||||
|
|
||||||
import GPUtil
|
import GPUtil
|
||||||
import ollama
|
import ollama
|
||||||
import psutil
|
import psutil
|
||||||
import structlog
|
import structlog
|
||||||
|
import logging
|
||||||
from flask import Flask, g, jsonify, request, send_from_directory
|
from flask import Flask, g, jsonify, request, send_from_directory
|
||||||
from flask_openapi3 import Info, OpenAPI
|
|
||||||
from flask_socketio import SocketIO, emit
|
from flask_socketio import SocketIO, emit
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from models import model_manager
|
from models import model_manager
|
||||||
from tools import DefaultToolManager
|
from tools import DefaultToolManager
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(logging.INFO)
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
# Configuration setup
|
# Configuration setup
|
||||||
CONFIG_FILE = "config.ini"
|
CONFIG_FILE = "config.ini"
|
||||||
|
|
||||||
|
# Add this near the top of the file, after imports
|
||||||
|
processing_thread = None
|
||||||
|
processing_thread_started = False
|
||||||
|
|
||||||
|
|
||||||
def create_default_config():
|
def create_default_config():
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@ -65,8 +74,7 @@ ENABLE_API_ENDPOINTS = config["SERVER_FEATURES"].getboolean("EnableAPIEndpoints"
|
|||||||
PRIMARY_MODEL = config["MODEL"]["PrimaryModel"]
|
PRIMARY_MODEL = config["MODEL"]["PrimaryModel"]
|
||||||
UPDATE_INTERVAL = config["PERFORMANCE"].getfloat("UpdateInterval")
|
UPDATE_INTERVAL = config["PERFORMANCE"].getfloat("UpdateInterval")
|
||||||
|
|
||||||
openapi = OpenAPI(__name__, info=Info(title="LLM Chat Server", version="1.0.0"))
|
app = Flask(__name__)
|
||||||
app = openapi
|
|
||||||
socketio = SocketIO(app, cors_allowed_origins="*")
|
socketio = SocketIO(app, cors_allowed_origins="*")
|
||||||
|
|
||||||
tool_manager = DefaultToolManager()
|
tool_manager = DefaultToolManager()
|
||||||
@ -88,6 +96,12 @@ def close_connection(exception):
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
class QueryStatus(enum.Enum):
|
||||||
|
QUEUED = "queued"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
DONE = "done"
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
def init_db():
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
db = get_db()
|
db = get_db()
|
||||||
@ -105,6 +119,7 @@ def init_db():
|
|||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
query TEXT NOT NULL,
|
query TEXT NOT NULL,
|
||||||
api_key_id INTEGER,
|
api_key_id INTEGER,
|
||||||
|
status TEXT NOT NULL,
|
||||||
conversation_history TEXT,
|
conversation_history TEXT,
|
||||||
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
|
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
|
||||||
)
|
)
|
||||||
@ -126,6 +141,7 @@ CREATE TABLE IF NOT EXISTS Queries (
|
|||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
query TEXT NOT NULL,
|
query TEXT NOT NULL,
|
||||||
api_key_id INTEGER,
|
api_key_id INTEGER,
|
||||||
|
status TEXT NOT NULL,
|
||||||
conversation_history TEXT,
|
conversation_history TEXT,
|
||||||
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
|
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
|
||||||
);
|
);
|
||||||
@ -432,14 +448,9 @@ class QueryStatusResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/api/v1/query",
|
"/api/v1/query"
|
||||||
responses={
|
|
||||||
"200": QueryResponse,
|
|
||||||
"401": {"description": "Unauthorized"},
|
|
||||||
"500": {"description": "Internal Server Error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
def api_query(body: QueryRequest):
|
def api_query():
|
||||||
"""
|
"""
|
||||||
Submit a new query to the LLM Chat Server.
|
Submit a new query to the LLM Chat Server.
|
||||||
|
|
||||||
@ -462,31 +473,31 @@ def api_query(body: QueryRequest):
|
|||||||
if not api_key_id:
|
if not api_key_id:
|
||||||
return jsonify({"error": "Invalid API key"}), 401
|
return jsonify({"error": "Invalid API key"}), 401
|
||||||
|
|
||||||
user_input = body.message
|
data = request.get_json()
|
||||||
|
if not data or 'message' not in data:
|
||||||
|
return jsonify({"error": "Invalid request body"}), 400
|
||||||
|
|
||||||
|
user_input = data['message']
|
||||||
query_id = str(uuid.uuid4())
|
query_id = str(uuid.uuid4())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
db = get_db()
|
db = get_db()
|
||||||
cursor = db.cursor()
|
cursor = db.cursor()
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"INSERT INTO Queries (id, ip, query, api_key_id) VALUES (?, ?, ?, ?)",
|
"INSERT INTO Queries (id, ip, query, api_key_id, status) VALUES (?, ?, ?, ?, ?)",
|
||||||
(query_id, request.remote_addr, user_input, api_key_id)
|
(query_id, request.remote_addr, user_input, api_key_id, QueryStatus.QUEUED.value)
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
logger.info(f"Added new query to database: {query_id}")
|
||||||
|
|
||||||
return jsonify({"query_id": query_id})
|
return jsonify({"query_id": query_id})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error during API query processing", error=str(e))
|
logger.exception(f"Error during API query processing: {str(e)}")
|
||||||
return jsonify({"error": str(e)}), 500
|
return jsonify({"error": str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
@app.get(
|
@app.get(
|
||||||
"/api/v1/query_status/<string:query_id>",
|
"/api/v1/query_status/<string:query_id>"
|
||||||
responses={
|
|
||||||
"200": QueryStatusResponse,
|
|
||||||
"404": {"description": "Query not found"},
|
|
||||||
"500": {"description": "Internal Server Error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
def get_query_status(query_id: str):
|
def get_query_status(query_id: str):
|
||||||
"""
|
"""
|
||||||
@ -509,23 +520,19 @@ def get_query_status(query_id: str):
|
|||||||
try:
|
try:
|
||||||
db = get_db()
|
db = get_db()
|
||||||
cursor = db.cursor()
|
cursor = db.cursor()
|
||||||
cursor.execute("SELECT conversation_history FROM Queries WHERE id = ?", (query_id,))
|
cursor.execute("SELECT status, conversation_history FROM Queries WHERE id = ?", (query_id,))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
return jsonify({"error": "Query not found"}), 404
|
return jsonify({"error": "Query not found"}), 404
|
||||||
|
|
||||||
conversation_history = result[0]
|
status, conversation_history = result
|
||||||
|
|
||||||
if conversation_history is None:
|
response = {"status": status}
|
||||||
return jsonify({"status": "processing"}), 202
|
if status == QueryStatus.DONE.value:
|
||||||
else:
|
response["conversation_history"] = json.loads(conversation_history)
|
||||||
return jsonify(
|
|
||||||
{
|
return jsonify(response)
|
||||||
"status": "completed",
|
|
||||||
"conversation_history": json.loads(conversation_history),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error retrieving query status", error=str(e))
|
logger.exception("Error retrieving query status", error=str(e))
|
||||||
return jsonify({"error": str(e)}), 500
|
return jsonify({"error": str(e)}), 500
|
||||||
@ -600,30 +607,62 @@ def answer_question_tools_api(
|
|||||||
|
|
||||||
|
|
||||||
def process_queries():
|
def process_queries():
|
||||||
|
logger.info("Query processing thread started")
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
db = get_db()
|
db = get_db()
|
||||||
cursor = db.cursor()
|
cursor = db.cursor()
|
||||||
|
|
||||||
|
# First, check if there are any PROCESSING queries
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT id, query FROM Queries WHERE conversation_history IS NULL ORDER BY timestamp ASC LIMIT 1"
|
"SELECT id FROM Queries WHERE status = ? LIMIT 1",
|
||||||
|
(QueryStatus.PROCESSING.value,)
|
||||||
|
)
|
||||||
|
processing_query = cursor.fetchone()
|
||||||
|
if processing_query:
|
||||||
|
logger.info(f"Found processing query: {processing_query[0]}. Waiting...")
|
||||||
|
db.commit()
|
||||||
|
time.sleep(10)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# If no PROCESSING queries, get the oldest QUEUED query
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT id, query FROM Queries WHERE status = ? ORDER BY timestamp ASC LIMIT 1",
|
||||||
|
(QueryStatus.QUEUED.value,)
|
||||||
)
|
)
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
query_id, user_input = result
|
query_id, user_input = result
|
||||||
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
|
logger.info(f"Processing query: {query_id}")
|
||||||
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
|
|
||||||
|
# Update status to PROCESSING
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"UPDATE Queries SET conversation_history = ? WHERE id = ?",
|
"UPDATE Queries SET status = ? WHERE id = ?",
|
||||||
(json.dumps(final_conversation_history), query_id)
|
(QueryStatus.PROCESSING.value, query_id)
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
logger.info(f"Updated query {query_id} status to PROCESSING")
|
||||||
|
|
||||||
|
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
|
||||||
|
logger.info(f"Starting answer_question_tools_api for query {query_id}")
|
||||||
|
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
|
||||||
|
logger.info(f"Finished answer_question_tools_api for query {query_id}")
|
||||||
|
|
||||||
|
# Update with final result and set status to DONE
|
||||||
|
db.execute("BEGIN TRANSACTION")
|
||||||
|
cursor.execute(
|
||||||
|
"UPDATE Queries SET conversation_history = ?, status = ? WHERE id = ?",
|
||||||
|
(json.dumps(final_conversation_history), QueryStatus.DONE.value, query_id)
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Updated query {query_id} status to DONE")
|
||||||
else:
|
else:
|
||||||
time.sleep(1) # Wait for 1 second before checking again if no queries are found
|
logger.info("No queued queries found. Waiting...")
|
||||||
|
time.sleep(random.uniform(5, 10)) # Wait for 5 seconds before checking again if no queries are found
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error processing query", error=str(e))
|
logger.exception(f"Error processing query: {str(e)}")
|
||||||
time.sleep(1) # Wait for 1 second before retrying in case of an error
|
time.sleep(1) # Wait for 1 second before retrying in case of an error
|
||||||
|
|
||||||
|
|
||||||
@ -638,14 +677,9 @@ class GenerateKeyResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
"/admin/generate_key",
|
"/admin/generate_key"
|
||||||
responses={
|
|
||||||
"200": GenerateKeyResponse,
|
|
||||||
"401": {"description": "Unauthorized"},
|
|
||||||
"500": {"description": "Internal Server Error"},
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
def generate_api_key(body: GenerateKeyRequest):
|
def generate_api_key():
|
||||||
"""
|
"""
|
||||||
Generate a new API key for a user.
|
Generate a new API key for a user.
|
||||||
|
|
||||||
@ -661,7 +695,11 @@ def generate_api_key(body: GenerateKeyRequest):
|
|||||||
if not admin_key or admin_key != ADMIN_KEY:
|
if not admin_key or admin_key != ADMIN_KEY:
|
||||||
return jsonify({"error": "Invalid admin key"}), 401
|
return jsonify({"error": "Invalid admin key"}), 401
|
||||||
|
|
||||||
username = body.username
|
data = request.get_json()
|
||||||
|
if not data or 'username' not in data:
|
||||||
|
return jsonify({"error": "Invalid request body"}), 400
|
||||||
|
|
||||||
|
username = data['username']
|
||||||
api_key = secrets.token_urlsafe(32)
|
api_key = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -679,16 +717,30 @@ def generate_api_key(body: GenerateKeyRequest):
|
|||||||
return jsonify({"error": str(e)}), 500
|
return jsonify({"error": str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
def start_processing_thread():
|
||||||
|
global processing_thread, processing_thread_started
|
||||||
|
if not processing_thread_started:
|
||||||
|
processing_thread = threading.Thread(target=process_queries, daemon=True)
|
||||||
|
processing_thread.start()
|
||||||
|
processing_thread_started = True
|
||||||
|
logger.info("Query processing thread started")
|
||||||
|
|
||||||
|
|
||||||
|
# Replace the if __main__ block with this:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info("Starting LLM Chat Server")
|
logger.info("Starting LLM Chat Server")
|
||||||
init_db() # Initialize the database
|
init_db() # Initialize the database
|
||||||
|
|
||||||
if ENABLE_FRONTEND or ENABLE_CHAT_ENDPOINTS:
|
if ENABLE_FRONTEND or ENABLE_CHAT_ENDPOINTS:
|
||||||
threading.Thread(target=send_system_resources, daemon=True).start()
|
threading.Thread(target=send_system_resources, daemon=True).start()
|
||||||
|
logger.info("System resources thread started")
|
||||||
|
|
||||||
if ENABLE_API_ENDPOINTS:
|
if ENABLE_API_ENDPOINTS:
|
||||||
threading.Thread(
|
start_processing_thread()
|
||||||
target=lambda: app.app_context().push() and process_queries(), daemon=True
|
|
||||||
).start()
|
|
||||||
|
|
||||||
socketio.run(app, debug=True, host="0.0.0.0", port=5001)
|
logger.info("Starting Flask application")
|
||||||
|
socketio.run(app, debug=True, host="0.0.0.0", port=5001)
|
||||||
|
else:
|
||||||
|
# This will run when the module is imported, e.g., by the reloader
|
||||||
|
if ENABLE_API_ENDPOINTS:
|
||||||
|
start_processing_thread()
|
@ -10,6 +10,7 @@ CREATE TABLE IF NOT EXISTS Queries (
|
|||||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
query TEXT NOT NULL,
|
query TEXT NOT NULL,
|
||||||
api_key_id INTEGER,
|
api_key_id INTEGER,
|
||||||
|
status TEXT NOT NULL,
|
||||||
conversation_history TEXT,
|
conversation_history TEXT,
|
||||||
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
|
FOREIGN KEY (api_key_id) REFERENCES Keys (id)
|
||||||
);
|
);
|
Loading…
Reference in New Issue
Block a user