diff --git a/main.py b/main.py index 3203faa..3639b06 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,20 @@ -import sys import argparse -from skimage.io import imread, imsave -from scipy.stats import moment +import sys + import numpy as np -from numpy import array, all, uint8 +from numpy import all, array, uint8 from rich.console import Console +from scipy.stats import moment +from skimage.io import imread, imsave console = Console() + def save_image(data, name, resolution): final_image = data.reshape(resolution) imsave(f"{name}.png", final_image) + def find_nearest_point(data, target): idx = np.array([calc_distance(p, target) for p in data]).argmin() return data[idx] @@ -22,7 +25,7 @@ def centroidnp(arr): sum_x = np.sum(arr[:, 0]) sum_y = np.sum(arr[:, 1]) sum_z = np.sum(arr[:, 2]) - return sum_x/length, sum_y/length, sum_z/length + return sum_x / length, sum_y / length, sum_z / length def calc_distance(x, y): @@ -62,7 +65,9 @@ def k_means(data, count): # Calculate new mean raw_mean = centroidnp(clusters[mean_key]) nearest_mean_point = find_nearest_point(data, raw_mean) - means_distance = means_distance + calc_distance(mean, nearest_mean_point) + means_distance = means_distance + calc_distance( + mean, nearest_mean_point + ) new_means.append(nearest_mean_point) means_distance = means_distance / float(count) distance_delta = abs(previous_distance - means_distance) @@ -77,51 +82,75 @@ console.log("[blue] Starting with image of size: ", starting_resolution) raw_pixels = im.reshape(-1, 3) raw_shape = raw_pixels.shape -colors = np.array([np.array([0,43,54]), - np.array([7,54,66]), - np.array([88,110,117]), - np.array([101,123,131]), - np.array([131,148,150]), - np.array([147,161,161]), - np.array([238,232,213]), - np.array([253,246,227]), - np.array([181,137,0]), - np.array([203,75,22]), - np.array([220,50,47]), - np.array([211,54,130]), - np.array([108,113,196]), - np.array([38,139,210]), - np.array([42,161,152]), - np.array([133,153,0])]) +colors = np.array( + [ + np.array([0, 43, 54]), + np.array([7, 54, 66]), + np.array([88, 110, 117]), + np.array([101, 123, 131]), + np.array([131, 148, 150]), + np.array([147, 161, 161]), + np.array([238, 232, 213]), + np.array([253, 246, 227]), + np.array([181, 137, 0]), + np.array([203, 75, 22]), + np.array([220, 50, 47]), + np.array([211, 54, 130]), + np.array([108, 113, 196]), + np.array([38, 139, 210]), + np.array([42, 161, 152]), + np.array([133, 153, 0]), + ] +) + def main(): -# Find the colors that most represent the image - color_means = k_means(raw_pixels, len(colors)) + # Find the colors that most represent the image + color_means = k_means(raw_pixels, len(colors)) console.log("[green] Found cluster centers: ", color_means) -# Remap image to the center points + # Remap image to the center points console.log("[purple] Re-mapping image") output_raw = np.zeros_like(raw_pixels) for i in range(len(raw_pixels)): output_raw[i] = find_nearest_point(color_means, raw_pixels[i]) -# Map means to the colors provided by the user + # Map means to the colors provided by the user pairs = [] tmp_means = color_means for color in colors: m = find_nearest_point(tmp_means, color) pairs.append((m, color)) - idxs, = np.where(np.all(tmp_means == m, axis=1)) + (idxs,) = np.where(np.all(tmp_means == m, axis=1)) tmp_means = np.delete(tmp_means, idxs, axis=0) -# Recolor the image + # Recolor the image for pair in pairs: - idxs, = np.where(np.all(output_raw == pair[0], axis=1)) + (idxs,) = np.where(np.all(output_raw == pair[0], axis=1)) output_raw[idxs] = pair[1] save_image(output_raw, "final", starting_resolution) if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="Recolor", + description="Recolor changes the color palette of an image to the one provided by the user", + ) + color_loader_group = parser.add_mutually_exclusive_group(required=True) + color_loader_group.add_argument( + "-f", + "--file", + description="A file of RGB color values, with one color per line", + dest="fpath", + default=None, + ) + color_loader_group.add_argument( + "-l", + "--list", + description="A list of RGB color values. Example: '123,90,89 212,7,0'", + dest="clist", + default=None, + ) main() pass