2022-11-25 10:29:55 -05:00
|
|
|
import argparse
|
2022-11-25 10:52:34 -05:00
|
|
|
import sys
|
2022-12-04 14:00:01 -05:00
|
|
|
from pathlib import Path
|
|
|
|
import multiprocessing
|
|
|
|
from itertools import product
|
2022-11-25 10:52:34 -05:00
|
|
|
|
2022-11-12 00:52:07 -05:00
|
|
|
import numpy as np
|
2022-11-25 10:29:55 -05:00
|
|
|
from rich.console import Console
|
2022-11-25 10:52:34 -05:00
|
|
|
from scipy.stats import moment
|
|
|
|
from skimage.io import imread, imsave
|
2022-12-04 14:00:01 -05:00
|
|
|
from PIL import Image
|
|
|
|
from numba import njit, prange
|
|
|
|
from scipy.spatial import KDTree
|
2022-11-25 10:29:55 -05:00
|
|
|
|
|
|
|
console = Console()
|
|
|
|
|
2022-11-25 10:52:34 -05:00
|
|
|
|
2022-11-25 10:29:55 -05:00
|
|
|
def save_image(data, name, resolution):
|
|
|
|
final_image = data.reshape(resolution)
|
2022-12-04 14:00:01 -05:00
|
|
|
imsave(name, final_image)
|
2022-11-12 00:52:07 -05:00
|
|
|
|
2022-11-25 10:52:34 -05:00
|
|
|
|
2022-11-12 00:52:07 -05:00
|
|
|
def find_nearest_point(data, target):
|
|
|
|
idx = np.array([calc_distance(p, target) for p in data]).argmin()
|
|
|
|
return data[idx]
|
|
|
|
|
|
|
|
|
2022-11-25 10:29:55 -05:00
|
|
|
def centroidnp(arr):
|
2022-11-12 00:52:07 -05:00
|
|
|
length = arr.shape[0]
|
|
|
|
sum_x = np.sum(arr[:, 0])
|
|
|
|
sum_y = np.sum(arr[:, 1])
|
|
|
|
sum_z = np.sum(arr[:, 2])
|
2022-11-25 10:52:34 -05:00
|
|
|
return sum_x / length, sum_y / length, sum_z / length
|
2022-11-12 00:52:07 -05:00
|
|
|
|
|
|
|
|
|
|
|
def calc_distance(x, y):
|
|
|
|
return np.absolute(np.linalg.norm(x - y))
|
|
|
|
|
|
|
|
|
|
|
|
def k_means(data, count):
|
2022-11-25 10:29:55 -05:00
|
|
|
# Pick n random points to startA
|
|
|
|
idx_data = np.unique(data, axis=0)
|
|
|
|
index = np.random.choice(idx_data.shape[0], count, replace=False)
|
|
|
|
means = idx_data[index]
|
2022-11-12 00:52:07 -05:00
|
|
|
data = np.delete(data, index, axis=0)
|
|
|
|
|
|
|
|
distance_delta = 100
|
|
|
|
means_distance = 0
|
2022-11-25 10:29:55 -05:00
|
|
|
clusters = {}
|
|
|
|
with console.status("[bold blue] Finding means...") as status:
|
|
|
|
while distance_delta > 5:
|
|
|
|
# Initialize cluster map
|
|
|
|
clusters = {}
|
|
|
|
for m in means:
|
|
|
|
clusters[repr(m)] = []
|
|
|
|
|
|
|
|
# Find closest mean to each point
|
|
|
|
for point in data:
|
|
|
|
closest = find_nearest_point(means, point)
|
|
|
|
clusters[repr(closest)].append(point)
|
|
|
|
|
|
|
|
# Find the centroid of each mean
|
|
|
|
new_means = []
|
|
|
|
previous_distance = means_distance
|
|
|
|
means_distance = 0
|
|
|
|
for mean in means:
|
|
|
|
mean_key = repr(mean)
|
|
|
|
# Clean up the results a little bit
|
|
|
|
clusters[mean_key] = np.stack(clusters[mean_key])
|
|
|
|
# Calculate new mean
|
|
|
|
raw_mean = centroidnp(clusters[mean_key])
|
|
|
|
nearest_mean_point = find_nearest_point(data, raw_mean)
|
2022-11-25 10:52:34 -05:00
|
|
|
means_distance = means_distance + calc_distance(
|
|
|
|
mean, nearest_mean_point
|
|
|
|
)
|
2022-11-25 10:29:55 -05:00
|
|
|
new_means.append(nearest_mean_point)
|
|
|
|
means_distance = means_distance / float(count)
|
|
|
|
distance_delta = abs(previous_distance - means_distance)
|
|
|
|
means = np.stack(new_means)
|
|
|
|
|
|
|
|
return means
|
|
|
|
|
|
|
|
|
2022-12-04 14:00:01 -05:00
|
|
|
def find_closest_points(points1, points2):
|
|
|
|
# Build a k-d tree from the points in the second array
|
|
|
|
tree = KDTree(points2)
|
|
|
|
|
|
|
|
# Find the closest point in the second array for each element in the first array
|
|
|
|
closest_points_indices = tree.query(points1)[1]
|
|
|
|
closest_points = points2[closest_points_indices]
|
|
|
|
|
|
|
|
# Return the result
|
|
|
|
return closest_points
|
|
|
|
|
|
|
|
|
|
|
|
def main(image_name, output_name, colors):
|
|
|
|
im = imread(image_name)
|
|
|
|
starting_resolution = im.shape
|
|
|
|
im_resized = np.array(Image.fromarray(im).resize(size=(150, 150)))
|
|
|
|
console.log("[blue] Starting with image of size: ", starting_resolution)
|
|
|
|
raw_pixels_resized = im_resized.reshape(-1, 3)
|
|
|
|
|
2022-11-25 10:52:34 -05:00
|
|
|
# Find the colors that most represent the image
|
2022-12-04 14:00:01 -05:00
|
|
|
color_means = k_means(raw_pixels_resized, len(colors))
|
2022-11-25 10:29:55 -05:00
|
|
|
console.log("[green] Found cluster centers: ", color_means)
|
|
|
|
|
2022-11-25 10:52:34 -05:00
|
|
|
# Remap image to the center points
|
2022-12-04 14:00:01 -05:00
|
|
|
raw_pixels = im.reshape(-1, 3)
|
|
|
|
raw_shape = raw_pixels.shape
|
2022-11-25 10:29:55 -05:00
|
|
|
console.log("[purple] Re-mapping image")
|
|
|
|
output_raw = np.zeros_like(raw_pixels)
|
2022-12-04 14:00:01 -05:00
|
|
|
output_raw = find_closest_points(raw_pixels, color_means)
|
|
|
|
console.log("[purple] Re-mapping image complete (phase 1)")
|
|
|
|
|
2022-11-25 10:29:55 -05:00
|
|
|
|
2022-11-25 10:52:34 -05:00
|
|
|
# Map means to the colors provided by the user
|
2022-11-25 10:29:55 -05:00
|
|
|
pairs = []
|
|
|
|
tmp_means = color_means
|
|
|
|
for color in colors:
|
|
|
|
m = find_nearest_point(tmp_means, color)
|
|
|
|
pairs.append((m, color))
|
2022-11-25 10:52:34 -05:00
|
|
|
(idxs,) = np.where(np.all(tmp_means == m, axis=1))
|
2022-11-25 10:29:55 -05:00
|
|
|
tmp_means = np.delete(tmp_means, idxs, axis=0)
|
|
|
|
|
2022-11-25 10:52:34 -05:00
|
|
|
# Recolor the image
|
2022-11-25 10:29:55 -05:00
|
|
|
for pair in pairs:
|
2022-11-25 10:52:34 -05:00
|
|
|
(idxs,) = np.where(np.all(output_raw == pair[0], axis=1))
|
2022-11-25 10:29:55 -05:00
|
|
|
output_raw[idxs] = pair[1]
|
2022-12-04 14:00:01 -05:00
|
|
|
|
|
|
|
if output_name is None:
|
|
|
|
output_name = f"{input_image.parents[0]}/recolored-{input_image.name}"
|
|
|
|
save_image(output_raw, output_name, starting_resolution)
|
2022-11-25 10:29:55 -05:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2022-11-25 10:52:34 -05:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
prog="Recolor",
|
|
|
|
description="Recolor changes the color palette of an image to the one provided by the user",
|
|
|
|
)
|
2022-12-04 14:00:01 -05:00
|
|
|
parser.add_argument(
|
|
|
|
"filename", help="Input image"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"-o", "--output", help="Output image destination, defaults to the same path as the input titled [path]/recolored-[name]", dest="outname"
|
|
|
|
)
|
2022-11-25 10:52:34 -05:00
|
|
|
color_loader_group = parser.add_mutually_exclusive_group(required=True)
|
|
|
|
color_loader_group.add_argument(
|
|
|
|
"-f",
|
|
|
|
"--file",
|
2022-12-04 14:00:01 -05:00
|
|
|
help="A file of RGB color values, with one color per line",
|
|
|
|
dest="cfpath",
|
2022-11-25 10:52:34 -05:00
|
|
|
)
|
|
|
|
color_loader_group.add_argument(
|
|
|
|
"-l",
|
|
|
|
"--list",
|
2022-12-04 14:00:01 -05:00
|
|
|
nargs='+',
|
|
|
|
help="A list of RGB color values. Example: '123,90,89 212,7,0'",
|
2022-11-25 10:52:34 -05:00
|
|
|
dest="clist",
|
|
|
|
)
|
2022-12-04 14:00:01 -05:00
|
|
|
# Check inputs
|
|
|
|
args = parser.parse_args()
|
|
|
|
input_image = Path(args.filename)
|
|
|
|
if not input_image.exists():
|
|
|
|
console.log("[red] Error: input image path does not exist")
|
|
|
|
sys.exit(-1)
|
|
|
|
else:
|
|
|
|
if not input_image.is_file():
|
|
|
|
console.log("[red] Error: input image path is not a file")
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
# Parse out color
|
|
|
|
colors = []
|
|
|
|
|
|
|
|
if args.cfpath is not None:
|
|
|
|
# Load from file
|
|
|
|
lines = []
|
|
|
|
input_colors = Path(args.cfpath)
|
|
|
|
with open(input_colors) as f:
|
|
|
|
lines = [line.rstrip() for line in f]
|
|
|
|
# Split each line
|
|
|
|
for line in lines:
|
|
|
|
vals = []
|
|
|
|
sp = line.split(',')
|
|
|
|
if len(sp) != 3:
|
|
|
|
console.log("[red] Error: RGB value in file is malformed -- not exactly 3 values", args.cfpath, line)
|
|
|
|
sys.exit(-2)
|
|
|
|
for v in sp:
|
|
|
|
vals.append(int(v))
|
|
|
|
colors.append(np.array(vals))
|
|
|
|
else:
|
|
|
|
#Load from list
|
|
|
|
for line in args.clist:
|
|
|
|
vals = []
|
|
|
|
sp = line.split(',')
|
|
|
|
if len(sp) != 3:
|
|
|
|
console.log("[red] Error: RGB value in list is malformed -- not exactly 3 values", line)
|
|
|
|
sys.exit(-2)
|
|
|
|
for v in sp:
|
|
|
|
vals.append(int(v))
|
|
|
|
colors.append(np.array(vals))
|
|
|
|
|
|
|
|
main(input_image, Path(args.outname) if args.outname is not None else None, np.array(colors))
|