import argparse import sys from pathlib import Path import multiprocessing from itertools import product import numpy as np from rich.console import Console from scipy.stats import moment from skimage.io import imread, imsave from PIL import Image from numba import njit, prange from scipy.spatial import KDTree console = Console() def save_image(data, name, resolution): final_image = data.reshape(resolution) imsave(name, final_image) def find_nearest_point(data, target): idx = np.array([calc_distance(p, target) for p in data]).argmin() return data[idx] def centroidnp(arr): length = arr.shape[0] sum_x = np.sum(arr[:, 0]) sum_y = np.sum(arr[:, 1]) sum_z = np.sum(arr[:, 2]) return sum_x / length, sum_y / length, sum_z / length def calc_distance(x, y): return np.absolute(np.linalg.norm(x - y)) def k_means(data, count): # Pick n random points to startA idx_data = np.unique(data, axis=0) index = np.random.choice(idx_data.shape[0], count, replace=False) means = idx_data[index] data = np.delete(data, index, axis=0) distance_delta = 100 means_distance = 0 clusters = {} with console.status("[bold blue] Finding means...") as status: while distance_delta > 5: # Initialize cluster map clusters = {} for m in means: clusters[repr(m)] = [] # Find closest mean to each point for point in data: closest = find_nearest_point(means, point) clusters[repr(closest)].append(point) # Find the centroid of each mean new_means = [] previous_distance = means_distance means_distance = 0 for mean in means: mean_key = repr(mean) # Clean up the results a little bit clusters[mean_key] = np.stack(clusters[mean_key]) # Calculate new mean raw_mean = centroidnp(clusters[mean_key]) nearest_mean_point = find_nearest_point(data, raw_mean) means_distance = means_distance + calc_distance( mean, nearest_mean_point ) new_means.append(nearest_mean_point) means_distance = means_distance / float(count) distance_delta = abs(previous_distance - means_distance) means = np.stack(new_means) return means def find_closest_points(points1, points2): # Build a k-d tree from the points in the second array tree = KDTree(points2) # Find the closest point in the second array for each element in the first array closest_points_indices = tree.query(points1)[1] closest_points = points2[closest_points_indices] # Return the result return closest_points def main(image_name, output_name, colors): im = imread(image_name) starting_resolution = im.shape im_resized = np.array(Image.fromarray(im).resize(size=(150, 150))) console.log("[blue] Starting with image of size: ", starting_resolution) raw_pixels_resized = im_resized.reshape(-1, 3) # Find the colors that most represent the image color_means = k_means(raw_pixels_resized, len(colors)) console.log("[green] Found cluster centers: ", color_means) # Remap image to the center points raw_pixels = im.reshape(-1, 3) raw_shape = raw_pixels.shape console.log("[purple] Re-mapping image") output_raw = np.zeros_like(raw_pixels) output_raw = find_closest_points(raw_pixels, color_means) console.log("[purple] Re-mapping image complete (phase 1)") # Map means to the colors provided by the user pairs = [] tmp_means = color_means for color in colors: m = find_nearest_point(tmp_means, color) pairs.append((m, color)) (idxs,) = np.where(np.all(tmp_means == m, axis=1)) tmp_means = np.delete(tmp_means, idxs, axis=0) # Recolor the image for pair in pairs: (idxs,) = np.where(np.all(output_raw == pair[0], axis=1)) output_raw[idxs] = pair[1] if output_name is None: output_name = f"{input_image.parents[0]}/recolored-{input_image.name}" save_image(output_raw, output_name, starting_resolution) if __name__ == "__main__": parser = argparse.ArgumentParser( prog="Recolor", description="Recolor changes the color palette of an image to the one provided by the user", ) parser.add_argument( "filename", help="Input image" ) parser.add_argument( "-o", "--output", help="Output image destination, defaults to the same path as the input titled [path]/recolored-[name]", dest="outname" ) color_loader_group = parser.add_mutually_exclusive_group(required=True) color_loader_group.add_argument( "-f", "--file", help="A file of RGB color values, with one color per line", dest="cfpath", ) color_loader_group.add_argument( "-l", "--list", nargs='+', help="A list of RGB color values. Example: '123,90,89 212,7,0'", dest="clist", ) # Check inputs args = parser.parse_args() input_image = Path(args.filename) if not input_image.exists(): console.log("[red] Error: input image path does not exist") sys.exit(-1) else: if not input_image.is_file(): console.log("[red] Error: input image path is not a file") sys.exit(-1) # Parse out color colors = [] if args.cfpath is not None: # Load from file lines = [] input_colors = Path(args.cfpath) with open(input_colors) as f: lines = [line.rstrip() for line in f] # Split each line for line in lines: vals = [] sp = line.split(',') if len(sp) != 3: console.log("[red] Error: RGB value in file is malformed -- not exactly 3 values", args.cfpath, line) sys.exit(-2) for v in sp: vals.append(int(v)) colors.append(np.array(vals)) else: #Load from list for line in args.clist: vals = [] sp = line.split(',') if len(sp) != 3: console.log("[red] Error: RGB value in list is malformed -- not exactly 3 values", line) sys.exit(-2) for v in sp: vals.append(int(v)) colors.append(np.array(vals)) main(input_image, Path(args.outname) if args.outname is not None else None, np.array(colors))