Classification of food¶
In this notebook we attempt, with only moderate success, to classify photos of fruit and vegetables, just using their overall colour.
from PIL import Image, ImageDraw
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
We first make a list of the photo files that we will use. There are 48 of them, in six groups of eight. Each one is $500\times 500$ pixels.
dir = (Path('..') / 'data/images/fruit/mixed').resolve()
files = list(dir.glob('*.jpg'))
Here we display all the files, combined into a single image.
img = Image.new('RGB', (800, 600), color = 'white')
for i, file in enumerate(files):
img.paste(Image.open(file).resize((100, 100)), (i % 8 * 100, i // 8 * 100))
display(img)
We now want to find the overall colour of each image. We start by reading the image using Image.open. We apply np.asarray to get an array A_rgb of shape (500,500,3). This means that A_rgb[i,j] is a vector of length three, containing the amounts of red, green and blue light at position $(i,j)$. All these amounts are integers in the range from $0$ to $255$. For convenience, we reshape A_rgb to make an array a_rgb of shape (250000, 3) instead of (500, 500, 3).
We then want to calculate the average colour of each image, ignoring the white background. For a pure white pixel, the red, green and blue channels will all be equal to 255, so the sum will be 765. We choose to treat pixels as white if the sum of the three channels is at least 700. The code mask = p.sum(a_rgb, axis=1) < 700 gives a one-dimensional array of length 250000, where the $i$ th entry is true iff the $i$ th pixel is not white. The code a_rgb[mask] then gives an array of shape (n, 3) for some n, giving the amounts of red, green and blue for the non-white pixels. By applying .mean(axis=0) we create an array of length three, containing the average amounts of red, green and blue for the non-white pixels. These averages will typically not be integers, so we apply np.round() and .astype(int) to correct that.
average_colours = []
for file in files:
img = Image.open(file)
a_rgb = np.asarray(img).reshape(-1,3)
mask = np.sum(a_rgb, axis=1) < 700
c = np.round(a_rgb[mask].mean(axis=0)).astype(int)
average_colours.append(c)
average_colours = np.array(average_colours)
We now want to make a scatter plot showing the average colours. This plot will have one small disc for each image, coloured with the average colour of that image. The position of the disc will also reflect the colour, so points of similar colours will be close to each other. We use a technique called principal component analysis to find the best way to arrange the positions; we will not explain the details here.
pca = PCA(n_components=2)
q = pca.fit_transform(average_colours).T
plt.axis('off')
plt.scatter(q[0], q[1], color=np.minimum(1,average_colours/255))
<matplotlib.collections.PathCollection at 0x1e81dad3f90>
We can now use the standard k-means clustering algorithm to divide our colour vectors into six groups.
kmeans = KMeans(n_clusters=6)
kmeans.fit(average_colours)
None
The clusters can be visualised as follows:
ac = kmeans.cluster_centers_
qc = pca.transform(ac).T
mqc = ['s','p','P','*','X','D']
mq = [mqc[kmeans.labels_[i]] for i in range(len(average_colours))]
plt.axis('off')
for i in range(6):
plt.scatter(qc[0,i], qc[1,i], color=ac[i]/255, marker=mqc[i], s=150)
mask = kmeans.labels_ == i
plt.scatter(q[0][mask], q[1][mask], color=average_colours[mask]/255, marker=mqc[i], s=20)
We now display the images in each cluster. We have had some success in separating the different types of fruit and vegetables, but there remain a significant number of errors.
n = len(kmeans.labels_)
for i in range(6):
J = [j for j in range(n) if kmeans.labels_[j] == i]
img = Image.new('RGB', (100*len(J)+100, 100))
d = ImageDraw.Draw(img)
d.rectangle([(0,0),(100,100)], fill=tuple(ac[i].astype(int)))
for j in range(len(J)):
img1 = Image.open(files[J[j]])
img1.thumbnail((100,100))
img.paste(img1, (100*(j+1),0))
display(img)