423 lines
13 KiB
Python
423 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
colorthief
|
|
~~~~~~~~~~
|
|
|
|
Grabbing the color palette from an image.
|
|
|
|
:copyright: (c) 2015 by Shipeng Feng.
|
|
:license: BSD, see LICENSE for more details.
|
|
"""
|
|
__version__ = '0.2.1'
|
|
|
|
import math
|
|
|
|
from PIL import Image
|
|
|
|
|
|
class cached_property(object):
|
|
"""Decorator that creates converts a method with a single
|
|
self argument into a property cached on the instance.
|
|
"""
|
|
def __init__(self, func):
|
|
self.func = func
|
|
|
|
def __get__(self, instance, type):
|
|
res = instance.__dict__[self.func.__name__] = self.func(instance)
|
|
return res
|
|
|
|
|
|
class ColorThief(object):
|
|
"""Color thief main class."""
|
|
def __init__(self, file):
|
|
"""Create one color thief for one image.
|
|
|
|
:param file: A filename (string) or a file object. The file object
|
|
must implement `read()`, `seek()`, and `tell()` methods,
|
|
and be opened in binary mode.
|
|
"""
|
|
self.image = Image.open(file)
|
|
|
|
def get_color(self, quality=10):
|
|
"""Get the dominant color.
|
|
|
|
:param quality: quality settings, 1 is the highest quality, the bigger
|
|
the number, the faster a color will be returned but
|
|
the greater the likelihood that it will not be the
|
|
visually most dominant color
|
|
:return tuple: (r, g, b)
|
|
"""
|
|
palette = self.get_palette(5, quality)
|
|
return palette[0]
|
|
|
|
def get_palette(self, color_count=10, quality=10):
|
|
"""Build a color palette. We are using the median cut algorithm to
|
|
cluster similar colors.
|
|
|
|
:param color_count: the size of the palette, max number of colors
|
|
:param quality: quality settings, 1 is the highest quality, the bigger
|
|
the number, the faster the palette generation, but the
|
|
greater the likelihood that colors will be missed.
|
|
:return list: a list of tuple in the form (r, g, b)
|
|
"""
|
|
image = self.image.convert('RGBA')
|
|
width, height = image.size
|
|
pixels = image.getdata()
|
|
pixel_count = width * height
|
|
valid_pixels = []
|
|
for i in range(0, pixel_count, quality):
|
|
r, g, b, a = pixels[i]
|
|
# If pixel is mostly opaque and not white
|
|
if a >= 125:
|
|
if not (r > 250 and g > 250 and b > 250):
|
|
valid_pixels.append((r, g, b))
|
|
|
|
# Send array to quantize function which clusters values
|
|
# using median cut algorithm
|
|
cmap = MMCQ.quantize(valid_pixels, color_count)
|
|
return cmap.palette
|
|
|
|
|
|
class MMCQ(object):
|
|
"""Basic Python port of the MMCQ (modified median cut quantization)
|
|
algorithm from the Leptonica library (http://www.leptonica.com/).
|
|
"""
|
|
|
|
SIGBITS = 5
|
|
RSHIFT = 8 - SIGBITS
|
|
MAX_ITERATION = 1000
|
|
FRACT_BY_POPULATIONS = 0.75
|
|
|
|
@staticmethod
|
|
def get_color_index(r, g, b):
|
|
return (r << (2 * MMCQ.SIGBITS)) + (g << MMCQ.SIGBITS) + b
|
|
|
|
@staticmethod
|
|
def get_histo(pixels):
|
|
"""histo (1-d array, giving the number of pixels in each quantized
|
|
region of color space)
|
|
"""
|
|
histo = dict()
|
|
for pixel in pixels:
|
|
rval = pixel[0] >> MMCQ.RSHIFT
|
|
gval = pixel[1] >> MMCQ.RSHIFT
|
|
bval = pixel[2] >> MMCQ.RSHIFT
|
|
index = MMCQ.get_color_index(rval, gval, bval)
|
|
histo[index] = histo.setdefault(index, 0) + 1
|
|
return histo
|
|
|
|
@staticmethod
|
|
def vbox_from_pixels(pixels, histo):
|
|
rmin = 1000000
|
|
rmax = 0
|
|
gmin = 1000000
|
|
gmax = 0
|
|
bmin = 1000000
|
|
bmax = 0
|
|
for pixel in pixels:
|
|
rval = pixel[0] >> MMCQ.RSHIFT
|
|
gval = pixel[1] >> MMCQ.RSHIFT
|
|
bval = pixel[2] >> MMCQ.RSHIFT
|
|
rmin = min(rval, rmin)
|
|
rmax = max(rval, rmax)
|
|
gmin = min(gval, gmin)
|
|
gmax = max(gval, gmax)
|
|
bmin = min(bval, bmin)
|
|
bmax = max(bval, bmax)
|
|
return VBox(rmin, rmax, gmin, gmax, bmin, bmax, histo)
|
|
|
|
@staticmethod
|
|
def median_cut_apply(histo, vbox):
|
|
if not vbox.count:
|
|
return (None, None)
|
|
|
|
rw = vbox.r2 - vbox.r1 + 1
|
|
gw = vbox.g2 - vbox.g1 + 1
|
|
bw = vbox.b2 - vbox.b1 + 1
|
|
maxw = max([rw, gw, bw])
|
|
# only one pixel, no split
|
|
if vbox.count == 1:
|
|
return (vbox.copy, None)
|
|
# Find the partial sum arrays along the selected axis.
|
|
total = 0
|
|
sum_ = 0
|
|
partialsum = {}
|
|
lookaheadsum = {}
|
|
do_cut_color = None
|
|
if maxw == rw:
|
|
do_cut_color = 'r'
|
|
for i in range(vbox.r1, vbox.r2+1):
|
|
sum_ = 0
|
|
for j in range(vbox.g1, vbox.g2+1):
|
|
for k in range(vbox.b1, vbox.b2+1):
|
|
index = MMCQ.get_color_index(i, j, k)
|
|
sum_ += histo.get(index, 0)
|
|
total += sum_
|
|
partialsum[i] = total
|
|
elif maxw == gw:
|
|
do_cut_color = 'g'
|
|
for i in range(vbox.g1, vbox.g2+1):
|
|
sum_ = 0
|
|
for j in range(vbox.r1, vbox.r2+1):
|
|
for k in range(vbox.b1, vbox.b2+1):
|
|
index = MMCQ.get_color_index(j, i, k)
|
|
sum_ += histo.get(index, 0)
|
|
total += sum_
|
|
partialsum[i] = total
|
|
else: # maxw == bw
|
|
do_cut_color = 'b'
|
|
for i in range(vbox.b1, vbox.b2+1):
|
|
sum_ = 0
|
|
for j in range(vbox.r1, vbox.r2+1):
|
|
for k in range(vbox.g1, vbox.g2+1):
|
|
index = MMCQ.get_color_index(j, k, i)
|
|
sum_ += histo.get(index, 0)
|
|
total += sum_
|
|
partialsum[i] = total
|
|
for i, d in partialsum.items():
|
|
lookaheadsum[i] = total - d
|
|
|
|
# determine the cut planes
|
|
dim1 = do_cut_color + '1'
|
|
dim2 = do_cut_color + '2'
|
|
dim1_val = getattr(vbox, dim1)
|
|
dim2_val = getattr(vbox, dim2)
|
|
for i in range(dim1_val, dim2_val+1):
|
|
if partialsum[i] > (total / 2):
|
|
vbox1 = vbox.copy
|
|
vbox2 = vbox.copy
|
|
left = i - dim1_val
|
|
right = dim2_val - i
|
|
if left <= right:
|
|
d2 = min([dim2_val - 1, int(i + right / 2)])
|
|
else:
|
|
d2 = max([dim1_val, int(i - 1 - left / 2)])
|
|
# avoid 0-count boxes
|
|
while not partialsum.get(d2, False):
|
|
d2 += 1
|
|
count2 = lookaheadsum.get(d2)
|
|
while not count2 and partialsum.get(d2-1, False):
|
|
d2 -= 1
|
|
count2 = lookaheadsum.get(d2)
|
|
# set dimensions
|
|
setattr(vbox1, dim2, d2)
|
|
setattr(vbox2, dim1, getattr(vbox1, dim2) + 1)
|
|
return (vbox1, vbox2)
|
|
return (None, None)
|
|
|
|
@staticmethod
|
|
def quantize(pixels, max_color):
|
|
"""Quantize.
|
|
|
|
:param pixels: a list of pixel in the form (r, g, b)
|
|
:param max_color: max number of colors
|
|
"""
|
|
if not pixels:
|
|
raise Exception('Empty pixels when quantize.')
|
|
if max_color < 2 or max_color > 256:
|
|
raise Exception('Wrong number of max colors when quantize.')
|
|
|
|
histo = MMCQ.get_histo(pixels)
|
|
|
|
# check that we aren't below maxcolors already
|
|
if len(histo) <= max_color:
|
|
# generate the new colors from the histo and return
|
|
pass
|
|
|
|
# get the beginning vbox from the colors
|
|
vbox = MMCQ.vbox_from_pixels(pixels, histo)
|
|
pq = PQueue(lambda x: x.count)
|
|
pq.push(vbox)
|
|
|
|
# inner function to do the iteration
|
|
def iter_(lh, target):
|
|
n_color = 1
|
|
n_iter = 0
|
|
while n_iter < MMCQ.MAX_ITERATION:
|
|
vbox = lh.pop()
|
|
if not vbox.count: # just put it back
|
|
lh.push(vbox)
|
|
n_iter += 1
|
|
continue
|
|
# do the cut
|
|
vbox1, vbox2 = MMCQ.median_cut_apply(histo, vbox)
|
|
if not vbox1:
|
|
raise Exception("vbox1 not defined; shouldn't happen!")
|
|
lh.push(vbox1)
|
|
if vbox2: # vbox2 can be null
|
|
lh.push(vbox2)
|
|
n_color += 1
|
|
if n_color >= target:
|
|
return
|
|
if n_iter > MMCQ.MAX_ITERATION:
|
|
return
|
|
n_iter += 1
|
|
|
|
# first set of colors, sorted by population
|
|
iter_(pq, MMCQ.FRACT_BY_POPULATIONS * max_color)
|
|
|
|
# Re-sort by the product of pixel occupancy times the size in
|
|
# color space.
|
|
pq2 = PQueue(lambda x: x.count * x.volume)
|
|
while pq.size():
|
|
pq2.push(pq.pop())
|
|
|
|
# next set - generate the median cuts using the (npix * vol) sorting.
|
|
iter_(pq2, max_color - pq2.size())
|
|
|
|
# calculate the actual colors
|
|
cmap = CMap()
|
|
while pq2.size():
|
|
cmap.push(pq2.pop())
|
|
return cmap
|
|
|
|
|
|
class VBox(object):
|
|
"""3d color space box"""
|
|
def __init__(self, r1, r2, g1, g2, b1, b2, histo):
|
|
self.r1 = r1
|
|
self.r2 = r2
|
|
self.g1 = g1
|
|
self.g2 = g2
|
|
self.b1 = b1
|
|
self.b2 = b2
|
|
self.histo = histo
|
|
|
|
@cached_property
|
|
def volume(self):
|
|
sub_r = self.r2 - self.r1
|
|
sub_g = self.g2 - self.g1
|
|
sub_b = self.b2 - self.b1
|
|
return (sub_r + 1) * (sub_g + 1) * (sub_b + 1)
|
|
|
|
@property
|
|
def copy(self):
|
|
return VBox(self.r1, self.r2, self.g1, self.g2,
|
|
self.b1, self.b2, self.histo)
|
|
|
|
@cached_property
|
|
def avg(self):
|
|
ntot = 0
|
|
mult = 1 << (8 - MMCQ.SIGBITS)
|
|
r_sum = 0
|
|
g_sum = 0
|
|
b_sum = 0
|
|
for i in range(self.r1, self.r2 + 1):
|
|
for j in range(self.g1, self.g2 + 1):
|
|
for k in range(self.b1, self.b2 + 1):
|
|
histoindex = MMCQ.get_color_index(i, j, k)
|
|
hval = self.histo.get(histoindex, 0)
|
|
ntot += hval
|
|
r_sum += hval * (i + 0.5) * mult
|
|
g_sum += hval * (j + 0.5) * mult
|
|
b_sum += hval * (k + 0.5) * mult
|
|
|
|
if ntot:
|
|
r_avg = int(r_sum / ntot)
|
|
g_avg = int(g_sum / ntot)
|
|
b_avg = int(b_sum / ntot)
|
|
else:
|
|
r_avg = int(mult * (self.r1 + self.r2 + 1) / 2)
|
|
g_avg = int(mult * (self.g1 + self.g2 + 1) / 2)
|
|
b_avg = int(mult * (self.b1 + self.b2 + 1) / 2)
|
|
|
|
return r_avg, g_avg, b_avg
|
|
|
|
def contains(self, pixel):
|
|
rval = pixel[0] >> MMCQ.RSHIFT
|
|
gval = pixel[1] >> MMCQ.RSHIFT
|
|
bval = pixel[2] >> MMCQ.RSHIFT
|
|
return all([
|
|
rval >= self.r1,
|
|
rval <= self.r2,
|
|
gval >= self.g1,
|
|
gval <= self.g2,
|
|
bval >= self.b1,
|
|
bval <= self.b2,
|
|
])
|
|
|
|
@cached_property
|
|
def count(self):
|
|
npix = 0
|
|
for i in range(self.r1, self.r2 + 1):
|
|
for j in range(self.g1, self.g2 + 1):
|
|
for k in range(self.b1, self.b2 + 1):
|
|
index = MMCQ.get_color_index(i, j, k)
|
|
npix += self.histo.get(index, 0)
|
|
return npix
|
|
|
|
|
|
class CMap(object):
|
|
"""Color map"""
|
|
def __init__(self):
|
|
self.vboxes = PQueue(lambda x: x['vbox'].count * x['vbox'].volume)
|
|
|
|
@property
|
|
def palette(self):
|
|
return self.vboxes.map(lambda x: x['color'])
|
|
|
|
def push(self, vbox):
|
|
self.vboxes.push({
|
|
'vbox': vbox,
|
|
'color': vbox.avg,
|
|
})
|
|
|
|
def size(self):
|
|
return self.vboxes.size()
|
|
|
|
def nearest(self, color):
|
|
d1 = None
|
|
p_color = None
|
|
for i in range(self.vboxes.size()):
|
|
vbox = self.vboxes.peek(i)
|
|
d2 = math.sqrt(
|
|
math.pow(color[0] - vbox['color'][0], 2) +
|
|
math.pow(color[1] - vbox['color'][1], 2) +
|
|
math.pow(color[2] - vbox['color'][2], 2)
|
|
)
|
|
if d1 is None or d2 < d1:
|
|
d1 = d2
|
|
p_color = vbox['color']
|
|
return p_color
|
|
|
|
def map(self, color):
|
|
for i in range(self.vboxes.size()):
|
|
vbox = self.vboxes.peek(i)
|
|
if vbox['vbox'].contains(color):
|
|
return vbox['color']
|
|
return self.nearest(color)
|
|
|
|
|
|
class PQueue(object):
|
|
"""Simple priority queue."""
|
|
def __init__(self, sort_key):
|
|
self.sort_key = sort_key
|
|
self.contents = []
|
|
self._sorted = False
|
|
|
|
def sort(self):
|
|
self.contents.sort(key=self.sort_key)
|
|
self._sorted = True
|
|
|
|
def push(self, o):
|
|
self.contents.append(o)
|
|
self._sorted = False
|
|
|
|
def peek(self, index=None):
|
|
if not self._sorted:
|
|
self.sort()
|
|
if index is None:
|
|
index = len(self.contents) - 1
|
|
return self.contents[index]
|
|
|
|
def pop(self):
|
|
if not self._sorted:
|
|
self.sort()
|
|
return self.contents.pop()
|
|
|
|
def size(self):
|
|
return len(self.contents)
|
|
|
|
def map(self, f):
|
|
return list(map(f, self.contents))
|