formatting

This commit is contained in:
Tanishq Dubey 2022-11-25 10:52:34 -05:00
parent 4a72a0fc95
commit 3624c33c51

47
main.py
View File

@ -1,17 +1,20 @@
import sys
import argparse import argparse
from skimage.io import imread, imsave import sys
from scipy.stats import moment
import numpy as np import numpy as np
from numpy import array, all, uint8 from numpy import all, array, uint8
from rich.console import Console from rich.console import Console
from scipy.stats import moment
from skimage.io import imread, imsave
console = Console() console = Console()
def save_image(data, name, resolution): def save_image(data, name, resolution):
final_image = data.reshape(resolution) final_image = data.reshape(resolution)
imsave(f"{name}.png", final_image) imsave(f"{name}.png", final_image)
def find_nearest_point(data, target): def find_nearest_point(data, target):
idx = np.array([calc_distance(p, target) for p in data]).argmin() idx = np.array([calc_distance(p, target) for p in data]).argmin()
return data[idx] return data[idx]
@ -62,7 +65,9 @@ def k_means(data, count):
# Calculate new mean # Calculate new mean
raw_mean = centroidnp(clusters[mean_key]) raw_mean = centroidnp(clusters[mean_key])
nearest_mean_point = find_nearest_point(data, raw_mean) nearest_mean_point = find_nearest_point(data, raw_mean)
means_distance = means_distance + calc_distance(mean, nearest_mean_point) means_distance = means_distance + calc_distance(
mean, nearest_mean_point
)
new_means.append(nearest_mean_point) new_means.append(nearest_mean_point)
means_distance = means_distance / float(count) means_distance = means_distance / float(count)
distance_delta = abs(previous_distance - means_distance) distance_delta = abs(previous_distance - means_distance)
@ -77,7 +82,9 @@ console.log("[blue] Starting with image of size: ", starting_resolution)
raw_pixels = im.reshape(-1, 3) raw_pixels = im.reshape(-1, 3)
raw_shape = raw_pixels.shape raw_shape = raw_pixels.shape
colors = np.array([np.array([0,43,54]), colors = np.array(
[
np.array([0, 43, 54]),
np.array([7, 54, 66]), np.array([7, 54, 66]),
np.array([88, 110, 117]), np.array([88, 110, 117]),
np.array([101, 123, 131]), np.array([101, 123, 131]),
@ -92,7 +99,10 @@ colors = np.array([np.array([0,43,54]),
np.array([108, 113, 196]), np.array([108, 113, 196]),
np.array([38, 139, 210]), np.array([38, 139, 210]),
np.array([42, 161, 152]), np.array([42, 161, 152]),
np.array([133,153,0])]) np.array([133, 153, 0]),
]
)
def main(): def main():
# Find the colors that most represent the image # Find the colors that most represent the image
@ -111,17 +121,36 @@ def main():
for color in colors: for color in colors:
m = find_nearest_point(tmp_means, color) m = find_nearest_point(tmp_means, color)
pairs.append((m, color)) pairs.append((m, color))
idxs, = np.where(np.all(tmp_means == m, axis=1)) (idxs,) = np.where(np.all(tmp_means == m, axis=1))
tmp_means = np.delete(tmp_means, idxs, axis=0) tmp_means = np.delete(tmp_means, idxs, axis=0)
# Recolor the image # Recolor the image
for pair in pairs: for pair in pairs:
idxs, = np.where(np.all(output_raw == pair[0], axis=1)) (idxs,) = np.where(np.all(output_raw == pair[0], axis=1))
output_raw[idxs] = pair[1] output_raw[idxs] = pair[1]
save_image(output_raw, "final", starting_resolution) save_image(output_raw, "final", starting_resolution)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="Recolor",
description="Recolor changes the color palette of an image to the one provided by the user",
)
color_loader_group = parser.add_mutually_exclusive_group(required=True)
color_loader_group.add_argument(
"-f",
"--file",
description="A file of RGB color values, with one color per line",
dest="fpath",
default=None,
)
color_loader_group.add_argument(
"-l",
"--list",
description="A list of RGB color values. Example: '123,90,89 212,7,0'",
dest="clist",
default=None,
)
main() main()
pass pass