from flask import Flask, request, jsonify, render_template, redirect, url_for, flash, session, send_from_directory
from werkzeug.utils import secure_filename
from models import Session as DBSession, Photo
from config import load_or_create_config
import os
from datetime import datetime
from PIL import Image, ExifTags
from apscheduler.schedulers.background import BackgroundScheduler
import random
from colorthief import ColorThief
import colorsys
from steganography import embed_message, extract_message
import hashlib
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import toml
import threading
import time
import atexit
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
import secrets

app = Flask(__name__)
app.secret_key = os.urandom(24)
config = load_or_create_config()

UPLOAD_FOLDER = config['directories']['upload']
THUMBNAIL_FOLDER = config['directories']['thumbnail']
ALLOWED_EXTENSIONS = {'jpg', 'jpeg', 'png', 'gif'}
THUMBNAIL_SIZES = [256, 512, 768, 1024, 1536, 2048]

# Create upload and thumbnail directories if they don't exist
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(THUMBNAIL_FOLDER, exist_ok=True)

app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['THUMBNAIL_FOLDER'] = THUMBNAIL_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 80 * 1024 * 1024  # 80MB limit

scheduler = BackgroundScheduler()
scheduler.start()

DEFAULT_CONFIG = {
    'server': {
        'host': '0.0.0.0',
        'port': 5000
    },
    'directories': {
        'upload': 'uploads',
        'thumbnail': 'thumbnails'
    },
    'admin': {
        'password': 'changeme'  # Default password
    },
    'appearance': {
        'accent_color': '#007bff'
    }
}

def merge_configs(default, user):
    """Recursively merge user config with default config"""
    result = default.copy()
    for key, value in user.items():
        if key in result and isinstance(result[key], dict) and isinstance(value, dict):
            result[key] = merge_configs(result[key], value)
        else:
            result[key] = value
    return result

class ConfigFileHandler(FileSystemEventHandler):
    def on_modified(self, event):
        if event.src_path.endswith('config.toml'):
            global config
            try:
                new_config = load_or_create_config()
                config.update(new_config)
                app.logger.info("Configuration reloaded successfully")
            except Exception as e:
                app.logger.error(f"Error reloading configuration: {e}")

def load_or_create_config():
    config_path = 'config.toml'
    
    try:
        if os.path.exists(config_path):
            with open(config_path, 'r') as f:
                user_config = toml.load(f)
        else:
            user_config = {}
            
        # Merge with defaults
        final_config = merge_configs(DEFAULT_CONFIG, user_config)
        
        # Save complete config back to file
        with open(config_path, 'w') as f:
            toml.dump(final_config, f)
            
        return final_config
        
    except Exception as e:
        app.logger.error(f"Error loading config: {e}")
        return DEFAULT_CONFIG.copy()

def start_config_watcher():
    observer = Observer()
    observer.schedule(ConfigFileHandler(), path='.', recursive=False)
    observer.start()
    
    # Register cleanup on app shutdown
    def cleanup():
        observer.stop()
        observer.join()
    
    atexit.register(cleanup)

start_config_watcher()

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

def get_highlight_color(image_path):
    color_thief = ColorThief(image_path)
    palette = color_thief.get_palette(color_count=6, quality=1)
    
    # Convert RGB to HSV and find the color with the highest saturation
    highlight_color = max(palette, key=lambda rgb: colorsys.rgb_to_hsv(*rgb)[1])
    
    return '#{:02x}{:02x}{:02x}'.format(*highlight_color)

def generate_thumbnails(filename):
    original_path = os.path.join(UPLOAD_FOLDER, filename)
    thumb_dir = os.path.join(THUMBNAIL_FOLDER, os.path.splitext(filename)[0])
    os.makedirs(thumb_dir, exist_ok=True)

    for size in THUMBNAIL_SIZES:
        thumb_path = os.path.join(thumb_dir, f"{size}_{filename}")
        if not os.path.exists(thumb_path):
            with Image.open(original_path) as img:
                # Extract EXIF data
                exif_data = None
                if "exif" in img.info:
                    exif_data = img.info["exif"]

                # Resize image
                img.thumbnail((size, size), Image.LANCZOS)

                # Save image with EXIF data
                if exif_data:
                    img.save(thumb_path, exif=exif_data, optimize=True, quality=85)
                else:
                    img.save(thumb_path, optimize=True, quality=85)

