tool_use #1
93
main.py
93
main.py
@ -22,6 +22,8 @@ import logging
|
||||
from flask import Flask, g, jsonify, request, send_from_directory
|
||||
from flask_socketio import SocketIO, emit
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.utils import secure_filename
|
||||
import base64
|
||||
|
||||
from models import model_manager
|
||||
from tools import DefaultToolManager
|
||||
@ -38,6 +40,9 @@ CONFIG_FILE = "config.ini"
|
||||
processing_thread = None
|
||||
processing_thread_started = False
|
||||
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
|
||||
MAX_IMAGE_SIZE = 1 * 1024 * 1024 # 1MB
|
||||
|
||||
|
||||
def create_default_config():
|
||||
config = configparser.ConfigParser()
|
||||
@ -564,6 +569,7 @@ def answer_question_tools_api(
|
||||
tools=tool_manager.get_tools_for_ollama_dict(),
|
||||
stream=False,
|
||||
)
|
||||
logger.info(f"API Response: {response}")
|
||||
assistant_message = response["message"]
|
||||
|
||||
conversation_history.append(assistant_message)
|
||||
@ -574,6 +580,7 @@ def answer_question_tools_api(
|
||||
tool_args = tool_call["function"]["arguments"]
|
||||
tool_response = tool_manager.get_tool(tool_name).execute(tool_args)
|
||||
conversation_history.append({"role": "tool", "content": tool_response})
|
||||
logger.info(f"API Tool response: {tool_response}")
|
||||
else:
|
||||
if "<reply>" in assistant_message["content"].lower():
|
||||
reply_content = re.search(
|
||||
@ -645,7 +652,15 @@ def process_queries():
|
||||
db.commit()
|
||||
logger.info(f"Updated query {query_id} status to PROCESSING")
|
||||
|
||||
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
|
||||
# Fetch conversation history if it exists
|
||||
cursor.execute("SELECT conversation_history FROM Queries WHERE id = ?", (query_id,))
|
||||
conversation_history_result = cursor.fetchone()
|
||||
|
||||
if conversation_history_result and conversation_history_result[0]:
|
||||
conversation_history = json.loads(conversation_history_result[0])
|
||||
else:
|
||||
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
|
||||
|
||||
logger.info(f"Starting answer_question_tools_api for query {query_id}")
|
||||
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
|
||||
logger.info(f"Finished answer_question_tools_api for query {query_id}")
|
||||
@ -660,7 +675,7 @@ def process_queries():
|
||||
logger.info(f"Updated query {query_id} status to DONE")
|
||||
else:
|
||||
logger.info("No queued queries found. Waiting...")
|
||||
time.sleep(random.uniform(5, 10)) # Wait for 5 seconds before checking again if no queries are found
|
||||
time.sleep(5) # Wait for 5 seconds before checking again if no queries are found
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing query: {str(e)}")
|
||||
time.sleep(1) # Wait for 1 second before retrying in case of an error
|
||||
@ -726,6 +741,80 @@ def start_processing_thread():
|
||||
logger.info("Query processing thread started")
|
||||
|
||||
|
||||
def allowed_file(filename):
|
||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
|
||||
@app.post("/api/v1/query_with_image")
|
||||
def api_query_with_image():
|
||||
"""
|
||||
Submit a new query to the LLM Chat Server with an optional image.
|
||||
|
||||
This endpoint requires authentication via an API key.
|
||||
|
||||
Sample cURL:
|
||||
curl -X POST http://localhost:5001/api/v1/query_with_image \
|
||||
-H "X-API-Key: your-api-key" \
|
||||
-F "message=What's in this image?" \
|
||||
-F "image=@path/to/your/image.jpg"
|
||||
"""
|
||||
if not ENABLE_API_ENDPOINTS:
|
||||
return jsonify({"error": "API endpoints are disabled"}), 404
|
||||
|
||||
api_key = request.headers.get('X-API-Key')
|
||||
if not api_key:
|
||||
return jsonify({"error": "API key is required"}), 401
|
||||
|
||||
api_key_id = validate_api_key(api_key)
|
||||
if not api_key_id:
|
||||
return jsonify({"error": "Invalid API key"}), 401
|
||||
|
||||
if 'message' not in request.form:
|
||||
return jsonify({"error": "Message is required"}), 400
|
||||
|
||||
user_input = request.form['message']
|
||||
query_id = str(uuid.uuid4())
|
||||
|
||||
image_base64 = None
|
||||
if 'image' in request.files:
|
||||
file = request.files['image']
|
||||
if file and allowed_file(file.filename):
|
||||
if file.content_length > MAX_IMAGE_SIZE:
|
||||
return jsonify({"error": "Image size exceeds 1MB limit"}), 400
|
||||
|
||||
# Read and encode the image
|
||||
image_data = file.read()
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
try:
|
||||
db = get_db()
|
||||
cursor = db.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO Queries (id, ip, query, api_key_id, status) VALUES (?, ?, ?, ?, ?)",
|
||||
(query_id, request.remote_addr, user_input, api_key_id, QueryStatus.QUEUED.value)
|
||||
)
|
||||
db.commit()
|
||||
logger.info(f"Added new query with image to database: {query_id}")
|
||||
|
||||
# If there's an image, add it to the conversation history
|
||||
if image_base64:
|
||||
conversation_history = [
|
||||
{"role": "system", "content": ANSWER_QUESTION_PROMPT},
|
||||
{"role": "user", "content": f"[An image was uploaded with this message] {user_input}"},
|
||||
{"role": "system", "content": f"An image was uploaded. You can analyze it using the analyze_image tool with the following base64 string: {image_base64}"}
|
||||
]
|
||||
cursor.execute(
|
||||
"UPDATE Queries SET conversation_history = ? WHERE id = ?",
|
||||
(json.dumps(conversation_history), query_id)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return jsonify({"query_id": query_id})
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during API query processing with image: {str(e)}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
|
||||
# Replace the if __main__ block with this:
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting LLM Chat Server")
|
||||
|
197
tools.py
197
tools.py
@ -1,12 +1,20 @@
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import duckduckgo_search
|
||||
import json
|
||||
import requests
|
||||
from markdownify import markdownify as md
|
||||
from readability.readability import Document
|
||||
|
||||
import duckduckgo_search
|
||||
import datetime
|
||||
import random
|
||||
import math
|
||||
import re
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import ollama
|
||||
import os
|
||||
|
||||
class Tool:
|
||||
def __init__(self, name: str, description: str, arguments: dict, returns: str):
|
||||
@ -56,6 +64,12 @@ class DefaultToolManager(ToolManager):
|
||||
self.add_tool(GetReadablePageContentsTool())
|
||||
self.add_tool(CalculatorTool())
|
||||
self.add_tool(PythonCodeTool())
|
||||
self.add_tool(DateTimeTool())
|
||||
self.add_tool(RandomNumberTool())
|
||||
self.add_tool(RegexTool())
|
||||
self.add_tool(Base64Tool())
|
||||
self.add_tool(SimpleChartTool())
|
||||
self.add_tool(LLAVAImageAnalysisTool())
|
||||
|
||||
|
||||
class SearchTool(Tool):
|
||||
@ -73,8 +87,11 @@ class SearchTool(Tool):
|
||||
)
|
||||
|
||||
def execute(self, arg: dict) -> str:
|
||||
res = duckduckgo_search.DDGS().text(arg["query"], max_results=5)
|
||||
return "\n\n".join([f"{r['title']}\n{r['body']}\n{r['href']}" for r in res])
|
||||
try:
|
||||
res = duckduckgo_search.DDGS().text(arg["query"], max_results=5)
|
||||
return "\n\n".join([f"{r['title']}\n{r['body']}\n{r['href']}" for r in res])
|
||||
except Exception as e:
|
||||
return f"Error searching the web: {str(e)}"
|
||||
|
||||
|
||||
def get_readable_page_contents(url: str) -> str:
|
||||
@ -180,3 +197,173 @@ class PythonCodeTool(Tool):
|
||||
return f"Error executing code: {str(e)}"
|
||||
|
||||
return "\n".join([f"{k}:\n{v}" for k, v in result.items()])
|
||||
|
||||
|
||||
class DateTimeTool(Tool):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"get_current_datetime",
|
||||
"Get the current date and time",
|
||||
{"type": "object", "properties": {}},
|
||||
"datetime:string"
|
||||
)
|
||||
|
||||
def execute(self, arg: dict) -> str:
|
||||
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
class RandomNumberTool(Tool):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"generate_random_number",
|
||||
"Generate a random number within a given range",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"min": {"type": "number", "description": "The minimum value"},
|
||||
"max": {"type": "number", "description": "The maximum value"}
|
||||
}
|
||||
},
|
||||
"random_number:number"
|
||||
)
|
||||
|
||||
def execute(self, arg: dict) -> str:
|
||||
return str(random.uniform(arg["min"], arg["max"]))
|
||||
|
||||
|
||||
class RegexTool(Tool):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"regex_match",
|
||||
"Perform a regex match on a given text",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "The text to search in"},
|
||||
"pattern": {"type": "string", "description": "The regex pattern to match"}
|
||||
}
|
||||
},
|
||||
"matches:list[string]"
|
||||
)
|
||||
|
||||
def execute(self, arg: dict) -> str:
|
||||
matches = re.findall(arg["pattern"], arg["text"])
|
||||
return json.dumps(matches)
|
||||
|
||||
|
||||
class Base64Tool(Tool):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"base64_encode_decode",
|
||||
"Encode or decode a string using Base64",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {"type": "string", "enum": ["encode", "decode"], "description": "Whether to encode or decode"},
|
||||
"text": {"type": "string", "description": "The text to encode or decode"}
|
||||
}
|
||||
},
|
||||
"result:string"
|
||||
)
|
||||
|
||||
def execute(self, arg: dict) -> str:
|
||||
if arg["action"] == "encode":
|
||||
return base64.b64encode(arg["text"].encode()).decode()
|
||||
elif arg["action"] == "decode":
|
||||
return base64.b64decode(arg["text"].encode()).decode()
|
||||
else:
|
||||
return "Invalid action. Use 'encode' or 'decode'."
|
||||
|
||||
|
||||
class SimpleChartTool(Tool):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"generate_simple_chart",
|
||||
"Generate a simple bar chart image",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {"type": "array", "items": {"type": "number"}, "description": "List of numerical values for the chart"},
|
||||
"labels": {"type": "array", "items": {"type": "string"}, "description": "Labels for each bar"}
|
||||
}
|
||||
},
|
||||
"image_base64:string"
|
||||
)
|
||||
|
||||
def execute(self, arg: dict) -> str:
|
||||
data = arg["data"]
|
||||
labels = arg["labels"]
|
||||
|
||||
# Create a simple bar chart
|
||||
width, height = 400, 300
|
||||
img = Image.new('RGB', (width, height), color='white')
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Draw bars
|
||||
max_value = max(data)
|
||||
bar_width = width // (len(data) + 1)
|
||||
for i, value in enumerate(data):
|
||||
bar_height = (value / max_value) * (height - 50)
|
||||
left = (i + 1) * bar_width
|
||||
draw.rectangle([left, height - bar_height, left + bar_width, height], fill='blue')
|
||||
|
||||
# Add labels
|
||||
font = ImageFont.load_default()
|
||||
for i, label in enumerate(labels):
|
||||
left = (i + 1) * bar_width + bar_width // 2
|
||||
draw.text((left, height - 20), label, fill='black', anchor='ms', font=font)
|
||||
|
||||
# Convert to base64
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
return img_str
|
||||
|
||||
|
||||
class LLAVAImageAnalysisTool(Tool):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"analyze_image",
|
||||
"Analyze an image using the LLAVA model",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_base64": {"type": "string", "description": "Base64 encoded image"},
|
||||
"question": {"type": "string", "description": "Question about the image"}
|
||||
}
|
||||
},
|
||||
"analysis:string"
|
||||
)
|
||||
|
||||
def execute(self, arg: dict) -> str:
|
||||
try:
|
||||
# Decode base64 image
|
||||
image_data = base64.b64decode(arg["image_base64"])
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
# Save image to a temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
||||
image.save(temp_file, format="PNG")
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
# Call LLAVA model
|
||||
response = ollama.chat(
|
||||
model="llava:7b",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": arg["question"],
|
||||
"images": [temp_file_path]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Clean up temporary file
|
||||
os.remove(temp_file_path)
|
||||
|
||||
# Unload LLAVA model
|
||||
ollama.delete("llava:7b")
|
||||
|
||||
return response['message']['content']
|
||||
except Exception as e:
|
||||
return f"Error analyzing image: {str(e)}"
|
Loading…
Reference in New Issue
Block a user