recolor/main.py
2022-11-25 10:29:55 -05:00

129 lines
4.1 KiB
Python

import sys
import argparse
from skimage.io import imread, imsave
from scipy.stats import moment
import numpy as np
from numpy import array, all, uint8
from rich.console import Console
console = Console()
def save_image(data, name, resolution):
final_image = data.reshape(resolution)
imsave(f"{name}.png", final_image)
def find_nearest_point(data, target):
idx = np.array([calc_distance(p, target) for p in data]).argmin()
return data[idx]
def centroidnp(arr):
length = arr.shape[0]
sum_x = np.sum(arr[:, 0])
sum_y = np.sum(arr[:, 1])
sum_z = np.sum(arr[:, 2])
return sum_x/length, sum_y/length, sum_z/length
def calc_distance(x, y):
return np.absolute(np.linalg.norm(x - y))
def k_means(data, count):
# Pick n random points to startA
idx_data = np.unique(data, axis=0)
index = np.random.choice(idx_data.shape[0], count, replace=False)
means = idx_data[index]
data = np.delete(data, index, axis=0)
distance_delta = 100
means_distance = 0
clusters = {}
with console.status("[bold blue] Finding means...") as status:
while distance_delta > 5:
# Initialize cluster map
clusters = {}
for m in means:
clusters[repr(m)] = []
# Find closest mean to each point
for point in data:
closest = find_nearest_point(means, point)
clusters[repr(closest)].append(point)
# Find the centroid of each mean
new_means = []
previous_distance = means_distance
means_distance = 0
for mean in means:
mean_key = repr(mean)
# Clean up the results a little bit
clusters[mean_key] = np.stack(clusters[mean_key])
# Calculate new mean
raw_mean = centroidnp(clusters[mean_key])
nearest_mean_point = find_nearest_point(data, raw_mean)
means_distance = means_distance + calc_distance(mean, nearest_mean_point)
new_means.append(nearest_mean_point)
means_distance = means_distance / float(count)
distance_delta = abs(previous_distance - means_distance)
means = np.stack(new_means)
return means
im = imread("zarin.jpg")
starting_resolution = im.shape
console.log("[blue] Starting with image of size: ", starting_resolution)
raw_pixels = im.reshape(-1, 3)
raw_shape = raw_pixels.shape
colors = np.array([np.array([0,43,54]),
np.array([7,54,66]),
np.array([88,110,117]),
np.array([101,123,131]),
np.array([131,148,150]),
np.array([147,161,161]),
np.array([238,232,213]),
np.array([253,246,227]),
np.array([181,137,0]),
np.array([203,75,22]),
np.array([220,50,47]),
np.array([211,54,130]),
np.array([108,113,196]),
np.array([38,139,210]),
np.array([42,161,152]),
np.array([133,153,0])])
def main():
# Find the colors that most represent the image
color_means = k_means(raw_pixels, len(colors))
console.log("[green] Found cluster centers: ", color_means)
# Remap image to the center points
console.log("[purple] Re-mapping image")
output_raw = np.zeros_like(raw_pixels)
for i in range(len(raw_pixels)):
output_raw[i] = find_nearest_point(color_means, raw_pixels[i])
# Map means to the colors provided by the user
pairs = []
tmp_means = color_means
for color in colors:
m = find_nearest_point(tmp_means, color)
pairs.append((m, color))
idxs, = np.where(np.all(tmp_means == m, axis=1))
tmp_means = np.delete(tmp_means, idxs, axis=0)
# Recolor the image
for pair in pairs:
idxs, = np.where(np.all(output_raw == pair[0], axis=1))
output_raw[idxs] = pair[1]
save_image(output_raw, "final", starting_resolution)
if __name__ == "__main__":
main()
pass
sys.exit(0)