def generate_all_thumbnails():
    for filename in os.listdir(UPLOAD_FOLDER):
        if allowed_file(filename):
            generate_thumbnails(filename)

scheduler.add_job(generate_all_thumbnails, 'interval', minutes=5)
scheduler.add_job(generate_all_thumbnails, 'date', run_date=datetime.now())  # Run once at startup

@app.route('/')
def index():
    return render_template('index.html', accent_color=config['appearance']['accent_color'])

@app.route('/api/images')
def get_images():
    page = int(request.args.get('page', 1))
    per_page = 20
    db_session = DBSession()
    photos = db_session.query(Photo).order_by(Photo.date_taken.desc()).offset((page - 1) * per_page).limit(per_page).all()
    
    images = []
    for photo in photos:
        factor = random.randint(2, 3)
        if photo.height < 4000 or photo.width < 4000:
            factor = 1
        if photo.orientation == 6 or photo.orientation == 8:
            width, height = photo.height, photo.width
        else:
            width, height = photo.width, photo.height
        images.append({
            'imgSrc': f'/static/thumbnails/{os.path.splitext(photo.input_filename)[0]}/1536_{photo.input_filename}',
            'width': width / factor,
            'height': height / factor,
            'caption': photo.input_filename,
            'date': photo.date_taken.strftime('%y %m %d'),
            'technicalInfo': f"{photo.focal_length}MM | F/{photo.aperture} | {photo.shutter_speed} | ISO{photo.iso}",
            'highlightColor': photo.highlight_color
        })
    
    has_more = db_session.query(Photo).count() > page * per_page
    db_session.close()
    
    return jsonify({'images': images, 'hasMore': has_more})

@app.route('/admin')
def admin():
    if 'logged_in' not in session:
        return redirect(url_for('admin_login'))
    
    db_session = DBSession()
    photos = db_session.query(Photo).order_by(Photo.date_taken.desc()).all()
    db_session.close()
    
    return render_template('admin.html', photos=photos, accent_color=config['appearance']['accent_color'])

@app.route('/admin/login', methods=['GET', 'POST'])
def admin_login():
    if request.method == 'POST':
        if request.form['password'] == config['admin']['password']:
            session['logged_in'] = True
            return redirect(url_for('admin'))
        else:
            flash('Invalid password')
    return render_template('admin_login.html', accent_color=config['appearance']['accent_color'])

@app.route('/admin/upload', methods=['POST'])
def admin_upload():
    if 'logged_in' not in session:
        return redirect(url_for('admin_login'))
    
    if 'file' not in request.files:
        flash('No file part')
        return redirect(url_for('admin'))
    
    file = request.files['file']
    if file.filename == '':
        flash('No selected file')
        return redirect(url_for('admin'))
    
    if file and allowed_file(file.filename):
        filename = secure_filename(file.filename)
        file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(file_path)

        # Extract EXIF data
        exif = None
        exifraw = None
        with Image.open(file_path) as img:
            exifraw = img.info['exif']
            width, height = img.size
            exif = {
                ExifTags.TAGS[k]: v
                for k, v in img._getexif().items()
                if k in ExifTags.TAGS
            }
        
        # Generate a unique key for the image
        unique_key = hashlib.sha256(f"{filename}{datetime.now().isoformat()}".encode()).hexdigest()[:16]
        
        # Embed the unique key into the image
        try:
            embed_message(file_path, unique_key, exifraw)
        except ValueError as e:
            flash(f"Error embedding key: {str(e)}")
            os.remove(file_path)
            return redirect(url_for('admin'))

        
        # Generate thumbnails
        generate_thumbnails(filename)
        
        # Get image dimensions
        with Image.open(file_path) as img:
            width, height = img.size

        exposure_time = exif['ExposureTime']
        if isinstance(exposure_time, tuple):
            exposure_fraction = f"{exposure_time[0]}/{exposure_time[1]}"
        else:
            exposure_fraction = f"1/{int(1/float(exposure_time))}"

        # Create database entry
        db_session = DBSession()
        new_photo = Photo(
            input_filename=filename,
            thumbnail_filename=f"{os.path.splitext(filename)[0]}/256_{filename}",
            focal_length=str(exif.get('FocalLengthIn35mmFilm', exif.get('FocalLength', ''))),
            aperture=str(exif.get('FNumber', '')),
            shutter_speed=exposure_fraction,
            date_taken=datetime.strptime(str(exif.get('DateTime', '1970:01:01 00:00:00')), '%Y:%m:%d %H:%M:%S'),
            iso=int(exif.get('ISOSpeedRatings', 0)),
            orientation=int(exif.get('Orientation', 1)),
            width=width,
            height=height,
            highlight_color=get_highlight_color(THUMBNAIL_FOLDER + f"/{os.path.splitext(filename)[0]}/256_{filename}"),
            unique_key=unique_key
        )
        db_session.add(new_photo)
        db_session.commit()
        db_session.close()
        
        flash('File uploaded successfully')
        return redirect(url_for('admin'))
    
    flash('Invalid file type')
    return redirect(url_for('admin'))

