program is fast now! can package it up

This commit is contained in:
Tanishq Dubey 2022-12-04 14:00:01 -05:00
parent 3624c33c51
commit 600db2f029

125
main.py
View File

@ -1,18 +1,23 @@
import argparse import argparse
import sys import sys
from pathlib import Path
import multiprocessing
from itertools import product
import numpy as np import numpy as np
from numpy import all, array, uint8
from rich.console import Console from rich.console import Console
from scipy.stats import moment from scipy.stats import moment
from skimage.io import imread, imsave from skimage.io import imread, imsave
from PIL import Image
from numba import njit, prange
from scipy.spatial import KDTree
console = Console() console = Console()
def save_image(data, name, resolution): def save_image(data, name, resolution):
final_image = data.reshape(resolution) final_image = data.reshape(resolution)
imsave(f"{name}.png", final_image) imsave(name, final_image)
def find_nearest_point(data, target): def find_nearest_point(data, target):
@ -76,44 +81,37 @@ def k_means(data, count):
return means return means
im = imread("zarin.jpg") def find_closest_points(points1, points2):
# Build a k-d tree from the points in the second array
tree = KDTree(points2)
# Find the closest point in the second array for each element in the first array
closest_points_indices = tree.query(points1)[1]
closest_points = points2[closest_points_indices]
# Return the result
return closest_points
def main(image_name, output_name, colors):
im = imread(image_name)
starting_resolution = im.shape starting_resolution = im.shape
im_resized = np.array(Image.fromarray(im).resize(size=(150, 150)))
console.log("[blue] Starting with image of size: ", starting_resolution) console.log("[blue] Starting with image of size: ", starting_resolution)
raw_pixels = im.reshape(-1, 3) raw_pixels_resized = im_resized.reshape(-1, 3)
raw_shape = raw_pixels.shape
colors = np.array(
[
np.array([0, 43, 54]),
np.array([7, 54, 66]),
np.array([88, 110, 117]),
np.array([101, 123, 131]),
np.array([131, 148, 150]),
np.array([147, 161, 161]),
np.array([238, 232, 213]),
np.array([253, 246, 227]),
np.array([181, 137, 0]),
np.array([203, 75, 22]),
np.array([220, 50, 47]),
np.array([211, 54, 130]),
np.array([108, 113, 196]),
np.array([38, 139, 210]),
np.array([42, 161, 152]),
np.array([133, 153, 0]),
]
)
def main():
# Find the colors that most represent the image # Find the colors that most represent the image
color_means = k_means(raw_pixels, len(colors)) color_means = k_means(raw_pixels_resized, len(colors))
console.log("[green] Found cluster centers: ", color_means) console.log("[green] Found cluster centers: ", color_means)
# Remap image to the center points # Remap image to the center points
raw_pixels = im.reshape(-1, 3)
raw_shape = raw_pixels.shape
console.log("[purple] Re-mapping image") console.log("[purple] Re-mapping image")
output_raw = np.zeros_like(raw_pixels) output_raw = np.zeros_like(raw_pixels)
for i in range(len(raw_pixels)): output_raw = find_closest_points(raw_pixels, color_means)
output_raw[i] = find_nearest_point(color_means, raw_pixels[i]) console.log("[purple] Re-mapping image complete (phase 1)")
# Map means to the colors provided by the user # Map means to the colors provided by the user
pairs = [] pairs = []
@ -128,7 +126,10 @@ def main():
for pair in pairs: for pair in pairs:
(idxs,) = np.where(np.all(output_raw == pair[0], axis=1)) (idxs,) = np.where(np.all(output_raw == pair[0], axis=1))
output_raw[idxs] = pair[1] output_raw[idxs] = pair[1]
save_image(output_raw, "final", starting_resolution)
if output_name is None:
output_name = f"{input_image.parents[0]}/recolored-{input_image.name}"
save_image(output_raw, output_name, starting_resolution)
if __name__ == "__main__": if __name__ == "__main__":
@ -136,22 +137,66 @@ if __name__ == "__main__":
prog="Recolor", prog="Recolor",
description="Recolor changes the color palette of an image to the one provided by the user", description="Recolor changes the color palette of an image to the one provided by the user",
) )
parser.add_argument(
"filename", help="Input image"
)
parser.add_argument(
"-o", "--output", help="Output image destination, defaults to the same path as the input titled [path]/recolored-[name]", dest="outname"
)
color_loader_group = parser.add_mutually_exclusive_group(required=True) color_loader_group = parser.add_mutually_exclusive_group(required=True)
color_loader_group.add_argument( color_loader_group.add_argument(
"-f", "-f",
"--file", "--file",
description="A file of RGB color values, with one color per line", help="A file of RGB color values, with one color per line",
dest="fpath", dest="cfpath",
default=None,
) )
color_loader_group.add_argument( color_loader_group.add_argument(
"-l", "-l",
"--list", "--list",
description="A list of RGB color values. Example: '123,90,89 212,7,0'", nargs='+',
help="A list of RGB color values. Example: '123,90,89 212,7,0'",
dest="clist", dest="clist",
default=None,
) )
main() # Check inputs
pass args = parser.parse_args()
input_image = Path(args.filename)
if not input_image.exists():
console.log("[red] Error: input image path does not exist")
sys.exit(-1)
else:
if not input_image.is_file():
console.log("[red] Error: input image path is not a file")
sys.exit(-1)
sys.exit(0) # Parse out color
colors = []
if args.cfpath is not None:
# Load from file
lines = []
input_colors = Path(args.cfpath)
with open(input_colors) as f:
lines = [line.rstrip() for line in f]
# Split each line
for line in lines:
vals = []
sp = line.split(',')
if len(sp) != 3:
console.log("[red] Error: RGB value in file is malformed -- not exactly 3 values", args.cfpath, line)
sys.exit(-2)
for v in sp:
vals.append(int(v))
colors.append(np.array(vals))
else:
#Load from list
for line in args.clist:
vals = []
sp = line.split(',')
if len(sp) != 3:
console.log("[red] Error: RGB value in list is malformed -- not exactly 3 values", line)
sys.exit(-2)
for v in sp:
vals.append(int(v))
colors.append(np.array(vals))
main(input_image, Path(args.outname) if args.outname is not None else None, np.array(colors))