Tuesday, June 23, 2020

Mean Shift Algorithm

It is another popular and powerful clustering algorithm used in unsupervised learning. It does not make any assumptions hence it is a non-parametric algorithm. It is also called hierarchical clustering or mean shift cluster analysis. Followings would be the basic steps of this algorithm:

  1. First of all, we need to start with the data points assigned to a cluster of their own.
  2. Now, it computes the centroids and update the location of new centroids.
  3. By repeating this process, we move closer the peak of cluster i.e. towards the region of higher density.
  4. This algorithm stops at the stage where centroids do not move anymore.

With the help of following code we are implementing Mean Shift clustering algorithm in Python. We are going to use Scikit-learn module. Let us import the necessary packages:

import numpy as np
from sklearn.cluster import MeanShift
import matplotlib.pyplot as plt
from matplotlib import style
style.use("ggplot")


The following code will help in generating the two-dimensional dataset, containing four blobs, by using make_blob from the sklearn.dataset package.

from sklearn.datasets.samples_generator import make_blobs

We can visualize the dataset with the following code:

centers = [[2,2],[4,5],[3,10]]
X, _ = make_blobs(n_samples = 500, centers = centers, cluster_std = 1)
plt.scatter(X[:,0],X[:,1])
plt.show()


Now, we need to train the Mean Shift cluster model with the input data.

ms = MeanShift()
ms.fit(X)
labels = ms.labels_
cluster_centers = ms.cluster_centers_


The following code will print the cluster centers and the expected number of cluster as per the input data:

print(cluster_centers)
n_clusters_ = len(np.unique(labels))
print("Estimated clusters:", n_clusters_)

[[ 3.23005036 3.84771893]
[ 3.02057451 9.88928991]]
Estimated clusters:
2


The code given below will help plot and visualize the machine's findings based on our data, and the fitment according to the number of clusters that are to be found.

colors = 10*['r.','g.','b.','c.','k.','y.','m.']
for i in range(len(X)):
plt.plot(X[i][0], X[i][1], colors[labels[i]], markersize = 10)
plt.scatter(cluster_centers[:,0],cluster_centers[:,1],
marker="x",color='k', s=150, linewidths = 5, zorder=10)
plt.show()


Here our discussion over mean shift algorithm comes to an end. Next we'll see why we need to measure the clustering performance.
Share:

0 comments:

Post a Comment