@app.route('/admin/logout')
def admin_logout():
    session.pop('logged_in', None)
    flash('You have been logged out')
    return redirect(url_for('admin_login'))

@app.route('/static/thumbnails/<path:filename>')
def serve_thumbnail(filename):
    return send_from_directory(THUMBNAIL_FOLDER, filename)

@app.route('/admin/update_photo/<int:photo_id>', methods=['POST'])
def update_photo(photo_id):
    if 'logged_in' not in session:
        return jsonify({'success': False, 'error': 'Not logged in'}), 401

    data = request.json
    db_session = DBSession()
    photo = db_session.query(Photo).get(photo_id)

    if not photo:
        db_session.close()
        return jsonify({'success': False, 'error': 'Photo not found'}), 404

    try:
        for field, value in data.items():
            if field == 'date_taken':
                value = datetime.strptime(value, '%Y-%m-%d %H:%M:%S')
            elif field == 'iso':
                value = int(value)
            setattr(photo, field, value)

        db_session.commit()
        db_session.close()
        return jsonify({'success': True})
    except Exception as e:
        db_session.rollback()
        db_session.close()
        return jsonify({'success': False, 'error': str(e)}), 500

@app.route('/admin/delete_photo/<int:photo_id>', methods=['POST'])
def delete_photo(photo_id):
    if 'logged_in' not in session:
        return jsonify({'success': False, 'error': 'Not logged in'}), 401

    db_session = DBSession()
    photo = db_session.query(Photo).get(photo_id)

    if not photo:
        db_session.close()
        return jsonify({'success': False, 'error': 'Photo not found'}), 404

    try:
        # Delete the original file
        original_path = os.path.join(UPLOAD_FOLDER, photo.input_filename)
        if os.path.exists(original_path):
            os.remove(original_path)

        # Delete the thumbnail directory
        thumb_dir = os.path.join(THUMBNAIL_FOLDER, os.path.splitext(photo.input_filename)[0])
        if os.path.exists(thumb_dir):
            for thumb_file in os.listdir(thumb_dir):
                os.remove(os.path.join(thumb_dir, thumb_file))
            os.rmdir(thumb_dir)

        # Delete the database entry
        db_session.delete(photo)
        db_session.commit()
        db_session.close()

        return jsonify({'success': True})
    except Exception as e:
        db_session.rollback()
        db_session.close()
        return jsonify({'success': False, 'error': str(e)}), 500

@app.route('/verify/<filename>', methods=['GET'])
def verify_image(filename):
    file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
    if not os.path.exists(file_path):
        return jsonify({'verified': False, 'error': 'Image not found'})
    
    try:
        extracted_key = extract_message(file_path, 16)
        db_session = DBSession()
        photo = db_session.query(Photo).filter_by(input_filename=filename).first()
        db_session.close()
        
        if photo and photo.unique_key == extracted_key:
            return jsonify({'verified': True, 'message': 'Image ownership verified'})
        else:
            return jsonify({'verified': False, 'message': 'Image ownership could not be verified'})
    except Exception as e:
        return jsonify({'verified': False, 'error': str(e)})

