From b30d77775eb4077d72a2189688e041c22958e7cf Mon Sep 17 00:00:00 2001 From: Tanishq Dubey Date: Fri, 25 Nov 2022 10:29:55 -0500 Subject: [PATCH] Recoloring works, adding cli interface --- .gitignore | 163 +++++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 134 ++++++++++++++++++++++++++++++------------- 2 files changed, 259 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index b984719..42890d3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,166 @@ +# Test images *.png *.jpg *.jpeg +*.gif + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/main.py b/main.py index 778607a..3203faa 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,23 @@ -from skimage.io import imread +import sys +import argparse +from skimage.io import imread, imsave from scipy.stats import moment import numpy as np +from numpy import array, all, uint8 +from rich.console import Console + +console = Console() + +def save_image(data, name, resolution): + final_image = data.reshape(resolution) + imsave(f"{name}.png", final_image) def find_nearest_point(data, target): idx = np.array([calc_distance(p, target) for p in data]).argmin() return data[idx] -def centeroidnp(arr): +def centroidnp(arr): length = arr.shape[0] sum_x = np.sum(arr[:, 0]) sum_y = np.sum(arr[:, 1]) @@ -20,51 +30,99 @@ def calc_distance(x, y): def k_means(data, count): - # Pick n random points to start - index = np.random.choice(data.shape[0], count, replace=False) - means = data[index] + # Pick n random points to startA + idx_data = np.unique(data, axis=0) + index = np.random.choice(idx_data.shape[0], count, replace=False) + means = idx_data[index] data = np.delete(data, index, axis=0) distance_delta = 100 means_distance = 0 - while distance_delta > 0.1: - print(f"new iteration, distance moved: {distance_delta}") - # Initialize cluster map - clusters = {} - for m in means: - clusters[str(m)] = [] + clusters = {} + with console.status("[bold blue] Finding means...") as status: + while distance_delta > 5: + # Initialize cluster map + clusters = {} + for m in means: + clusters[repr(m)] = [] - # Find closest mean to each point - for point in data: - closest = find_nearest_point(means, point) - clusters[str(closest)].append(point) + # Find closest mean to each point + for point in data: + closest = find_nearest_point(means, point) + clusters[repr(closest)].append(point) - # Find the centeroid of each mean - new_means = [] - previous_distance = means_distance - means_distance = 0 - for mean in means: - mean_key = str(mean) - # Clean up the results a little bit - clusters[mean_key] = np.stack(clusters[str(mean)]) - # Calculate new mean - raw_mean = centeroidnp(clusters[mean_key]) - nearest_mean_point = find_nearest_point(data, raw_mean) - means_distance = means_distance + calc_distance(mean, nearest_mean_point) - new_means.append(nearest_mean_point) - means_distance = means_distance / float(count) - distance_delta = abs(previous_distance - means_distance) - means = np.stack(new_means) - print(means) + # Find the centroid of each mean + new_means = [] + previous_distance = means_distance + means_distance = 0 + for mean in means: + mean_key = repr(mean) + # Clean up the results a little bit + clusters[mean_key] = np.stack(clusters[mean_key]) + # Calculate new mean + raw_mean = centroidnp(clusters[mean_key]) + nearest_mean_point = find_nearest_point(data, raw_mean) + means_distance = means_distance + calc_distance(mean, nearest_mean_point) + new_means.append(nearest_mean_point) + means_distance = means_distance / float(count) + distance_delta = abs(previous_distance - means_distance) + means = np.stack(new_means) + + return means - - - -im = imread("image.png") +im = imread("zarin.jpg") starting_resolution = im.shape +console.log("[blue] Starting with image of size: ", starting_resolution) raw_pixels = im.reshape(-1, 3) +raw_shape = raw_pixels.shape -colors = [[45, 85, 255], [0, 181, 204], [243, 225, 107]] +colors = np.array([np.array([0,43,54]), + np.array([7,54,66]), + np.array([88,110,117]), + np.array([101,123,131]), + np.array([131,148,150]), + np.array([147,161,161]), + np.array([238,232,213]), + np.array([253,246,227]), + np.array([181,137,0]), + np.array([203,75,22]), + np.array([220,50,47]), + np.array([211,54,130]), + np.array([108,113,196]), + np.array([38,139,210]), + np.array([42,161,152]), + np.array([133,153,0])]) -k_means(raw_pixels, len(colors)) +def main(): +# Find the colors that most represent the image + color_means = k_means(raw_pixels, len(colors)) + console.log("[green] Found cluster centers: ", color_means) + +# Remap image to the center points + console.log("[purple] Re-mapping image") + output_raw = np.zeros_like(raw_pixels) + for i in range(len(raw_pixels)): + output_raw[i] = find_nearest_point(color_means, raw_pixels[i]) + +# Map means to the colors provided by the user + pairs = [] + tmp_means = color_means + for color in colors: + m = find_nearest_point(tmp_means, color) + pairs.append((m, color)) + idxs, = np.where(np.all(tmp_means == m, axis=1)) + tmp_means = np.delete(tmp_means, idxs, axis=0) + +# Recolor the image + for pair in pairs: + idxs, = np.where(np.all(output_raw == pair[0], axis=1)) + output_raw[idxs] = pair[1] + save_image(output_raw, "final", starting_resolution) + + +if __name__ == "__main__": + main() + pass + +sys.exit(0)