import sys
import argparse
from skimage.io import imread, imsave
from scipy.stats import moment
import numpy as np
from numpy import array, all, uint8
from rich.console import Console

console = Console()

def save_image(data, name, resolution):
    final_image = data.reshape(resolution)
    imsave(f"{name}.png", final_image)

def find_nearest_point(data, target):
    idx = np.array([calc_distance(p, target) for p in data]).argmin()
    return data[idx]


def centroidnp(arr):
    length = arr.shape[0]
    sum_x = np.sum(arr[:, 0])
    sum_y = np.sum(arr[:, 1])
    sum_z = np.sum(arr[:, 2])
    return sum_x/length, sum_y/length, sum_z/length


def calc_distance(x, y):
    return np.absolute(np.linalg.norm(x - y))


def k_means(data, count):
    # Pick n random points to startA
    idx_data = np.unique(data, axis=0)
    index = np.random.choice(idx_data.shape[0], count, replace=False)
    means = idx_data[index]
    data = np.delete(data, index, axis=0)

    distance_delta = 100
    means_distance = 0
    clusters = {}
    with console.status("[bold blue] Finding means...") as status:
        while distance_delta > 5:
            # Initialize cluster map
            clusters = {}
            for m in means:
                clusters[repr(m)] = []

            # Find closest mean to each point
            for point in data:
                closest = find_nearest_point(means, point)
                clusters[repr(closest)].append(point)

            # Find the centroid of each mean
            new_means = []
            previous_distance = means_distance
            means_distance = 0
            for mean in means:
                mean_key = repr(mean)
                # Clean up the results a little bit
                clusters[mean_key] = np.stack(clusters[mean_key])
                # Calculate new mean
                raw_mean = centroidnp(clusters[mean_key])
                nearest_mean_point = find_nearest_point(data, raw_mean)
                means_distance = means_distance + calc_distance(mean, nearest_mean_point)
                new_means.append(nearest_mean_point)
            means_distance = means_distance / float(count)
            distance_delta = abs(previous_distance - means_distance)
            means = np.stack(new_means)

        return means


im = imread("zarin.jpg")
starting_resolution = im.shape
console.log("[blue] Starting with image of size: ", starting_resolution)
raw_pixels = im.reshape(-1, 3)
raw_shape = raw_pixels.shape

colors = np.array([np.array([0,43,54]),
                    np.array([7,54,66]),
                    np.array([88,110,117]),
                    np.array([101,123,131]),
                    np.array([131,148,150]),
                    np.array([147,161,161]),
                    np.array([238,232,213]),
                    np.array([253,246,227]),
                    np.array([181,137,0]),
                    np.array([203,75,22]),
                    np.array([220,50,47]),
                    np.array([211,54,130]),
                    np.array([108,113,196]),
                    np.array([38,139,210]),
                    np.array([42,161,152]),
                    np.array([133,153,0])])

def main():
# Find the colors that most represent the image
    color_means  = k_means(raw_pixels, len(colors))
    console.log("[green] Found cluster centers: ", color_means)

# Remap image to the center points
    console.log("[purple] Re-mapping image")
    output_raw = np.zeros_like(raw_pixels)
    for i in range(len(raw_pixels)):
        output_raw[i] = find_nearest_point(color_means, raw_pixels[i])

# Map means to the colors provided by the user
    pairs = []
    tmp_means = color_means
    for color in colors:
        m = find_nearest_point(tmp_means, color)
        pairs.append((m, color))
        idxs, = np.where(np.all(tmp_means == m, axis=1))
        tmp_means = np.delete(tmp_means, idxs, axis=0)

# Recolor the image
    for pair in pairs:
        idxs, = np.where(np.all(output_raw == pair[0], axis=1))
        output_raw[idxs] = pair[1]
    save_image(output_raw, "final", starting_resolution)


if __name__ == "__main__":
    main()
    pass

sys.exit(0)