limiter = Limiter(
    app=app,
    key_func=get_remote_address,
    default_limits=["100 per minute"],
    storage_uri="memory://"
)

# Generate a strong secret key at startup
app.secret_key = secrets.token_hex(32)

# Add security headers middleware
@app.after_request
def add_security_headers(response):
    response.headers['X-Content-Type-Options'] = 'nosniff'
    response.headers['X-Frame-Options'] = 'SAMEORIGIN'
    response.headers['X-XSS-Protection'] = '1; mode=block'
    response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
    response.headers['Content-Security-Policy'] = "default-src 'self'; img-src 'self' data:; style-src 'self' 'unsafe-inline';"
    return response

# Add rate limiting to sensitive endpoints
@app.route('/admin/login', methods=['POST'])
@limiter.limit("5 per minute")
def admin_login():
    if request.method == 'POST':
        if request.form['password'] == config['admin']['password']:
            session['logged_in'] = True
            return redirect(url_for('admin'))
        else:
            flash('Invalid password')
    return render_template('admin_login.html', accent_color=config['appearance']['accent_color'])

@app.route('/admin/upload', methods=['POST'])
@limiter.limit("10 per minute")
def admin_upload():
    if 'logged_in' not in session:
        return redirect(url_for('admin_login'))
    
    if 'file' not in request.files:
        flash('No file part')
        return redirect(url_for('admin'))
    
    file = request.files['file']
    if file.filename == '':
        flash('No selected file')
        return redirect(url_for('admin'))
    
    if file and allowed_file(file.filename):
        filename = secure_filename(file.filename)
        file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(file_path)

        # Extract EXIF data
        exif = None
        exifraw = None
        with Image.open(file_path) as img:
            exifraw = img.info['exif']
            width, height = img.size
            exif = {
                ExifTags.TAGS[k]: v
                for k, v in img._getexif().items()
                if k in ExifTags.TAGS
            }
        
        # Generate a unique key for the image
        unique_key = hashlib.sha256(f"{filename}{datetime.now().isoformat()}".encode()).hexdigest()[:16]
        
        # Embed the unique key into the image
        try:
            embed_message(file_path, unique_key, exifraw)
        except ValueError as e:
            flash(f"Error embedding key: {str(e)}")
            os.remove(file_path)
            return redirect(url_for('admin'))

        
        # Generate thumbnails
        generate_thumbnails(filename)
        
        # Get image dimensions
        with Image.open(file_path) as img:
            width, height = img.size

        exposure_time = exif['ExposureTime']
        if isinstance(exposure_time, tuple):
            exposure_fraction = f"{exposure_time[0]}/{exposure_time[1]}"
        else:
            exposure_fraction = f"1/{int(1/float(exposure_time))}"

        # Create database entry
        db_session = DBSession()
        new_photo = Photo(
            input_filename=filename,
            thumbnail_filename=f"{os.path.splitext(filename)[0]}/256_{filename}",
            focal_length=str(exif.get('FocalLengthIn35mmFilm', exif.get('FocalLength', ''))),
            aperture=str(exif.get('FNumber', '')),
            shutter_speed=exposure_fraction,
            date_taken=datetime.strptime(str(exif.get('DateTime', '1970:01:01 00:00:00')), '%Y:%m:%d %H:%M:%S'),
            iso=int(exif.get('ISOSpeedRatings', 0)),
            orientation=int(exif.get('Orientation', 1)),
            width=width,
            height=height,
            highlight_color=get_highlight_color(THUMBNAIL_FOLDER + f"/{os.path.splitext(filename)[0]}/256_{filename}"),
            unique_key=unique_key
        )
        db_session.add(new_photo)
        db_session.commit()
        db_session.close()
        
        flash('File uploaded successfully')
        return redirect(url_for('admin'))
    
    flash('Invalid file type')
    return redirect(url_for('admin'))

if __name__ == '__main__':
    app.run(
        debug=True, 
        port=config['server']['port'], 
        host=config['server']['host']
    )