tool_use #1

Open
dubey wants to merge 21 commits from tool_use into main
2 changed files with 283 additions and 7 deletions
Showing only changes of commit 6d0134c34d - Show all commits

91
main.py
View File

@ -22,6 +22,8 @@ import logging
from flask import Flask, g, jsonify, request, send_from_directory
from flask_socketio import SocketIO, emit
from pydantic import BaseModel
from werkzeug.utils import secure_filename
import base64
from models import model_manager
from tools import DefaultToolManager
@ -38,6 +40,9 @@ CONFIG_FILE = "config.ini"
processing_thread = None
processing_thread_started = False
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
MAX_IMAGE_SIZE = 1 * 1024 * 1024 # 1MB
def create_default_config():
config = configparser.ConfigParser()
@ -564,6 +569,7 @@ def answer_question_tools_api(
tools=tool_manager.get_tools_for_ollama_dict(),
stream=False,
)
logger.info(f"API Response: {response}")
assistant_message = response["message"]
conversation_history.append(assistant_message)
@ -574,6 +580,7 @@ def answer_question_tools_api(
tool_args = tool_call["function"]["arguments"]
tool_response = tool_manager.get_tool(tool_name).execute(tool_args)
conversation_history.append({"role": "tool", "content": tool_response})
logger.info(f"API Tool response: {tool_response}")
else:
if "<reply>" in assistant_message["content"].lower():
reply_content = re.search(
@ -645,7 +652,15 @@ def process_queries():
db.commit()
logger.info(f"Updated query {query_id} status to PROCESSING")
# Fetch conversation history if it exists
cursor.execute("SELECT conversation_history FROM Queries WHERE id = ?", (query_id,))
conversation_history_result = cursor.fetchone()
if conversation_history_result and conversation_history_result[0]:
conversation_history = json.loads(conversation_history_result[0])
else:
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
logger.info(f"Starting answer_question_tools_api for query {query_id}")
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
logger.info(f"Finished answer_question_tools_api for query {query_id}")
@ -660,7 +675,7 @@ def process_queries():
logger.info(f"Updated query {query_id} status to DONE")
else:
logger.info("No queued queries found. Waiting...")
time.sleep(random.uniform(5, 10)) # Wait for 5 seconds before checking again if no queries are found
time.sleep(5) # Wait for 5 seconds before checking again if no queries are found
except Exception as e:
logger.exception(f"Error processing query: {str(e)}")
time.sleep(1) # Wait for 1 second before retrying in case of an error
@ -726,6 +741,80 @@ def start_processing_thread():
logger.info("Query processing thread started")
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.post("/api/v1/query_with_image")
def api_query_with_image():
"""
Submit a new query to the LLM Chat Server with an optional image.
This endpoint requires authentication via an API key.
Sample cURL:
curl -X POST http://localhost:5001/api/v1/query_with_image \
-H "X-API-Key: your-api-key" \
-F "message=What's in this image?" \
-F "image=@path/to/your/image.jpg"
"""
if not ENABLE_API_ENDPOINTS:
return jsonify({"error": "API endpoints are disabled"}), 404
api_key = request.headers.get('X-API-Key')
if not api_key:
return jsonify({"error": "API key is required"}), 401
api_key_id = validate_api_key(api_key)
if not api_key_id:
return jsonify({"error": "Invalid API key"}), 401
if 'message' not in request.form:
return jsonify({"error": "Message is required"}), 400
user_input = request.form['message']
query_id = str(uuid.uuid4())
image_base64 = None
if 'image' in request.files:
file = request.files['image']
if file and allowed_file(file.filename):
if file.content_length > MAX_IMAGE_SIZE:
return jsonify({"error": "Image size exceeds 1MB limit"}), 400
# Read and encode the image
image_data = file.read()
image_base64 = base64.b64encode(image_data).decode('utf-8')
try:
db = get_db()
cursor = db.cursor()
cursor.execute(
"INSERT INTO Queries (id, ip, query, api_key_id, status) VALUES (?, ?, ?, ?, ?)",
(query_id, request.remote_addr, user_input, api_key_id, QueryStatus.QUEUED.value)
)
db.commit()
logger.info(f"Added new query with image to database: {query_id}")
# If there's an image, add it to the conversation history
if image_base64:
conversation_history = [
{"role": "system", "content": ANSWER_QUESTION_PROMPT},
{"role": "user", "content": f"[An image was uploaded with this message] {user_input}"},
{"role": "system", "content": f"An image was uploaded. You can analyze it using the analyze_image tool with the following base64 string: {image_base64}"}
]
cursor.execute(
"UPDATE Queries SET conversation_history = ? WHERE id = ?",
(json.dumps(conversation_history), query_id)
)
db.commit()
return jsonify({"query_id": query_id})
except Exception as e:
logger.exception(f"Error during API query processing with image: {str(e)}")
return jsonify({"error": str(e)}), 500
# Replace the if __main__ block with this:
if __name__ == "__main__":
logger.info("Starting LLM Chat Server")

193
tools.py
View File

@ -1,12 +1,20 @@
import subprocess
import tempfile
import time
import duckduckgo_search
import json
import requests
from markdownify import markdownify as md
from readability.readability import Document
import duckduckgo_search
import datetime
import random
import math
import re
import base64
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
import ollama
import os
class Tool:
def __init__(self, name: str, description: str, arguments: dict, returns: str):
@ -56,6 +64,12 @@ class DefaultToolManager(ToolManager):
self.add_tool(GetReadablePageContentsTool())
self.add_tool(CalculatorTool())
self.add_tool(PythonCodeTool())
self.add_tool(DateTimeTool())
self.add_tool(RandomNumberTool())
self.add_tool(RegexTool())
self.add_tool(Base64Tool())
self.add_tool(SimpleChartTool())
self.add_tool(LLAVAImageAnalysisTool())
class SearchTool(Tool):
@ -73,8 +87,11 @@ class SearchTool(Tool):
)
def execute(self, arg: dict) -> str:
try:
res = duckduckgo_search.DDGS().text(arg["query"], max_results=5)
return "\n\n".join([f"{r['title']}\n{r['body']}\n{r['href']}" for r in res])
except Exception as e:
return f"Error searching the web: {str(e)}"
def get_readable_page_contents(url: str) -> str:
@ -180,3 +197,173 @@ class PythonCodeTool(Tool):
return f"Error executing code: {str(e)}"
return "\n".join([f"{k}:\n{v}" for k, v in result.items()])
class DateTimeTool(Tool):
def __init__(self):
super().__init__(
"get_current_datetime",
"Get the current date and time",
{"type": "object", "properties": {}},
"datetime:string"
)
def execute(self, arg: dict) -> str:
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
class RandomNumberTool(Tool):
def __init__(self):
super().__init__(
"generate_random_number",
"Generate a random number within a given range",
{
"type": "object",
"properties": {
"min": {"type": "number", "description": "The minimum value"},
"max": {"type": "number", "description": "The maximum value"}
}
},
"random_number:number"
)
def execute(self, arg: dict) -> str:
return str(random.uniform(arg["min"], arg["max"]))
class RegexTool(Tool):
def __init__(self):
super().__init__(
"regex_match",
"Perform a regex match on a given text",
{
"type": "object",
"properties": {
"text": {"type": "string", "description": "The text to search in"},
"pattern": {"type": "string", "description": "The regex pattern to match"}
}
},
"matches:list[string]"
)
def execute(self, arg: dict) -> str:
matches = re.findall(arg["pattern"], arg["text"])
return json.dumps(matches)
class Base64Tool(Tool):
def __init__(self):
super().__init__(
"base64_encode_decode",
"Encode or decode a string using Base64",
{
"type": "object",
"properties": {
"action": {"type": "string", "enum": ["encode", "decode"], "description": "Whether to encode or decode"},
"text": {"type": "string", "description": "The text to encode or decode"}
}
},
"result:string"
)
def execute(self, arg: dict) -> str:
if arg["action"] == "encode":
return base64.b64encode(arg["text"].encode()).decode()
elif arg["action"] == "decode":
return base64.b64decode(arg["text"].encode()).decode()
else:
return "Invalid action. Use 'encode' or 'decode'."
class SimpleChartTool(Tool):
def __init__(self):
super().__init__(
"generate_simple_chart",
"Generate a simple bar chart image",
{
"type": "object",
"properties": {
"data": {"type": "array", "items": {"type": "number"}, "description": "List of numerical values for the chart"},
"labels": {"type": "array", "items": {"type": "string"}, "description": "Labels for each bar"}
}
},
"image_base64:string"
)
def execute(self, arg: dict) -> str:
data = arg["data"]
labels = arg["labels"]
# Create a simple bar chart
width, height = 400, 300
img = Image.new('RGB', (width, height), color='white')
draw = ImageDraw.Draw(img)
# Draw bars
max_value = max(data)
bar_width = width // (len(data) + 1)
for i, value in enumerate(data):
bar_height = (value / max_value) * (height - 50)
left = (i + 1) * bar_width
draw.rectangle([left, height - bar_height, left + bar_width, height], fill='blue')
# Add labels
font = ImageFont.load_default()
for i, label in enumerate(labels):
left = (i + 1) * bar_width + bar_width // 2
draw.text((left, height - 20), label, fill='black', anchor='ms', font=font)
# Convert to base64
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
class LLAVAImageAnalysisTool(Tool):
def __init__(self):
super().__init__(
"analyze_image",
"Analyze an image using the LLAVA model",
{
"type": "object",
"properties": {
"image_base64": {"type": "string", "description": "Base64 encoded image"},
"question": {"type": "string", "description": "Question about the image"}
}
},
"analysis:string"
)
def execute(self, arg: dict) -> str:
try:
# Decode base64 image
image_data = base64.b64decode(arg["image_base64"])
image = Image.open(BytesIO(image_data))
# Save image to a temporary file
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image.save(temp_file, format="PNG")
temp_file_path = temp_file.name
# Call LLAVA model
response = ollama.chat(
model="llava:7b",
messages=[
{
"role": "user",
"content": arg["question"],
"images": [temp_file_path]
}
]
)
# Clean up temporary file
os.remove(temp_file_path)
# Unload LLAVA model
ollama.delete("llava:7b")
return response['message']['content']
except Exception as e:
return f"Error analyzing image: {str(e)}"