recolor/main.py

203 lines
6.5 KiB
Python
Raw Normal View History

2022-11-25 10:29:55 -05:00
import argparse
2022-11-25 10:52:34 -05:00
import sys
2022-12-04 14:00:01 -05:00
from pathlib import Path
import multiprocessing
from itertools import product
2022-11-25 10:52:34 -05:00
2022-11-12 00:52:07 -05:00
import numpy as np
2022-11-25 10:29:55 -05:00
from rich.console import Console
2022-11-25 10:52:34 -05:00
from scipy.stats import moment
from skimage.io import imread, imsave
2022-12-04 14:00:01 -05:00
from PIL import Image
from numba import njit, prange
from scipy.spatial import KDTree
2022-11-25 10:29:55 -05:00
console = Console()
2022-11-25 10:52:34 -05:00
2022-11-25 10:29:55 -05:00
def save_image(data, name, resolution):
final_image = data.reshape(resolution)
2022-12-04 14:00:01 -05:00
imsave(name, final_image)
2022-11-12 00:52:07 -05:00
2022-11-25 10:52:34 -05:00
2022-11-12 00:52:07 -05:00
def find_nearest_point(data, target):
idx = np.array([calc_distance(p, target) for p in data]).argmin()
return data[idx]
2022-11-25 10:29:55 -05:00
def centroidnp(arr):
2022-11-12 00:52:07 -05:00
length = arr.shape[0]
sum_x = np.sum(arr[:, 0])
sum_y = np.sum(arr[:, 1])
sum_z = np.sum(arr[:, 2])
2022-11-25 10:52:34 -05:00
return sum_x / length, sum_y / length, sum_z / length
2022-11-12 00:52:07 -05:00
def calc_distance(x, y):
return np.absolute(np.linalg.norm(x - y))
def k_means(data, count):
2022-11-25 10:29:55 -05:00
# Pick n random points to startA
idx_data = np.unique(data, axis=0)
index = np.random.choice(idx_data.shape[0], count, replace=False)
means = idx_data[index]
2022-11-12 00:52:07 -05:00
data = np.delete(data, index, axis=0)
distance_delta = 100
means_distance = 0
2022-11-25 10:29:55 -05:00
clusters = {}
with console.status("[bold blue] Finding means...") as status:
while distance_delta > 5:
# Initialize cluster map
clusters = {}
for m in means:
clusters[repr(m)] = []
# Find closest mean to each point
for point in data:
closest = find_nearest_point(means, point)
clusters[repr(closest)].append(point)
# Find the centroid of each mean
new_means = []
previous_distance = means_distance
means_distance = 0
for mean in means:
mean_key = repr(mean)
# Clean up the results a little bit
clusters[mean_key] = np.stack(clusters[mean_key])
# Calculate new mean
raw_mean = centroidnp(clusters[mean_key])
nearest_mean_point = find_nearest_point(data, raw_mean)
2022-11-25 10:52:34 -05:00
means_distance = means_distance + calc_distance(
mean, nearest_mean_point
)
2022-11-25 10:29:55 -05:00
new_means.append(nearest_mean_point)
means_distance = means_distance / float(count)
distance_delta = abs(previous_distance - means_distance)
means = np.stack(new_means)
return means
2022-12-04 14:00:01 -05:00
def find_closest_points(points1, points2):
# Build a k-d tree from the points in the second array
tree = KDTree(points2)
# Find the closest point in the second array for each element in the first array
closest_points_indices = tree.query(points1)[1]
closest_points = points2[closest_points_indices]
# Return the result
return closest_points
def main(image_name, output_name, colors):
im = imread(image_name)
starting_resolution = im.shape
im_resized = np.array(Image.fromarray(im).resize(size=(150, 150)))
console.log("[blue] Starting with image of size: ", starting_resolution)
raw_pixels_resized = im_resized.reshape(-1, 3)
2022-11-25 10:52:34 -05:00
# Find the colors that most represent the image
2022-12-04 14:00:01 -05:00
color_means = k_means(raw_pixels_resized, len(colors))
2022-11-25 10:29:55 -05:00
console.log("[green] Found cluster centers: ", color_means)
2022-11-25 10:52:34 -05:00
# Remap image to the center points
2022-12-04 14:00:01 -05:00
raw_pixels = im.reshape(-1, 3)
raw_shape = raw_pixels.shape
2022-11-25 10:29:55 -05:00
console.log("[purple] Re-mapping image")
output_raw = np.zeros_like(raw_pixels)
2022-12-04 14:00:01 -05:00
output_raw = find_closest_points(raw_pixels, color_means)
console.log("[purple] Re-mapping image complete (phase 1)")
2022-11-25 10:29:55 -05:00
2022-11-25 10:52:34 -05:00
# Map means to the colors provided by the user
2022-11-25 10:29:55 -05:00
pairs = []
tmp_means = color_means
for color in colors:
m = find_nearest_point(tmp_means, color)
pairs.append((m, color))
2022-11-25 10:52:34 -05:00
(idxs,) = np.where(np.all(tmp_means == m, axis=1))
2022-11-25 10:29:55 -05:00
tmp_means = np.delete(tmp_means, idxs, axis=0)
2022-11-25 10:52:34 -05:00
# Recolor the image
2022-11-25 10:29:55 -05:00
for pair in pairs:
2022-11-25 10:52:34 -05:00
(idxs,) = np.where(np.all(output_raw == pair[0], axis=1))
2022-11-25 10:29:55 -05:00
output_raw[idxs] = pair[1]
2022-12-04 14:00:01 -05:00
if output_name is None:
output_name = f"{input_image.parents[0]}/recolored-{input_image.name}"
save_image(output_raw, output_name, starting_resolution)
2022-11-25 10:29:55 -05:00
if __name__ == "__main__":
2022-11-25 10:52:34 -05:00
parser = argparse.ArgumentParser(
prog="Recolor",
description="Recolor changes the color palette of an image to the one provided by the user",
)
2022-12-04 14:00:01 -05:00
parser.add_argument(
"filename", help="Input image"
)
parser.add_argument(
"-o", "--output", help="Output image destination, defaults to the same path as the input titled [path]/recolored-[name]", dest="outname"
)
2022-11-25 10:52:34 -05:00
color_loader_group = parser.add_mutually_exclusive_group(required=True)
color_loader_group.add_argument(
"-f",
"--file",
2022-12-04 14:00:01 -05:00
help="A file of RGB color values, with one color per line",
dest="cfpath",
2022-11-25 10:52:34 -05:00
)
color_loader_group.add_argument(
"-l",
"--list",
2022-12-04 14:00:01 -05:00
nargs='+',
help="A list of RGB color values. Example: '123,90,89 212,7,0'",
2022-11-25 10:52:34 -05:00
dest="clist",
)
2022-12-04 14:00:01 -05:00
# Check inputs
args = parser.parse_args()
input_image = Path(args.filename)
if not input_image.exists():
console.log("[red] Error: input image path does not exist")
sys.exit(-1)
else:
if not input_image.is_file():
console.log("[red] Error: input image path is not a file")
sys.exit(-1)
# Parse out color
colors = []
if args.cfpath is not None:
# Load from file
lines = []
input_colors = Path(args.cfpath)
with open(input_colors) as f:
lines = [line.rstrip() for line in f]
# Split each line
for line in lines:
vals = []
sp = line.split(',')
if len(sp) != 3:
console.log("[red] Error: RGB value in file is malformed -- not exactly 3 values", args.cfpath, line)
sys.exit(-2)
for v in sp:
vals.append(int(v))
colors.append(np.array(vals))
else:
#Load from list
for line in args.clist:
vals = []
sp = line.split(',')
if len(sp) != 3:
console.log("[red] Error: RGB value in list is malformed -- not exactly 3 values", line)
sys.exit(-2)
for v in sp:
vals.append(int(v))
colors.append(np.array(vals))
main(input_image, Path(args.outname) if args.outname is not None else None, np.array(colors))