recolor/main.py
Tanishq Dubey 3624c33c51 formatting
2022-11-25 10:52:34 -05:00

158 lines
4.6 KiB
Python

import argparse
import sys
import numpy as np
from numpy import all, array, uint8
from rich.console import Console
from scipy.stats import moment
from skimage.io import imread, imsave
console = Console()
def save_image(data, name, resolution):
final_image = data.reshape(resolution)
imsave(f"{name}.png", final_image)
def find_nearest_point(data, target):
idx = np.array([calc_distance(p, target) for p in data]).argmin()
return data[idx]
def centroidnp(arr):
length = arr.shape[0]
sum_x = np.sum(arr[:, 0])
sum_y = np.sum(arr[:, 1])
sum_z = np.sum(arr[:, 2])
return sum_x / length, sum_y / length, sum_z / length
def calc_distance(x, y):
return np.absolute(np.linalg.norm(x - y))
def k_means(data, count):
# Pick n random points to startA
idx_data = np.unique(data, axis=0)
index = np.random.choice(idx_data.shape[0], count, replace=False)
means = idx_data[index]
data = np.delete(data, index, axis=0)
distance_delta = 100
means_distance = 0
clusters = {}
with console.status("[bold blue] Finding means...") as status:
while distance_delta > 5:
# Initialize cluster map
clusters = {}
for m in means:
clusters[repr(m)] = []
# Find closest mean to each point
for point in data:
closest = find_nearest_point(means, point)
clusters[repr(closest)].append(point)
# Find the centroid of each mean
new_means = []
previous_distance = means_distance
means_distance = 0
for mean in means:
mean_key = repr(mean)
# Clean up the results a little bit
clusters[mean_key] = np.stack(clusters[mean_key])
# Calculate new mean
raw_mean = centroidnp(clusters[mean_key])
nearest_mean_point = find_nearest_point(data, raw_mean)
means_distance = means_distance + calc_distance(
mean, nearest_mean_point
)
new_means.append(nearest_mean_point)
means_distance = means_distance / float(count)
distance_delta = abs(previous_distance - means_distance)
means = np.stack(new_means)
return means
im = imread("zarin.jpg")
starting_resolution = im.shape
console.log("[blue] Starting with image of size: ", starting_resolution)
raw_pixels = im.reshape(-1, 3)
raw_shape = raw_pixels.shape
colors = np.array(
[
np.array([0, 43, 54]),
np.array([7, 54, 66]),
np.array([88, 110, 117]),
np.array([101, 123, 131]),
np.array([131, 148, 150]),
np.array([147, 161, 161]),
np.array([238, 232, 213]),
np.array([253, 246, 227]),
np.array([181, 137, 0]),
np.array([203, 75, 22]),
np.array([220, 50, 47]),
np.array([211, 54, 130]),
np.array([108, 113, 196]),
np.array([38, 139, 210]),
np.array([42, 161, 152]),
np.array([133, 153, 0]),
]
)
def main():
# Find the colors that most represent the image
color_means = k_means(raw_pixels, len(colors))
console.log("[green] Found cluster centers: ", color_means)
# Remap image to the center points
console.log("[purple] Re-mapping image")
output_raw = np.zeros_like(raw_pixels)
for i in range(len(raw_pixels)):
output_raw[i] = find_nearest_point(color_means, raw_pixels[i])
# Map means to the colors provided by the user
pairs = []
tmp_means = color_means
for color in colors:
m = find_nearest_point(tmp_means, color)
pairs.append((m, color))
(idxs,) = np.where(np.all(tmp_means == m, axis=1))
tmp_means = np.delete(tmp_means, idxs, axis=0)
# Recolor the image
for pair in pairs:
(idxs,) = np.where(np.all(output_raw == pair[0], axis=1))
output_raw[idxs] = pair[1]
save_image(output_raw, "final", starting_resolution)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="Recolor",
description="Recolor changes the color palette of an image to the one provided by the user",
)
color_loader_group = parser.add_mutually_exclusive_group(required=True)
color_loader_group.add_argument(
"-f",
"--file",
description="A file of RGB color values, with one color per line",
dest="fpath",
default=None,
)
color_loader_group.add_argument(
"-l",
"--list",
description="A list of RGB color values. Example: '123,90,89 212,7,0'",
dest="clist",
default=None,
)
main()
pass
sys.exit(0)