diff --git a/main.py b/main.py index 179ac45..d3f1f4e 100644 --- a/main.py +++ b/main.py @@ -22,6 +22,8 @@ import logging from flask import Flask, g, jsonify, request, send_from_directory from flask_socketio import SocketIO, emit from pydantic import BaseModel +from werkzeug.utils import secure_filename +import base64 from models import model_manager from tools import DefaultToolManager @@ -38,6 +40,9 @@ CONFIG_FILE = "config.ini" processing_thread = None processing_thread_started = False +ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} +MAX_IMAGE_SIZE = 1 * 1024 * 1024 # 1MB + def create_default_config(): config = configparser.ConfigParser() @@ -564,6 +569,7 @@ def answer_question_tools_api( tools=tool_manager.get_tools_for_ollama_dict(), stream=False, ) + logger.info(f"API Response: {response}") assistant_message = response["message"] conversation_history.append(assistant_message) @@ -574,6 +580,7 @@ def answer_question_tools_api( tool_args = tool_call["function"]["arguments"] tool_response = tool_manager.get_tool(tool_name).execute(tool_args) conversation_history.append({"role": "tool", "content": tool_response}) + logger.info(f"API Tool response: {tool_response}") else: if "" in assistant_message["content"].lower(): reply_content = re.search( @@ -645,7 +652,15 @@ def process_queries(): db.commit() logger.info(f"Updated query {query_id} status to PROCESSING") - conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}] + # Fetch conversation history if it exists + cursor.execute("SELECT conversation_history FROM Queries WHERE id = ?", (query_id,)) + conversation_history_result = cursor.fetchone() + + if conversation_history_result and conversation_history_result[0]: + conversation_history = json.loads(conversation_history_result[0]) + else: + conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}] + logger.info(f"Starting answer_question_tools_api for query {query_id}") final_conversation_history = answer_question_tools_api(user_input, conversation_history) logger.info(f"Finished answer_question_tools_api for query {query_id}") @@ -660,7 +675,7 @@ def process_queries(): logger.info(f"Updated query {query_id} status to DONE") else: logger.info("No queued queries found. Waiting...") - time.sleep(random.uniform(5, 10)) # Wait for 5 seconds before checking again if no queries are found + time.sleep(5) # Wait for 5 seconds before checking again if no queries are found except Exception as e: logger.exception(f"Error processing query: {str(e)}") time.sleep(1) # Wait for 1 second before retrying in case of an error @@ -726,6 +741,80 @@ def start_processing_thread(): logger.info("Query processing thread started") +def allowed_file(filename): + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + + +@app.post("/api/v1/query_with_image") +def api_query_with_image(): + """ + Submit a new query to the LLM Chat Server with an optional image. + + This endpoint requires authentication via an API key. + + Sample cURL: + curl -X POST http://localhost:5001/api/v1/query_with_image \ + -H "X-API-Key: your-api-key" \ + -F "message=What's in this image?" \ + -F "image=@path/to/your/image.jpg" + """ + if not ENABLE_API_ENDPOINTS: + return jsonify({"error": "API endpoints are disabled"}), 404 + + api_key = request.headers.get('X-API-Key') + if not api_key: + return jsonify({"error": "API key is required"}), 401 + + api_key_id = validate_api_key(api_key) + if not api_key_id: + return jsonify({"error": "Invalid API key"}), 401 + + if 'message' not in request.form: + return jsonify({"error": "Message is required"}), 400 + + user_input = request.form['message'] + query_id = str(uuid.uuid4()) + + image_base64 = None + if 'image' in request.files: + file = request.files['image'] + if file and allowed_file(file.filename): + if file.content_length > MAX_IMAGE_SIZE: + return jsonify({"error": "Image size exceeds 1MB limit"}), 400 + + # Read and encode the image + image_data = file.read() + image_base64 = base64.b64encode(image_data).decode('utf-8') + + try: + db = get_db() + cursor = db.cursor() + cursor.execute( + "INSERT INTO Queries (id, ip, query, api_key_id, status) VALUES (?, ?, ?, ?, ?)", + (query_id, request.remote_addr, user_input, api_key_id, QueryStatus.QUEUED.value) + ) + db.commit() + logger.info(f"Added new query with image to database: {query_id}") + + # If there's an image, add it to the conversation history + if image_base64: + conversation_history = [ + {"role": "system", "content": ANSWER_QUESTION_PROMPT}, + {"role": "user", "content": f"[An image was uploaded with this message] {user_input}"}, + {"role": "system", "content": f"An image was uploaded. You can analyze it using the analyze_image tool with the following base64 string: {image_base64}"} + ] + cursor.execute( + "UPDATE Queries SET conversation_history = ? WHERE id = ?", + (json.dumps(conversation_history), query_id) + ) + db.commit() + + return jsonify({"query_id": query_id}) + except Exception as e: + logger.exception(f"Error during API query processing with image: {str(e)}") + return jsonify({"error": str(e)}), 500 + + # Replace the if __main__ block with this: if __name__ == "__main__": logger.info("Starting LLM Chat Server") diff --git a/tools.py b/tools.py index 5f25c77..265ca79 100644 --- a/tools.py +++ b/tools.py @@ -1,12 +1,20 @@ import subprocess import tempfile import time - -import duckduckgo_search +import json import requests from markdownify import markdownify as md from readability.readability import Document - +import duckduckgo_search +import datetime +import random +import math +import re +import base64 +from io import BytesIO +from PIL import Image, ImageDraw, ImageFont +import ollama +import os class Tool: def __init__(self, name: str, description: str, arguments: dict, returns: str): @@ -56,6 +64,12 @@ class DefaultToolManager(ToolManager): self.add_tool(GetReadablePageContentsTool()) self.add_tool(CalculatorTool()) self.add_tool(PythonCodeTool()) + self.add_tool(DateTimeTool()) + self.add_tool(RandomNumberTool()) + self.add_tool(RegexTool()) + self.add_tool(Base64Tool()) + self.add_tool(SimpleChartTool()) + self.add_tool(LLAVAImageAnalysisTool()) class SearchTool(Tool): @@ -73,8 +87,11 @@ class SearchTool(Tool): ) def execute(self, arg: dict) -> str: - res = duckduckgo_search.DDGS().text(arg["query"], max_results=5) - return "\n\n".join([f"{r['title']}\n{r['body']}\n{r['href']}" for r in res]) + try: + res = duckduckgo_search.DDGS().text(arg["query"], max_results=5) + return "\n\n".join([f"{r['title']}\n{r['body']}\n{r['href']}" for r in res]) + except Exception as e: + return f"Error searching the web: {str(e)}" def get_readable_page_contents(url: str) -> str: @@ -180,3 +197,173 @@ class PythonCodeTool(Tool): return f"Error executing code: {str(e)}" return "\n".join([f"{k}:\n{v}" for k, v in result.items()]) + + +class DateTimeTool(Tool): + def __init__(self): + super().__init__( + "get_current_datetime", + "Get the current date and time", + {"type": "object", "properties": {}}, + "datetime:string" + ) + + def execute(self, arg: dict) -> str: + return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + +class RandomNumberTool(Tool): + def __init__(self): + super().__init__( + "generate_random_number", + "Generate a random number within a given range", + { + "type": "object", + "properties": { + "min": {"type": "number", "description": "The minimum value"}, + "max": {"type": "number", "description": "The maximum value"} + } + }, + "random_number:number" + ) + + def execute(self, arg: dict) -> str: + return str(random.uniform(arg["min"], arg["max"])) + + +class RegexTool(Tool): + def __init__(self): + super().__init__( + "regex_match", + "Perform a regex match on a given text", + { + "type": "object", + "properties": { + "text": {"type": "string", "description": "The text to search in"}, + "pattern": {"type": "string", "description": "The regex pattern to match"} + } + }, + "matches:list[string]" + ) + + def execute(self, arg: dict) -> str: + matches = re.findall(arg["pattern"], arg["text"]) + return json.dumps(matches) + + +class Base64Tool(Tool): + def __init__(self): + super().__init__( + "base64_encode_decode", + "Encode or decode a string using Base64", + { + "type": "object", + "properties": { + "action": {"type": "string", "enum": ["encode", "decode"], "description": "Whether to encode or decode"}, + "text": {"type": "string", "description": "The text to encode or decode"} + } + }, + "result:string" + ) + + def execute(self, arg: dict) -> str: + if arg["action"] == "encode": + return base64.b64encode(arg["text"].encode()).decode() + elif arg["action"] == "decode": + return base64.b64decode(arg["text"].encode()).decode() + else: + return "Invalid action. Use 'encode' or 'decode'." + + +class SimpleChartTool(Tool): + def __init__(self): + super().__init__( + "generate_simple_chart", + "Generate a simple bar chart image", + { + "type": "object", + "properties": { + "data": {"type": "array", "items": {"type": "number"}, "description": "List of numerical values for the chart"}, + "labels": {"type": "array", "items": {"type": "string"}, "description": "Labels for each bar"} + } + }, + "image_base64:string" + ) + + def execute(self, arg: dict) -> str: + data = arg["data"] + labels = arg["labels"] + + # Create a simple bar chart + width, height = 400, 300 + img = Image.new('RGB', (width, height), color='white') + draw = ImageDraw.Draw(img) + + # Draw bars + max_value = max(data) + bar_width = width // (len(data) + 1) + for i, value in enumerate(data): + bar_height = (value / max_value) * (height - 50) + left = (i + 1) * bar_width + draw.rectangle([left, height - bar_height, left + bar_width, height], fill='blue') + + # Add labels + font = ImageFont.load_default() + for i, label in enumerate(labels): + left = (i + 1) * bar_width + bar_width // 2 + draw.text((left, height - 20), label, fill='black', anchor='ms', font=font) + + # Convert to base64 + buffered = BytesIO() + img.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str + + +class LLAVAImageAnalysisTool(Tool): + def __init__(self): + super().__init__( + "analyze_image", + "Analyze an image using the LLAVA model", + { + "type": "object", + "properties": { + "image_base64": {"type": "string", "description": "Base64 encoded image"}, + "question": {"type": "string", "description": "Question about the image"} + } + }, + "analysis:string" + ) + + def execute(self, arg: dict) -> str: + try: + # Decode base64 image + image_data = base64.b64decode(arg["image_base64"]) + image = Image.open(BytesIO(image_data)) + + # Save image to a temporary file + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: + image.save(temp_file, format="PNG") + temp_file_path = temp_file.name + + # Call LLAVA model + response = ollama.chat( + model="llava:7b", + messages=[ + { + "role": "user", + "content": arg["question"], + "images": [temp_file_path] + } + ] + ) + + # Clean up temporary file + os.remove(temp_file_path) + + # Unload LLAVA model + ollama.delete("llava:7b") + + return response['message']['content'] + except Exception as e: + return f"Error analyzing image: {str(e)}" \ No newline at end of file