recolor/main.py
2022-12-04 14:00:01 -05:00

203 lines
6.5 KiB
Python

import argparse
import sys
from pathlib import Path
import multiprocessing
from itertools import product
import numpy as np
from rich.console import Console
from scipy.stats import moment
from skimage.io import imread, imsave
from PIL import Image
from numba import njit, prange
from scipy.spatial import KDTree
console = Console()
def save_image(data, name, resolution):
final_image = data.reshape(resolution)
imsave(name, final_image)
def find_nearest_point(data, target):
idx = np.array([calc_distance(p, target) for p in data]).argmin()
return data[idx]
def centroidnp(arr):
length = arr.shape[0]
sum_x = np.sum(arr[:, 0])
sum_y = np.sum(arr[:, 1])
sum_z = np.sum(arr[:, 2])
return sum_x / length, sum_y / length, sum_z / length
def calc_distance(x, y):
return np.absolute(np.linalg.norm(x - y))
def k_means(data, count):
# Pick n random points to startA
idx_data = np.unique(data, axis=0)
index = np.random.choice(idx_data.shape[0], count, replace=False)
means = idx_data[index]
data = np.delete(data, index, axis=0)
distance_delta = 100
means_distance = 0
clusters = {}
with console.status("[bold blue] Finding means...") as status:
while distance_delta > 5:
# Initialize cluster map
clusters = {}
for m in means:
clusters[repr(m)] = []
# Find closest mean to each point
for point in data:
closest = find_nearest_point(means, point)
clusters[repr(closest)].append(point)
# Find the centroid of each mean
new_means = []
previous_distance = means_distance
means_distance = 0
for mean in means:
mean_key = repr(mean)
# Clean up the results a little bit
clusters[mean_key] = np.stack(clusters[mean_key])
# Calculate new mean
raw_mean = centroidnp(clusters[mean_key])
nearest_mean_point = find_nearest_point(data, raw_mean)
means_distance = means_distance + calc_distance(
mean, nearest_mean_point
)
new_means.append(nearest_mean_point)
means_distance = means_distance / float(count)
distance_delta = abs(previous_distance - means_distance)
means = np.stack(new_means)
return means
def find_closest_points(points1, points2):
# Build a k-d tree from the points in the second array
tree = KDTree(points2)
# Find the closest point in the second array for each element in the first array
closest_points_indices = tree.query(points1)[1]
closest_points = points2[closest_points_indices]
# Return the result
return closest_points
def main(image_name, output_name, colors):
im = imread(image_name)
starting_resolution = im.shape
im_resized = np.array(Image.fromarray(im).resize(size=(150, 150)))
console.log("[blue] Starting with image of size: ", starting_resolution)
raw_pixels_resized = im_resized.reshape(-1, 3)
# Find the colors that most represent the image
color_means = k_means(raw_pixels_resized, len(colors))
console.log("[green] Found cluster centers: ", color_means)
# Remap image to the center points
raw_pixels = im.reshape(-1, 3)
raw_shape = raw_pixels.shape
console.log("[purple] Re-mapping image")
output_raw = np.zeros_like(raw_pixels)
output_raw = find_closest_points(raw_pixels, color_means)
console.log("[purple] Re-mapping image complete (phase 1)")
# Map means to the colors provided by the user
pairs = []
tmp_means = color_means
for color in colors:
m = find_nearest_point(tmp_means, color)
pairs.append((m, color))
(idxs,) = np.where(np.all(tmp_means == m, axis=1))
tmp_means = np.delete(tmp_means, idxs, axis=0)
# Recolor the image
for pair in pairs:
(idxs,) = np.where(np.all(output_raw == pair[0], axis=1))
output_raw[idxs] = pair[1]
if output_name is None:
output_name = f"{input_image.parents[0]}/recolored-{input_image.name}"
save_image(output_raw, output_name, starting_resolution)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="Recolor",
description="Recolor changes the color palette of an image to the one provided by the user",
)
parser.add_argument(
"filename", help="Input image"
)
parser.add_argument(
"-o", "--output", help="Output image destination, defaults to the same path as the input titled [path]/recolored-[name]", dest="outname"
)
color_loader_group = parser.add_mutually_exclusive_group(required=True)
color_loader_group.add_argument(
"-f",
"--file",
help="A file of RGB color values, with one color per line",
dest="cfpath",
)
color_loader_group.add_argument(
"-l",
"--list",
nargs='+',
help="A list of RGB color values. Example: '123,90,89 212,7,0'",
dest="clist",
)
# Check inputs
args = parser.parse_args()
input_image = Path(args.filename)
if not input_image.exists():
console.log("[red] Error: input image path does not exist")
sys.exit(-1)
else:
if not input_image.is_file():
console.log("[red] Error: input image path is not a file")
sys.exit(-1)
# Parse out color
colors = []
if args.cfpath is not None:
# Load from file
lines = []
input_colors = Path(args.cfpath)
with open(input_colors) as f:
lines = [line.rstrip() for line in f]
# Split each line
for line in lines:
vals = []
sp = line.split(',')
if len(sp) != 3:
console.log("[red] Error: RGB value in file is malformed -- not exactly 3 values", args.cfpath, line)
sys.exit(-2)
for v in sp:
vals.append(int(v))
colors.append(np.array(vals))
else:
#Load from list
for line in args.clist:
vals = []
sp = line.split(',')
if len(sp) != 3:
console.log("[red] Error: RGB value in list is malformed -- not exactly 3 values", line)
sys.exit(-2)
for v in sp:
vals.append(int(v))
colors.append(np.array(vals))
main(input_image, Path(args.outname) if args.outname is not None else None, np.array(colors))