From b30d77775eb4077d72a2189688e041c22958e7cf Mon Sep 17 00:00:00 2001
From: Tanishq Dubey <tdubey@clearstreet.io>
Date: Fri, 25 Nov 2022 10:29:55 -0500
Subject: [PATCH] Recoloring works, adding cli interface

---
 .gitignore | 163 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 main.py    | 134 ++++++++++++++++++++++++++++++-------------
 2 files changed, 259 insertions(+), 38 deletions(-)

diff --git a/.gitignore b/.gitignore
index b984719..42890d3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,166 @@
+# Test images
 *.png
 *.jpg
 *.jpeg
+*.gif
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+#   For a library or package, you might want to ignore these files since the code is
+#   intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+#   This is especially recommended for binary packages to ensure reproducibility, and is more
+#   commonly ignored for libraries.
+#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+#   in version control.
+#   https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+#  and can be added to the global gitignore or merged into this file.  For a more nuclear
+#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
diff --git a/main.py b/main.py
index 778607a..3203faa 100644
--- a/main.py
+++ b/main.py
@@ -1,13 +1,23 @@
-from skimage.io import imread
+import sys
+import argparse
+from skimage.io import imread, imsave
 from scipy.stats import moment
 import numpy as np
+from numpy import array, all, uint8
+from rich.console import Console
+
+console = Console()
+
+def save_image(data, name, resolution):
+    final_image = data.reshape(resolution)
+    imsave(f"{name}.png", final_image)
 
 def find_nearest_point(data, target):
     idx = np.array([calc_distance(p, target) for p in data]).argmin()
     return data[idx]
 
 
-def centeroidnp(arr):
+def centroidnp(arr):
     length = arr.shape[0]
     sum_x = np.sum(arr[:, 0])
     sum_y = np.sum(arr[:, 1])
@@ -20,51 +30,99 @@ def calc_distance(x, y):
 
 
 def k_means(data, count):
-    # Pick n random points to start
-    index = np.random.choice(data.shape[0], count, replace=False)
-    means = data[index]
+    # Pick n random points to startA
+    idx_data = np.unique(data, axis=0)
+    index = np.random.choice(idx_data.shape[0], count, replace=False)
+    means = idx_data[index]
     data = np.delete(data, index, axis=0)
 
     distance_delta = 100
     means_distance = 0
-    while distance_delta > 0.1:
-        print(f"new iteration, distance moved: {distance_delta}")
-        # Initialize cluster map
-        clusters = {}
-        for m in means:
-            clusters[str(m)] = []
+    clusters = {}
+    with console.status("[bold blue] Finding means...") as status:
+        while distance_delta > 5:
+            # Initialize cluster map
+            clusters = {}
+            for m in means:
+                clusters[repr(m)] = []
 
-        # Find closest mean to each point
-        for point in data:
-            closest = find_nearest_point(means, point)
-            clusters[str(closest)].append(point)
+            # Find closest mean to each point
+            for point in data:
+                closest = find_nearest_point(means, point)
+                clusters[repr(closest)].append(point)
 
-        # Find the centeroid of each mean
-        new_means = []
-        previous_distance = means_distance
-        means_distance = 0
-        for mean in means:
-            mean_key = str(mean)
-            # Clean up the results a little bit
-            clusters[mean_key] = np.stack(clusters[str(mean)])
-            # Calculate new mean
-            raw_mean = centeroidnp(clusters[mean_key])
-            nearest_mean_point = find_nearest_point(data, raw_mean)
-            means_distance = means_distance + calc_distance(mean, nearest_mean_point)
-            new_means.append(nearest_mean_point)
-        means_distance = means_distance / float(count)
-        distance_delta = abs(previous_distance - means_distance)
-        means = np.stack(new_means)
-    print(means)
+            # Find the centroid of each mean
+            new_means = []
+            previous_distance = means_distance
+            means_distance = 0
+            for mean in means:
+                mean_key = repr(mean)
+                # Clean up the results a little bit
+                clusters[mean_key] = np.stack(clusters[mean_key])
+                # Calculate new mean
+                raw_mean = centroidnp(clusters[mean_key])
+                nearest_mean_point = find_nearest_point(data, raw_mean)
+                means_distance = means_distance + calc_distance(mean, nearest_mean_point)
+                new_means.append(nearest_mean_point)
+            means_distance = means_distance / float(count)
+            distance_delta = abs(previous_distance - means_distance)
+            means = np.stack(new_means)
+
+        return means
 
 
-
-
-
-im = imread("image.png")
+im = imread("zarin.jpg")
 starting_resolution = im.shape
+console.log("[blue] Starting with image of size: ", starting_resolution)
 raw_pixels = im.reshape(-1, 3)
+raw_shape = raw_pixels.shape
 
-colors = [[45, 85, 255], [0, 181, 204], [243, 225, 107]]
+colors = np.array([np.array([0,43,54]),
+                    np.array([7,54,66]),
+                    np.array([88,110,117]),
+                    np.array([101,123,131]),
+                    np.array([131,148,150]),
+                    np.array([147,161,161]),
+                    np.array([238,232,213]),
+                    np.array([253,246,227]),
+                    np.array([181,137,0]),
+                    np.array([203,75,22]),
+                    np.array([220,50,47]),
+                    np.array([211,54,130]),
+                    np.array([108,113,196]),
+                    np.array([38,139,210]),
+                    np.array([42,161,152]),
+                    np.array([133,153,0])])
 
-k_means(raw_pixels, len(colors))
+def main():
+# Find the colors that most represent the image
+    color_means  = k_means(raw_pixels, len(colors))
+    console.log("[green] Found cluster centers: ", color_means)
+
+# Remap image to the center points
+    console.log("[purple] Re-mapping image")
+    output_raw = np.zeros_like(raw_pixels)
+    for i in range(len(raw_pixels)):
+        output_raw[i] = find_nearest_point(color_means, raw_pixels[i])
+
+# Map means to the colors provided by the user
+    pairs = []
+    tmp_means = color_means
+    for color in colors:
+        m = find_nearest_point(tmp_means, color)
+        pairs.append((m, color))
+        idxs, = np.where(np.all(tmp_means == m, axis=1))
+        tmp_means = np.delete(tmp_means, idxs, axis=0)
+
+# Recolor the image
+    for pair in pairs:
+        idxs, = np.where(np.all(output_raw == pair[0], axis=1))
+        output_raw[idxs] = pair[1]
+    save_image(output_raw, "final", starting_resolution)
+
+
+if __name__ == "__main__":
+    main()
+    pass
+
+sys.exit(0)