formatting

This commit is contained in:
Tanishq Dubey 2022-11-25 10:52:34 -05:00
parent 4a72a0fc95
commit 3624c33c51

85
main.py
View File

@ -1,17 +1,20 @@
import sys
import argparse import argparse
from skimage.io import imread, imsave import sys
from scipy.stats import moment
import numpy as np import numpy as np
from numpy import array, all, uint8 from numpy import all, array, uint8
from rich.console import Console from rich.console import Console
from scipy.stats import moment
from skimage.io import imread, imsave
console = Console() console = Console()
def save_image(data, name, resolution): def save_image(data, name, resolution):
final_image = data.reshape(resolution) final_image = data.reshape(resolution)
imsave(f"{name}.png", final_image) imsave(f"{name}.png", final_image)
def find_nearest_point(data, target): def find_nearest_point(data, target):
idx = np.array([calc_distance(p, target) for p in data]).argmin() idx = np.array([calc_distance(p, target) for p in data]).argmin()
return data[idx] return data[idx]
@ -22,7 +25,7 @@ def centroidnp(arr):
sum_x = np.sum(arr[:, 0]) sum_x = np.sum(arr[:, 0])
sum_y = np.sum(arr[:, 1]) sum_y = np.sum(arr[:, 1])
sum_z = np.sum(arr[:, 2]) sum_z = np.sum(arr[:, 2])
return sum_x/length, sum_y/length, sum_z/length return sum_x / length, sum_y / length, sum_z / length
def calc_distance(x, y): def calc_distance(x, y):
@ -62,7 +65,9 @@ def k_means(data, count):
# Calculate new mean # Calculate new mean
raw_mean = centroidnp(clusters[mean_key]) raw_mean = centroidnp(clusters[mean_key])
nearest_mean_point = find_nearest_point(data, raw_mean) nearest_mean_point = find_nearest_point(data, raw_mean)
means_distance = means_distance + calc_distance(mean, nearest_mean_point) means_distance = means_distance + calc_distance(
mean, nearest_mean_point
)
new_means.append(nearest_mean_point) new_means.append(nearest_mean_point)
means_distance = means_distance / float(count) means_distance = means_distance / float(count)
distance_delta = abs(previous_distance - means_distance) distance_delta = abs(previous_distance - means_distance)
@ -77,51 +82,75 @@ console.log("[blue] Starting with image of size: ", starting_resolution)
raw_pixels = im.reshape(-1, 3) raw_pixels = im.reshape(-1, 3)
raw_shape = raw_pixels.shape raw_shape = raw_pixels.shape
colors = np.array([np.array([0,43,54]), colors = np.array(
np.array([7,54,66]), [
np.array([88,110,117]), np.array([0, 43, 54]),
np.array([101,123,131]), np.array([7, 54, 66]),
np.array([131,148,150]), np.array([88, 110, 117]),
np.array([147,161,161]), np.array([101, 123, 131]),
np.array([238,232,213]), np.array([131, 148, 150]),
np.array([253,246,227]), np.array([147, 161, 161]),
np.array([181,137,0]), np.array([238, 232, 213]),
np.array([203,75,22]), np.array([253, 246, 227]),
np.array([220,50,47]), np.array([181, 137, 0]),
np.array([211,54,130]), np.array([203, 75, 22]),
np.array([108,113,196]), np.array([220, 50, 47]),
np.array([38,139,210]), np.array([211, 54, 130]),
np.array([42,161,152]), np.array([108, 113, 196]),
np.array([133,153,0])]) np.array([38, 139, 210]),
np.array([42, 161, 152]),
np.array([133, 153, 0]),
]
)
def main(): def main():
# Find the colors that most represent the image # Find the colors that most represent the image
color_means = k_means(raw_pixels, len(colors)) color_means = k_means(raw_pixels, len(colors))
console.log("[green] Found cluster centers: ", color_means) console.log("[green] Found cluster centers: ", color_means)
# Remap image to the center points # Remap image to the center points
console.log("[purple] Re-mapping image") console.log("[purple] Re-mapping image")
output_raw = np.zeros_like(raw_pixels) output_raw = np.zeros_like(raw_pixels)
for i in range(len(raw_pixels)): for i in range(len(raw_pixels)):
output_raw[i] = find_nearest_point(color_means, raw_pixels[i]) output_raw[i] = find_nearest_point(color_means, raw_pixels[i])
# Map means to the colors provided by the user # Map means to the colors provided by the user
pairs = [] pairs = []
tmp_means = color_means tmp_means = color_means
for color in colors: for color in colors:
m = find_nearest_point(tmp_means, color) m = find_nearest_point(tmp_means, color)
pairs.append((m, color)) pairs.append((m, color))
idxs, = np.where(np.all(tmp_means == m, axis=1)) (idxs,) = np.where(np.all(tmp_means == m, axis=1))
tmp_means = np.delete(tmp_means, idxs, axis=0) tmp_means = np.delete(tmp_means, idxs, axis=0)
# Recolor the image # Recolor the image
for pair in pairs: for pair in pairs:
idxs, = np.where(np.all(output_raw == pair[0], axis=1)) (idxs,) = np.where(np.all(output_raw == pair[0], axis=1))
output_raw[idxs] = pair[1] output_raw[idxs] = pair[1]
save_image(output_raw, "final", starting_resolution) save_image(output_raw, "final", starting_resolution)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="Recolor",
description="Recolor changes the color palette of an image to the one provided by the user",
)
color_loader_group = parser.add_mutually_exclusive_group(required=True)
color_loader_group.add_argument(
"-f",
"--file",
description="A file of RGB color values, with one color per line",
dest="fpath",
default=None,
)
color_loader_group.add_argument(
"-l",
"--list",
description="A list of RGB color values. Example: '123,90,89 212,7,0'",
dest="clist",
default=None,
)
main() main()
pass pass