diff --git a/main.py b/main.py index 3639b06..615cac1 100644 --- a/main.py +++ b/main.py @@ -1,18 +1,23 @@ import argparse import sys +from pathlib import Path +import multiprocessing +from itertools import product import numpy as np -from numpy import all, array, uint8 from rich.console import Console from scipy.stats import moment from skimage.io import imread, imsave +from PIL import Image +from numba import njit, prange +from scipy.spatial import KDTree console = Console() def save_image(data, name, resolution): final_image = data.reshape(resolution) - imsave(f"{name}.png", final_image) + imsave(name, final_image) def find_nearest_point(data, target): @@ -76,44 +81,37 @@ def k_means(data, count): return means -im = imread("zarin.jpg") -starting_resolution = im.shape -console.log("[blue] Starting with image of size: ", starting_resolution) -raw_pixels = im.reshape(-1, 3) -raw_shape = raw_pixels.shape +def find_closest_points(points1, points2): + # Build a k-d tree from the points in the second array + tree = KDTree(points2) -colors = np.array( - [ - np.array([0, 43, 54]), - np.array([7, 54, 66]), - np.array([88, 110, 117]), - np.array([101, 123, 131]), - np.array([131, 148, 150]), - np.array([147, 161, 161]), - np.array([238, 232, 213]), - np.array([253, 246, 227]), - np.array([181, 137, 0]), - np.array([203, 75, 22]), - np.array([220, 50, 47]), - np.array([211, 54, 130]), - np.array([108, 113, 196]), - np.array([38, 139, 210]), - np.array([42, 161, 152]), - np.array([133, 153, 0]), - ] -) + # Find the closest point in the second array for each element in the first array + closest_points_indices = tree.query(points1)[1] + closest_points = points2[closest_points_indices] + + # Return the result + return closest_points -def main(): +def main(image_name, output_name, colors): + im = imread(image_name) + starting_resolution = im.shape + im_resized = np.array(Image.fromarray(im).resize(size=(150, 150))) + console.log("[blue] Starting with image of size: ", starting_resolution) + raw_pixels_resized = im_resized.reshape(-1, 3) + # Find the colors that most represent the image - color_means = k_means(raw_pixels, len(colors)) + color_means = k_means(raw_pixels_resized, len(colors)) console.log("[green] Found cluster centers: ", color_means) # Remap image to the center points + raw_pixels = im.reshape(-1, 3) + raw_shape = raw_pixels.shape console.log("[purple] Re-mapping image") output_raw = np.zeros_like(raw_pixels) - for i in range(len(raw_pixels)): - output_raw[i] = find_nearest_point(color_means, raw_pixels[i]) + output_raw = find_closest_points(raw_pixels, color_means) + console.log("[purple] Re-mapping image complete (phase 1)") + # Map means to the colors provided by the user pairs = [] @@ -128,7 +126,10 @@ def main(): for pair in pairs: (idxs,) = np.where(np.all(output_raw == pair[0], axis=1)) output_raw[idxs] = pair[1] - save_image(output_raw, "final", starting_resolution) + + if output_name is None: + output_name = f"{input_image.parents[0]}/recolored-{input_image.name}" + save_image(output_raw, output_name, starting_resolution) if __name__ == "__main__": @@ -136,22 +137,66 @@ if __name__ == "__main__": prog="Recolor", description="Recolor changes the color palette of an image to the one provided by the user", ) + parser.add_argument( + "filename", help="Input image" + ) + parser.add_argument( + "-o", "--output", help="Output image destination, defaults to the same path as the input titled [path]/recolored-[name]", dest="outname" + ) color_loader_group = parser.add_mutually_exclusive_group(required=True) color_loader_group.add_argument( "-f", "--file", - description="A file of RGB color values, with one color per line", - dest="fpath", - default=None, + help="A file of RGB color values, with one color per line", + dest="cfpath", ) color_loader_group.add_argument( "-l", "--list", - description="A list of RGB color values. Example: '123,90,89 212,7,0'", + nargs='+', + help="A list of RGB color values. Example: '123,90,89 212,7,0'", dest="clist", - default=None, ) - main() - pass + # Check inputs + args = parser.parse_args() + input_image = Path(args.filename) + if not input_image.exists(): + console.log("[red] Error: input image path does not exist") + sys.exit(-1) + else: + if not input_image.is_file(): + console.log("[red] Error: input image path is not a file") + sys.exit(-1) -sys.exit(0) + # Parse out color + colors = [] + + if args.cfpath is not None: + # Load from file + lines = [] + input_colors = Path(args.cfpath) + with open(input_colors) as f: + lines = [line.rstrip() for line in f] + # Split each line + for line in lines: + vals = [] + sp = line.split(',') + if len(sp) != 3: + console.log("[red] Error: RGB value in file is malformed -- not exactly 3 values", args.cfpath, line) + sys.exit(-2) + for v in sp: + vals.append(int(v)) + colors.append(np.array(vals)) + else: + #Load from list + for line in args.clist: + vals = [] + sp = line.split(',') + if len(sp) != 3: + console.log("[red] Error: RGB value in list is malformed -- not exactly 3 values", line) + sys.exit(-2) + for v in sp: + vals.append(int(v)) + colors.append(np.array(vals)) + + main(input_image, Path(args.outname) if args.outname is not None else None, np.array(colors))