Source code for hypertools.tools.cluster
#!/usr/bin/env python
import warnings
from sklearn.cluster import KMeans, MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration, SpectralClustering
import six
import numpy as np
from .._shared.helpers import *
from .format_data import format_data as formatter
# dictionary of models
models = {
'KMeans': KMeans,
'MiniBatchKMeans': MiniBatchKMeans,
'AgglomerativeClustering': AgglomerativeClustering,
'FeatureAgglomeration': FeatureAgglomeration,
'Birch': Birch,
'SpectralClustering': SpectralClustering,
}
try:
from hdbscan import HDBSCAN
_has_hdbscan = True
models.update({'HDBSCAN': HDBSCAN})
except ImportError:
_has_hdbscan = False
[docs]@memoize
def cluster(x, cluster='KMeans', n_clusters=3, ndims=None, format_data=True):
"""
Performs clustering analysis and returns a list of cluster labels
Parameters
----------
x : A Numpy array, Pandas Dataframe or list of arrays/dfs
The data to be clustered. You can pass a single array/df or a list.
If a list is passed, the arrays will be stacked and the clustering
will be performed across all lists (i.e. not within each list).
cluster : str or dict
Model to use to discover clusters. Support algorithms are: KMeans,
MiniBatchKMeans, AgglomerativeClustering, Birch, FeatureAgglomeration,
SpectralClustering and HDBSCAN (default: KMeans). Can be passed as a
string, but for finer control of the model parameters, pass as a
dictionary, e.g. reduce={'model' : 'KMeans', 'params' : {'max_iter' : 100}}.
See scikit-learn specific model docs for details on parameters supported for
each model.
n_clusters : int
Number of clusters to discover. Not required for HDBSCAN.
format_data : bool
Whether or not to first call the format_data function (default: True).
ndims : None
Deprecated argument. Please use new analyze function to perform
combinations of transformations
Returns
----------
cluster_labels : list
An list of cluster labels
"""
if cluster == None:
return x
elif (isinstance(cluster, six.string_types) and cluster=='HDBSCAN') or \
(isinstance(cluster, dict) and cluster['model']=='HDBSCAN'):
if not _has_hdbscan:
raise ImportError('HDBSCAN is not installed. Please install hdbscan>=0.8.11')
if ndims != None:
warnings.warn('The ndims argument is now deprecated. Ignoring dimensionality reduction step.')
if format_data:
x = formatter(x, ppca=True)
# if reduce is a string, find the corresponding model
if isinstance(cluster, six.string_types):
model = models[cluster]
if cluster != 'HDBSCAN':
model_params = {
'n_clusters' : n_clusters
}
else:
model_params = {}
# if its a dict, use custom params
elif type(cluster) is dict:
if isinstance(cluster['model'], six.string_types):
model = models[cluster['model']]
model_params = cluster['params']
# initialize model
model = model(**model_params)
# fit the model
model.fit(np.vstack(x))
# return the labels
return list(model.labels_)