recolor/main.py

158 lines
4.6 KiB
Python
Raw Normal View History

2022-11-25 10:29:55 -05:00
import argparse
2022-11-25 10:52:34 -05:00
import sys
2022-11-12 00:52:07 -05:00
import numpy as np
2022-11-25 10:52:34 -05:00
from numpy import all, array, uint8
2022-11-25 10:29:55 -05:00
from rich.console import Console
2022-11-25 10:52:34 -05:00
from scipy.stats import moment
from skimage.io import imread, imsave
2022-11-25 10:29:55 -05:00
console = Console()
2022-11-25 10:52:34 -05:00
2022-11-25 10:29:55 -05:00
def save_image(data, name, resolution):
final_image = data.reshape(resolution)
imsave(f"{name}.png", final_image)
2022-11-12 00:52:07 -05:00
2022-11-25 10:52:34 -05:00
2022-11-12 00:52:07 -05:00
def find_nearest_point(data, target):
idx = np.array([calc_distance(p, target) for p in data]).argmin()
return data[idx]
2022-11-25 10:29:55 -05:00
def centroidnp(arr):
2022-11-12 00:52:07 -05:00
length = arr.shape[0]
sum_x = np.sum(arr[:, 0])
sum_y = np.sum(arr[:, 1])
sum_z = np.sum(arr[:, 2])
2022-11-25 10:52:34 -05:00
return sum_x / length, sum_y / length, sum_z / length
2022-11-12 00:52:07 -05:00
def calc_distance(x, y):
return np.absolute(np.linalg.norm(x - y))
def k_means(data, count):
2022-11-25 10:29:55 -05:00
# Pick n random points to startA
idx_data = np.unique(data, axis=0)
index = np.random.choice(idx_data.shape[0], count, replace=False)
means = idx_data[index]
2022-11-12 00:52:07 -05:00
data = np.delete(data, index, axis=0)
distance_delta = 100
means_distance = 0
2022-11-25 10:29:55 -05:00
clusters = {}
with console.status("[bold blue] Finding means...") as status:
while distance_delta > 5:
# Initialize cluster map
clusters = {}
for m in means:
clusters[repr(m)] = []
# Find closest mean to each point
for point in data:
closest = find_nearest_point(means, point)
clusters[repr(closest)].append(point)
# Find the centroid of each mean
new_means = []
previous_distance = means_distance
means_distance = 0
for mean in means:
mean_key = repr(mean)
# Clean up the results a little bit
clusters[mean_key] = np.stack(clusters[mean_key])
# Calculate new mean
raw_mean = centroidnp(clusters[mean_key])
nearest_mean_point = find_nearest_point(data, raw_mean)
2022-11-25 10:52:34 -05:00
means_distance = means_distance + calc_distance(
mean, nearest_mean_point
)
2022-11-25 10:29:55 -05:00
new_means.append(nearest_mean_point)
means_distance = means_distance / float(count)
distance_delta = abs(previous_distance - means_distance)
means = np.stack(new_means)
return means
im = imread("zarin.jpg")
2022-11-12 00:52:07 -05:00
starting_resolution = im.shape
2022-11-25 10:29:55 -05:00
console.log("[blue] Starting with image of size: ", starting_resolution)
2022-11-12 00:52:07 -05:00
raw_pixels = im.reshape(-1, 3)
2022-11-25 10:29:55 -05:00
raw_shape = raw_pixels.shape
2022-11-25 10:52:34 -05:00
colors = np.array(
[
np.array([0, 43, 54]),
np.array([7, 54, 66]),
np.array([88, 110, 117]),
np.array([101, 123, 131]),
np.array([131, 148, 150]),
np.array([147, 161, 161]),
np.array([238, 232, 213]),
np.array([253, 246, 227]),
np.array([181, 137, 0]),
np.array([203, 75, 22]),
np.array([220, 50, 47]),
np.array([211, 54, 130]),
np.array([108, 113, 196]),
np.array([38, 139, 210]),
np.array([42, 161, 152]),
np.array([133, 153, 0]),
]
)
2022-11-25 10:29:55 -05:00
def main():
2022-11-25 10:52:34 -05:00
# Find the colors that most represent the image
color_means = k_means(raw_pixels, len(colors))
2022-11-25 10:29:55 -05:00
console.log("[green] Found cluster centers: ", color_means)
2022-11-25 10:52:34 -05:00
# Remap image to the center points
2022-11-25 10:29:55 -05:00
console.log("[purple] Re-mapping image")
output_raw = np.zeros_like(raw_pixels)
for i in range(len(raw_pixels)):
output_raw[i] = find_nearest_point(color_means, raw_pixels[i])
2022-11-25 10:52:34 -05:00
# Map means to the colors provided by the user
2022-11-25 10:29:55 -05:00
pairs = []
tmp_means = color_means
for color in colors:
m = find_nearest_point(tmp_means, color)
pairs.append((m, color))
2022-11-25 10:52:34 -05:00
(idxs,) = np.where(np.all(tmp_means == m, axis=1))
2022-11-25 10:29:55 -05:00
tmp_means = np.delete(tmp_means, idxs, axis=0)
2022-11-25 10:52:34 -05:00
# Recolor the image
2022-11-25 10:29:55 -05:00
for pair in pairs:
2022-11-25 10:52:34 -05:00
(idxs,) = np.where(np.all(output_raw == pair[0], axis=1))
2022-11-25 10:29:55 -05:00
output_raw[idxs] = pair[1]
save_image(output_raw, "final", starting_resolution)
if __name__ == "__main__":
2022-11-25 10:52:34 -05:00
parser = argparse.ArgumentParser(
prog="Recolor",
description="Recolor changes the color palette of an image to the one provided by the user",
)
color_loader_group = parser.add_mutually_exclusive_group(required=True)
color_loader_group.add_argument(
"-f",
"--file",
description="A file of RGB color values, with one color per line",
dest="fpath",
default=None,
)
color_loader_group.add_argument(
"-l",
"--list",
description="A list of RGB color values. Example: '123,90,89 212,7,0'",
dest="clist",
default=None,
)
2022-11-25 10:29:55 -05:00
main()
pass
sys.exit(0)