import argparse import sys import numpy as np from numpy import all, array, uint8 from rich.console import Console from scipy.stats import moment from skimage.io import imread, imsave console = Console() def save_image(data, name, resolution): final_image = data.reshape(resolution) imsave(f"{name}.png", final_image) def find_nearest_point(data, target): idx = np.array([calc_distance(p, target) for p in data]).argmin() return data[idx] def centroidnp(arr): length = arr.shape[0] sum_x = np.sum(arr[:, 0]) sum_y = np.sum(arr[:, 1]) sum_z = np.sum(arr[:, 2]) return sum_x / length, sum_y / length, sum_z / length def calc_distance(x, y): return np.absolute(np.linalg.norm(x - y)) def k_means(data, count): # Pick n random points to startA idx_data = np.unique(data, axis=0) index = np.random.choice(idx_data.shape[0], count, replace=False) means = idx_data[index] data = np.delete(data, index, axis=0) distance_delta = 100 means_distance = 0 clusters = {} with console.status("[bold blue] Finding means...") as status: while distance_delta > 5: # Initialize cluster map clusters = {} for m in means: clusters[repr(m)] = [] # Find closest mean to each point for point in data: closest = find_nearest_point(means, point) clusters[repr(closest)].append(point) # Find the centroid of each mean new_means = [] previous_distance = means_distance means_distance = 0 for mean in means: mean_key = repr(mean) # Clean up the results a little bit clusters[mean_key] = np.stack(clusters[mean_key]) # Calculate new mean raw_mean = centroidnp(clusters[mean_key]) nearest_mean_point = find_nearest_point(data, raw_mean) means_distance = means_distance + calc_distance( mean, nearest_mean_point ) new_means.append(nearest_mean_point) means_distance = means_distance / float(count) distance_delta = abs(previous_distance - means_distance) means = np.stack(new_means) return means im = imread("zarin.jpg") starting_resolution = im.shape console.log("[blue] Starting with image of size: ", starting_resolution) raw_pixels = im.reshape(-1, 3) raw_shape = raw_pixels.shape colors = np.array( [ np.array([0, 43, 54]), np.array([7, 54, 66]), np.array([88, 110, 117]), np.array([101, 123, 131]), np.array([131, 148, 150]), np.array([147, 161, 161]), np.array([238, 232, 213]), np.array([253, 246, 227]), np.array([181, 137, 0]), np.array([203, 75, 22]), np.array([220, 50, 47]), np.array([211, 54, 130]), np.array([108, 113, 196]), np.array([38, 139, 210]), np.array([42, 161, 152]), np.array([133, 153, 0]), ] ) def main(): # Find the colors that most represent the image color_means = k_means(raw_pixels, len(colors)) console.log("[green] Found cluster centers: ", color_means) # Remap image to the center points console.log("[purple] Re-mapping image") output_raw = np.zeros_like(raw_pixels) for i in range(len(raw_pixels)): output_raw[i] = find_nearest_point(color_means, raw_pixels[i]) # Map means to the colors provided by the user pairs = [] tmp_means = color_means for color in colors: m = find_nearest_point(tmp_means, color) pairs.append((m, color)) (idxs,) = np.where(np.all(tmp_means == m, axis=1)) tmp_means = np.delete(tmp_means, idxs, axis=0) # Recolor the image for pair in pairs: (idxs,) = np.where(np.all(output_raw == pair[0], axis=1)) output_raw[idxs] = pair[1] save_image(output_raw, "final", starting_resolution) if __name__ == "__main__": parser = argparse.ArgumentParser( prog="Recolor", description="Recolor changes the color palette of an image to the one provided by the user", ) color_loader_group = parser.add_mutually_exclusive_group(required=True) color_loader_group.add_argument( "-f", "--file", description="A file of RGB color values, with one color per line", dest="fpath", default=None, ) color_loader_group.add_argument( "-l", "--list", description="A list of RGB color values. Example: '123,90,89 212,7,0'", dest="clist", default=None, ) main() pass sys.exit(0)