implemented k means
This commit is contained in:
parent
9a15628705
commit
78272b1646
68
main.py
68
main.py
@ -1,4 +1,70 @@
|
||||
from skimage.io import imread
|
||||
from scipy.stats import moment
|
||||
import numpy as np
|
||||
|
||||
def find_nearest_point(data, target):
|
||||
idx = np.array([calc_distance(p, target) for p in data]).argmin()
|
||||
return data[idx]
|
||||
|
||||
|
||||
def centeroidnp(arr):
|
||||
length = arr.shape[0]
|
||||
sum_x = np.sum(arr[:, 0])
|
||||
sum_y = np.sum(arr[:, 1])
|
||||
sum_z = np.sum(arr[:, 2])
|
||||
return sum_x/length, sum_y/length, sum_z/length
|
||||
|
||||
|
||||
def calc_distance(x, y):
|
||||
return np.absolute(np.linalg.norm(x - y))
|
||||
|
||||
|
||||
def k_means(data, count):
|
||||
# Pick n random points to start
|
||||
index = np.random.choice(data.shape[0], count, replace=False)
|
||||
means = data[index]
|
||||
data = np.delete(data, index, axis=0)
|
||||
|
||||
distance_delta = 100
|
||||
means_distance = 0
|
||||
while distance_delta > 0.1:
|
||||
print(f"new iteration, distance moved: {distance_delta}")
|
||||
# Initialize cluster map
|
||||
clusters = {}
|
||||
for m in means:
|
||||
clusters[str(m)] = []
|
||||
|
||||
# Find closest mean to each point
|
||||
for point in data:
|
||||
closest = find_nearest_point(means, point)
|
||||
clusters[str(closest)].append(point)
|
||||
|
||||
# Find the centeroid of each mean
|
||||
new_means = []
|
||||
previous_distance = means_distance
|
||||
means_distance = 0
|
||||
for mean in means:
|
||||
mean_key = str(mean)
|
||||
# Clean up the results a little bit
|
||||
clusters[mean_key] = np.stack(clusters[str(mean)])
|
||||
# Calculate new mean
|
||||
raw_mean = centeroidnp(clusters[mean_key])
|
||||
nearest_mean_point = find_nearest_point(data, raw_mean)
|
||||
means_distance = means_distance + calc_distance(mean, nearest_mean_point)
|
||||
new_means.append(nearest_mean_point)
|
||||
means_distance = means_distance / float(count)
|
||||
distance_delta = abs(previous_distance - means_distance)
|
||||
means = np.stack(new_means)
|
||||
print(means)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
im = imread("image.png")
|
||||
print(im)
|
||||
starting_resolution = im.shape
|
||||
raw_pixels = im.reshape(-1, 3)
|
||||
|
||||
colors = [[45, 85, 255], [0, 181, 204], [243, 225, 107]]
|
||||
|
||||
k_means(raw_pixels, len(colors))
|
||||
|
Loading…
Reference in New Issue
Block a user