tool_use #1
93
main.py
93
main.py
@ -22,6 +22,8 @@ import logging
|
|||||||
from flask import Flask, g, jsonify, request, send_from_directory
|
from flask import Flask, g, jsonify, request, send_from_directory
|
||||||
from flask_socketio import SocketIO, emit
|
from flask_socketio import SocketIO, emit
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from werkzeug.utils import secure_filename
|
||||||
|
import base64
|
||||||
|
|
||||||
from models import model_manager
|
from models import model_manager
|
||||||
from tools import DefaultToolManager
|
from tools import DefaultToolManager
|
||||||
@ -38,6 +40,9 @@ CONFIG_FILE = "config.ini"
|
|||||||
processing_thread = None
|
processing_thread = None
|
||||||
processing_thread_started = False
|
processing_thread_started = False
|
||||||
|
|
||||||
|
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
|
||||||
|
MAX_IMAGE_SIZE = 1 * 1024 * 1024 # 1MB
|
||||||
|
|
||||||
|
|
||||||
def create_default_config():
|
def create_default_config():
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@ -564,6 +569,7 @@ def answer_question_tools_api(
|
|||||||
tools=tool_manager.get_tools_for_ollama_dict(),
|
tools=tool_manager.get_tools_for_ollama_dict(),
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
logger.info(f"API Response: {response}")
|
||||||
assistant_message = response["message"]
|
assistant_message = response["message"]
|
||||||
|
|
||||||
conversation_history.append(assistant_message)
|
conversation_history.append(assistant_message)
|
||||||
@ -574,6 +580,7 @@ def answer_question_tools_api(
|
|||||||
tool_args = tool_call["function"]["arguments"]
|
tool_args = tool_call["function"]["arguments"]
|
||||||
tool_response = tool_manager.get_tool(tool_name).execute(tool_args)
|
tool_response = tool_manager.get_tool(tool_name).execute(tool_args)
|
||||||
conversation_history.append({"role": "tool", "content": tool_response})
|
conversation_history.append({"role": "tool", "content": tool_response})
|
||||||
|
logger.info(f"API Tool response: {tool_response}")
|
||||||
else:
|
else:
|
||||||
if "<reply>" in assistant_message["content"].lower():
|
if "<reply>" in assistant_message["content"].lower():
|
||||||
reply_content = re.search(
|
reply_content = re.search(
|
||||||
@ -645,7 +652,15 @@ def process_queries():
|
|||||||
db.commit()
|
db.commit()
|
||||||
logger.info(f"Updated query {query_id} status to PROCESSING")
|
logger.info(f"Updated query {query_id} status to PROCESSING")
|
||||||
|
|
||||||
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
|
# Fetch conversation history if it exists
|
||||||
|
cursor.execute("SELECT conversation_history FROM Queries WHERE id = ?", (query_id,))
|
||||||
|
conversation_history_result = cursor.fetchone()
|
||||||
|
|
||||||
|
if conversation_history_result and conversation_history_result[0]:
|
||||||
|
conversation_history = json.loads(conversation_history_result[0])
|
||||||
|
else:
|
||||||
|
conversation_history = [{"role": "system", "content": ANSWER_QUESTION_PROMPT}]
|
||||||
|
|
||||||
logger.info(f"Starting answer_question_tools_api for query {query_id}")
|
logger.info(f"Starting answer_question_tools_api for query {query_id}")
|
||||||
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
|
final_conversation_history = answer_question_tools_api(user_input, conversation_history)
|
||||||
logger.info(f"Finished answer_question_tools_api for query {query_id}")
|
logger.info(f"Finished answer_question_tools_api for query {query_id}")
|
||||||
@ -660,7 +675,7 @@ def process_queries():
|
|||||||
logger.info(f"Updated query {query_id} status to DONE")
|
logger.info(f"Updated query {query_id} status to DONE")
|
||||||
else:
|
else:
|
||||||
logger.info("No queued queries found. Waiting...")
|
logger.info("No queued queries found. Waiting...")
|
||||||
time.sleep(random.uniform(5, 10)) # Wait for 5 seconds before checking again if no queries are found
|
time.sleep(5) # Wait for 5 seconds before checking again if no queries are found
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error processing query: {str(e)}")
|
logger.exception(f"Error processing query: {str(e)}")
|
||||||
time.sleep(1) # Wait for 1 second before retrying in case of an error
|
time.sleep(1) # Wait for 1 second before retrying in case of an error
|
||||||
@ -726,6 +741,80 @@ def start_processing_thread():
|
|||||||
logger.info("Query processing thread started")
|
logger.info("Query processing thread started")
|
||||||
|
|
||||||
|
|
||||||
|
def allowed_file(filename):
|
||||||
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/v1/query_with_image")
|
||||||
|
def api_query_with_image():
|
||||||
|
"""
|
||||||
|
Submit a new query to the LLM Chat Server with an optional image.
|
||||||
|
|
||||||
|
This endpoint requires authentication via an API key.
|
||||||
|
|
||||||
|
Sample cURL:
|
||||||
|
curl -X POST http://localhost:5001/api/v1/query_with_image \
|
||||||
|
-H "X-API-Key: your-api-key" \
|
||||||
|
-F "message=What's in this image?" \
|
||||||
|
-F "image=@path/to/your/image.jpg"
|
||||||
|
"""
|
||||||
|
if not ENABLE_API_ENDPOINTS:
|
||||||
|
return jsonify({"error": "API endpoints are disabled"}), 404
|
||||||
|
|
||||||
|
api_key = request.headers.get('X-API-Key')
|
||||||
|
if not api_key:
|
||||||
|
return jsonify({"error": "API key is required"}), 401
|
||||||
|
|
||||||
|
api_key_id = validate_api_key(api_key)
|
||||||
|
if not api_key_id:
|
||||||
|
return jsonify({"error": "Invalid API key"}), 401
|
||||||
|
|
||||||
|
if 'message' not in request.form:
|
||||||
|
return jsonify({"error": "Message is required"}), 400
|
||||||
|
|
||||||
|
user_input = request.form['message']
|
||||||
|
query_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
image_base64 = None
|
||||||
|
if 'image' in request.files:
|
||||||
|
file = request.files['image']
|
||||||
|
if file and allowed_file(file.filename):
|
||||||
|
if file.content_length > MAX_IMAGE_SIZE:
|
||||||
|
return jsonify({"error": "Image size exceeds 1MB limit"}), 400
|
||||||
|
|
||||||
|
# Read and encode the image
|
||||||
|
image_data = file.read()
|
||||||
|
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
|
||||||
|
try:
|
||||||
|
db = get_db()
|
||||||
|
cursor = db.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"INSERT INTO Queries (id, ip, query, api_key_id, status) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(query_id, request.remote_addr, user_input, api_key_id, QueryStatus.QUEUED.value)
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Added new query with image to database: {query_id}")
|
||||||
|
|
||||||
|
# If there's an image, add it to the conversation history
|
||||||
|
if image_base64:
|
||||||
|
conversation_history = [
|
||||||
|
{"role": "system", "content": ANSWER_QUESTION_PROMPT},
|
||||||
|
{"role": "user", "content": f"[An image was uploaded with this message] {user_input}"},
|
||||||
|
{"role": "system", "content": f"An image was uploaded. You can analyze it using the analyze_image tool with the following base64 string: {image_base64}"}
|
||||||
|
]
|
||||||
|
cursor.execute(
|
||||||
|
"UPDATE Queries SET conversation_history = ? WHERE id = ?",
|
||||||
|
(json.dumps(conversation_history), query_id)
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return jsonify({"query_id": query_id})
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error during API query processing with image: {str(e)}")
|
||||||
|
return jsonify({"error": str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
# Replace the if __main__ block with this:
|
# Replace the if __main__ block with this:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info("Starting LLM Chat Server")
|
logger.info("Starting LLM Chat Server")
|
||||||
|
197
tools.py
197
tools.py
@ -1,12 +1,20 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
import duckduckgo_search
|
|
||||||
import requests
|
import requests
|
||||||
from markdownify import markdownify as md
|
from markdownify import markdownify as md
|
||||||
from readability.readability import Document
|
from readability.readability import Document
|
||||||
|
import duckduckgo_search
|
||||||
|
import datetime
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
import ollama
|
||||||
|
import os
|
||||||
|
|
||||||
class Tool:
|
class Tool:
|
||||||
def __init__(self, name: str, description: str, arguments: dict, returns: str):
|
def __init__(self, name: str, description: str, arguments: dict, returns: str):
|
||||||
@ -56,6 +64,12 @@ class DefaultToolManager(ToolManager):
|
|||||||
self.add_tool(GetReadablePageContentsTool())
|
self.add_tool(GetReadablePageContentsTool())
|
||||||
self.add_tool(CalculatorTool())
|
self.add_tool(CalculatorTool())
|
||||||
self.add_tool(PythonCodeTool())
|
self.add_tool(PythonCodeTool())
|
||||||
|
self.add_tool(DateTimeTool())
|
||||||
|
self.add_tool(RandomNumberTool())
|
||||||
|
self.add_tool(RegexTool())
|
||||||
|
self.add_tool(Base64Tool())
|
||||||
|
self.add_tool(SimpleChartTool())
|
||||||
|
self.add_tool(LLAVAImageAnalysisTool())
|
||||||
|
|
||||||
|
|
||||||
class SearchTool(Tool):
|
class SearchTool(Tool):
|
||||||
@ -73,8 +87,11 @@ class SearchTool(Tool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def execute(self, arg: dict) -> str:
|
def execute(self, arg: dict) -> str:
|
||||||
res = duckduckgo_search.DDGS().text(arg["query"], max_results=5)
|
try:
|
||||||
return "\n\n".join([f"{r['title']}\n{r['body']}\n{r['href']}" for r in res])
|
res = duckduckgo_search.DDGS().text(arg["query"], max_results=5)
|
||||||
|
return "\n\n".join([f"{r['title']}\n{r['body']}\n{r['href']}" for r in res])
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error searching the web: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def get_readable_page_contents(url: str) -> str:
|
def get_readable_page_contents(url: str) -> str:
|
||||||
@ -180,3 +197,173 @@ class PythonCodeTool(Tool):
|
|||||||
return f"Error executing code: {str(e)}"
|
return f"Error executing code: {str(e)}"
|
||||||
|
|
||||||
return "\n".join([f"{k}:\n{v}" for k, v in result.items()])
|
return "\n".join([f"{k}:\n{v}" for k, v in result.items()])
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeTool(Tool):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
"get_current_datetime",
|
||||||
|
"Get the current date and time",
|
||||||
|
{"type": "object", "properties": {}},
|
||||||
|
"datetime:string"
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, arg: dict) -> str:
|
||||||
|
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
|
||||||
|
class RandomNumberTool(Tool):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
"generate_random_number",
|
||||||
|
"Generate a random number within a given range",
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"min": {"type": "number", "description": "The minimum value"},
|
||||||
|
"max": {"type": "number", "description": "The maximum value"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"random_number:number"
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, arg: dict) -> str:
|
||||||
|
return str(random.uniform(arg["min"], arg["max"]))
|
||||||
|
|
||||||
|
|
||||||
|
class RegexTool(Tool):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
"regex_match",
|
||||||
|
"Perform a regex match on a given text",
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {"type": "string", "description": "The text to search in"},
|
||||||
|
"pattern": {"type": "string", "description": "The regex pattern to match"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"matches:list[string]"
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, arg: dict) -> str:
|
||||||
|
matches = re.findall(arg["pattern"], arg["text"])
|
||||||
|
return json.dumps(matches)
|
||||||
|
|
||||||
|
|
||||||
|
class Base64Tool(Tool):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
"base64_encode_decode",
|
||||||
|
"Encode or decode a string using Base64",
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action": {"type": "string", "enum": ["encode", "decode"], "description": "Whether to encode or decode"},
|
||||||
|
"text": {"type": "string", "description": "The text to encode or decode"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"result:string"
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, arg: dict) -> str:
|
||||||
|
if arg["action"] == "encode":
|
||||||
|
return base64.b64encode(arg["text"].encode()).decode()
|
||||||
|
elif arg["action"] == "decode":
|
||||||
|
return base64.b64decode(arg["text"].encode()).decode()
|
||||||
|
else:
|
||||||
|
return "Invalid action. Use 'encode' or 'decode'."
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleChartTool(Tool):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
"generate_simple_chart",
|
||||||
|
"Generate a simple bar chart image",
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {"type": "array", "items": {"type": "number"}, "description": "List of numerical values for the chart"},
|
||||||
|
"labels": {"type": "array", "items": {"type": "string"}, "description": "Labels for each bar"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"image_base64:string"
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, arg: dict) -> str:
|
||||||
|
data = arg["data"]
|
||||||
|
labels = arg["labels"]
|
||||||
|
|
||||||
|
# Create a simple bar chart
|
||||||
|
width, height = 400, 300
|
||||||
|
img = Image.new('RGB', (width, height), color='white')
|
||||||
|
draw = ImageDraw.Draw(img)
|
||||||
|
|
||||||
|
# Draw bars
|
||||||
|
max_value = max(data)
|
||||||
|
bar_width = width // (len(data) + 1)
|
||||||
|
for i, value in enumerate(data):
|
||||||
|
bar_height = (value / max_value) * (height - 50)
|
||||||
|
left = (i + 1) * bar_width
|
||||||
|
draw.rectangle([left, height - bar_height, left + bar_width, height], fill='blue')
|
||||||
|
|
||||||
|
# Add labels
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
for i, label in enumerate(labels):
|
||||||
|
left = (i + 1) * bar_width + bar_width // 2
|
||||||
|
draw.text((left, height - 20), label, fill='black', anchor='ms', font=font)
|
||||||
|
|
||||||
|
# Convert to base64
|
||||||
|
buffered = BytesIO()
|
||||||
|
img.save(buffered, format="PNG")
|
||||||
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
return img_str
|
||||||
|
|
||||||
|
|
||||||
|
class LLAVAImageAnalysisTool(Tool):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
"analyze_image",
|
||||||
|
"Analyze an image using the LLAVA model",
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"image_base64": {"type": "string", "description": "Base64 encoded image"},
|
||||||
|
"question": {"type": "string", "description": "Question about the image"}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"analysis:string"
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, arg: dict) -> str:
|
||||||
|
try:
|
||||||
|
# Decode base64 image
|
||||||
|
image_data = base64.b64decode(arg["image_base64"])
|
||||||
|
image = Image.open(BytesIO(image_data))
|
||||||
|
|
||||||
|
# Save image to a temporary file
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
||||||
|
image.save(temp_file, format="PNG")
|
||||||
|
temp_file_path = temp_file.name
|
||||||
|
|
||||||
|
# Call LLAVA model
|
||||||
|
response = ollama.chat(
|
||||||
|
model="llava:7b",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": arg["question"],
|
||||||
|
"images": [temp_file_path]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up temporary file
|
||||||
|
os.remove(temp_file_path)
|
||||||
|
|
||||||
|
# Unload LLAVA model
|
||||||
|
ollama.delete("llava:7b")
|
||||||
|
|
||||||
|
return response['message']['content']
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error analyzing image: {str(e)}"
|
Loading…
Reference in New Issue
Block a user