Machine Learning Algorithms Part 13: Mean Shift Clustering
Example In Python
Mean Shift is a hierarchical clustering algorithm. In contrast to supervised machine learning algorithms, clustering attempts to group data without having first been train on labeled data. Clustering is used in a wide variety of applications such as search engines, academic rankings and medicine. As opposed to K-Means, when using Mean Shift, you don’t need to know the number of categories (clusters) beforehand. The downside to Mean Shift is that it is computationally expensive — O(n²).
How it works
- Define a window (bandwidth of the kernel) and place the window on a data point
2. Calculate the mean for all the points in the window
3. Move the center of the window to the location of the mean
4. Repeat steps 2 and 3 until there is convergence
Example in python
Let’s take a look at how we could go about labeling the data using the Mean Shift algorithm in python.
import numpy as np
import pandas as pd
from sklearn.cluster import MeanShift
from sklearn.datasets.samples_generator import make_blobs
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
We generate our own data using the make_blobs
method.
clusters = [[1,1,1],[5,5,5],[3,10,10]]
X, _ = make_blobs(n_samples = 150, centers = clusters, cluster_std = 0.60)
After training the model, we store the coordinates for the cluster centers.
ms = MeanShift()
ms.fit(X)
cluster_centers = ms.cluster_centers_
Finally, we plot the data points and centroids in a 3D graph.
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(X[:,0], X[:,1], X[:,2], marker='o')
ax.scatter(cluster_centers[:,0], cluster_centers[:,1], cluster_centers[:,2], marker='x', color='red', s=300, linewidth=5, zorder=10)
plt.show()
Cory Maklin
_Sign in now to see your channels and recommendations!_www.youtube.com