diff --git a/ChangeLog.rst b/ChangeLog.rst
index 4d3c226..94346be 100644
--- a/ChangeLog.rst
+++ b/ChangeLog.rst
@@ -1,6 +1,12 @@
ChangeLog
====================================================
+Release 0.5.3 (2017-12-18)
+---------------------------------------
+
+* New Features
+ * Add Clustering service (#93, #98)
+
Release 0.5.2 (2017-10-30)
---------------------------------------
diff --git a/README.rst b/README.rst
index a7e0d15..adbe7cc 100644
--- a/README.rst
+++ b/README.rst
@@ -16,7 +16,7 @@ jubakit is a Python module to access Jubatus features easily.
jubakit can be used in conjunction with `scikit-learn `_ so that you can use powerful features like cross validation and model evaluation.
See the `Jubakit Documentation `_ for the detailed description.
-Currently jubakit supports `Classifier `_, `Regression `_, `Anomaly `_, `Recommender `_ and `Weight `_ engines.
+Currently jubakit supports `Classifier `_, `Regression `_, `Anomaly `_, `Recommender `_, `Clustering `_ and `Weight `_ engines.
Install
-------
@@ -105,6 +105,8 @@ See the `example `_ dire
+-----------------------------------+-----------------------------------------------+-----------------------+
| recommender_npb.py | Recommend similar items | |
+-----------------------------------+-----------------------------------------------+-----------------------+
+| clustering_2d.py | Clustering 2-dimensional dataset | |
++-----------------------------------+-----------------------------------------------+-----------------------+
| weight_shogun.py | Tracing fv_converter behavior using Weight | |
+-----------------------------------+-----------------------------------------------+-----------------------+
| weight_model_extract.py | Extract contents of Weight model file | |
diff --git a/example/blobs.csv b/example/blobs.csv
new file mode 100644
index 0000000..ac887ce
--- /dev/null
+++ b/example/blobs.csv
@@ -0,0 +1,301 @@
+cluster,x1,x2
+1,-0.93365052788,-1.14252366433
+1,-1.30942924203,-1.17678133335
+1,-1.34024502265,-1.31247402809
+0,0.705946662732,1.08441625777
+1,-1.13881130967,-0.7085249752
+0,0.956149945383,0.628758392122
+1,-0.749811617546,-0.783501028126
+0,1.09808612117,1.20026950289
+0,1.27769396504,0.928303484449
+0,0.666131442806,0.876941115329
+0,1.10948853175,0.788762702015
+0,0.928641954222,0.737586758678
+1,-0.92868829085,-0.802790686893
+0,1.08601733009,0.743459059141
+0,1.15436399849,1.23561794118
+0,0.894167921591,1.00858976753
+0,1.01180789786,0.962622197845
+1,-1.17281887522,-1.31434156598
+0,0.722735657016,0.748793610328
+0,1.07877265618,0.953919497595
+1,-1.0616397876,-0.836511727778
+1,-0.807477769967,-0.865361420293
+0,0.953572364838,0.935215671951
+0,0.997909419762,0.957200423884
+0,0.96365322024,1.06081199546
+0,0.961237465492,1.07967600026
+1,-0.495908857816,-0.959438464457
+1,-1.69287710777,-1.23728584981
+1,-0.967713503586,-1.06734691425
+1,-0.721891018457,-0.961223455053
+0,1.32631992264,1.01557318933
+1,-0.837405444918,-1.10729247379
+0,1.01495695974,1.23246945591
+1,-0.756971166227,-1.12469423796
+1,-1.14208325927,-0.482392806743
+0,0.977752246202,0.879820916999
+0,0.862810629696,1.08769305379
+0,1.09954406122,0.938612261241
+1,-0.823182960746,-0.930780485092
+0,0.689785527641,1.19453266431
+0,0.757724405413,1.3473090632
+0,0.877481320621,0.793741068761
+0,0.928999199851,1.00270428848
+1,-0.845402369455,-0.749220815906
+1,-0.970156814095,-0.758907438626
+1,-0.87235646144,-0.880970665629
+0,0.607339613379,1.16354040159
+1,-1.34614348729,-1.02540477671
+1,-1.2145504187,-0.999724235822
+1,-0.870527368967,-1.11699430888
+1,-0.718380625072,-1.1088022103
+0,1.062650858,0.785094270862
+1,-0.806293152097,-1.09505549938
+0,0.811787817815,0.869230709957
+0,0.892888731123,1.12132309814
+1,-0.997885067324,-1.14134314191
+1,-0.910173309421,-1.10079246393
+1,-0.481630999694,-0.883757493582
+0,0.707076543774,1.15769692848
+1,-1.16068210274,-1.25297022359
+1,-0.881054298908,-1.08869920088
+0,1.01845183546,1.10855776458
+0,1.00670796987,0.511756022803
+0,1.09775194332,1.12818658085
+1,-0.867627089872,-0.522748510569
+0,0.558298055212,0.886244752846
+1,-1.19590082703,-0.580458281596
+0,0.825876418599,1.02713177735
+1,-1.10534892994,-1.01661263094
+0,1.0137580997,1.01394515234
+0,1.27443347546,1.15344843593
+1,-0.841169236496,-1.2858972431
+0,1.06639728878,1.15198632838
+0,0.732734102261,0.84064579653
+1,-0.97370887636,-1.06128053325
+0,1.14242937492,1.0584408659
+1,-0.876023303483,-1.18236840362
+0,1.07701802062,0.746070794056
+1,-0.935814391035,-1.08753717031
+1,-1.01698475648,-0.775420213246
+0,0.537324421414,1.44402657477
+1,-1.3170507062,-0.872816250137
+1,-1.13712268693,-0.589193699823
+1,-0.913274447267,-0.887080518959
+1,-1.00859225235,-1.06118433703
+1,-0.792496157825,-0.915519738475
+0,0.980027951159,0.494567435192
+0,1.33339766676,0.837165358674
+1,-0.876643120952,-1.18292066828
+1,-0.979084375782,-0.82307189748
+0,0.816902215077,0.899673400048
+1,-0.57374001078,-0.734988521948
+1,-1.0934833097,-0.801864954993
+1,-1.06061689938,-1.34832895394
+1,-0.752697423238,-0.753716927762
+1,-1.42694528231,-0.697203669377
+1,-0.841274764173,-0.781559535673
+0,0.756459398564,0.844521971092
+0,1.29976101673,1.10831386395
+1,-1.03171705276,-1.25389155479
+1,-0.890890412259,-0.812819878009
+1,-0.551192238809,-0.994639833663
+0,1.24756402929,0.776557542347
+0,1.05839562468,1.04517104829
+1,-1.12073030272,-0.837256437347
+1,-0.827044749048,-1.05624253637
+0,0.51395666615,1.11051505652
+0,0.739673389253,0.741822007211
+1,-1.27381981888,-1.19688252607
+1,-0.654682929756,-0.973799559948
+1,-1.27326329396,-0.989773871415
+1,-1.09444784539,-1.27763064589
+1,-0.993825738802,-1.09185165069
+1,-0.802710426742,-0.993350846287
+0,1.11525094936,1.00938723789
+0,0.917137550958,0.816560261526
+0,0.992330530566,0.791111411532
+1,-0.707168296188,-0.925412488074
+0,1.13771832227,0.96075461916
+0,0.890959410912,0.938086246527
+1,-1.03627843595,-0.776940390286
+0,0.915069704117,0.666640094248
+0,0.816783179949,0.786688515747
+0,0.801292624248,0.903060164209
+0,1.16113203714,1.03195696871
+0,1.26008314539,1.24444035364
+1,-0.948892165167,-0.926001056779
+1,-0.941296417563,-1.10844766803
+1,-1.10542842758,-0.92429505133
+0,1.01750150234,1.25568028746
+0,0.752380047579,1.05486143046
+1,-0.73702836877,-1.23281202586
+0,0.627606560555,1.05109826999
+0,0.517776378126,1.19789326426
+1,-0.713222432843,-0.958592252019
+0,1.04382879954,0.480124791259
+1,-1.09346632927,-0.536183575128
+0,0.797768559477,0.867412022334
+1,-1.25166127245,-1.20054168598
+0,1.13751133453,1.4897145682
+0,1.32907920212,1.16890122329
+1,-1.05221319285,-0.619013961394
+0,1.31527194689,0.928056818898
+1,-0.84617505511,-0.960290093304
+0,0.93145065596,1.0457932325
+1,-1.1406088498,-0.766721044648
+1,-1.18117331655,-1.09069328092
+0,1.08225304345,0.829386217405
+1,-0.769894971178,-0.699558950919
+0,1.02050957341,0.986519722895
+1,-1.34325173759,-1.18516751807
+1,-0.582824676028,-0.97448263386
+1,-0.91682874077,-1.30224725509
+0,1.20912285457,0.986237097157
+1,-0.83145177721,-1.04546167864
+1,-1.20867057072,-0.695569099091
+1,-0.964820787331,-0.986652971401
+1,-0.723287318177,-1.17655839686
+0,1.24904540791,0.958520057889
+1,-0.959382519428,-1.10295631686
+1,-1.01521994073,-0.81720667044
+0,0.979055380374,0.930820910756
+0,1.11319239668,1.10061025635
+1,-1.10539188939,-1.15785134438
+0,1.25569667484,0.809511757516
+1,-1.00258583913,-1.0325433727
+1,-1.33488656243,-1.09859469138
+0,0.891235180887,0.819265367588
+0,1.14944236244,1.1630855575
+1,-0.946205231191,-0.849020692018
+0,0.943022145949,0.715867364935
+1,-0.76076441316,-0.779314332251
+0,0.910738956598,1.35569832177
+0,0.721201637698,0.955888602789
+1,-0.787255827868,-0.55261444834
+1,-0.573276712201,-0.959791230078
+0,1.0087782502,1.45900319422
+0,1.01110015772,1.33955783784
+1,-0.605276760547,-0.952284518005
+1,-0.925546454866,-1.17246909632
+1,-1.02198782791,-0.866145630703
+0,0.957553827204,1.05482417288
+1,-0.811182149039,-1.04335186967
+0,1.02046011441,0.904556127609
+1,-0.916580747162,-1.08094162804
+1,-1.31552650363,-0.750744076626
+0,0.559553483781,1.12829448684
+0,1.29397635256,1.04237547388
+1,-1.13287497615,-0.81420008866
+1,-0.944853915653,-1.20704136072
+1,-1.2461332122,-0.958377844123
+1,-0.977067258698,-1.30396981597
+0,0.903306301543,0.949296054698
+0,0.952134777039,1.06964011491
+0,0.621814550845,0.623016191179
+0,0.996972697804,1.2927729619
+0,0.719837189355,0.973600117394
+0,1.24771674478,0.924379447806
+0,1.09031605882,0.710713190899
+0,0.602160571357,0.907628480382
+1,-1.30509663897,-1.22218292874
+1,-1.24449003602,-1.11234483665
+1,-0.996925647094,-1.12394888285
+1,-0.796142515273,-1.06934554615
+1,-1.01453272118,-0.935585735187
+0,0.892871032184,0.89163677657
+0,0.878883422428,0.699279561569
+1,-0.888679242366,-0.954850905534
+0,1.48164360015,0.942929378117
+0,0.877942993272,0.82188051988
+0,1.2840129128,0.782852923729
+0,0.979628162094,1.0046376355
+0,1.0011746956,1.12267964379
+0,0.830659477108,0.869425923633
+1,-1.03726570769,-1.23452374255
+0,0.987657575081,1.37662447447
+0,0.580467666554,0.908464440585
+1,-0.660081209999,-0.984522246947
+1,-0.980369691949,-1.23243599229
+0,0.606533785575,0.699995653768
+0,1.2587530271,0.85275307954
+0,0.979099363225,1.17801837723
+1,-1.01138332504,-0.896736446022
+0,0.769932593875,1.1598559749
+1,-1.04170291003,-0.912274263627
+0,0.371305675372,1.19505468158
+1,-0.896310853917,-1.15976596784
+0,0.97455113657,0.807068299647
+0,1.21648553711,0.767050284428
+0,0.831378092761,0.649644408667
+1,-1.25986459056,-1.54711480041
+0,1.13340112634,0.736359688345
+1,-1.2384024028,-1.06698951467
+1,-0.988528343692,-0.73969035651
+1,-1.1593662035,-1.37746336517
+1,-0.933518993839,-0.984825025698
+1,-1.51994563014,-0.992198993144
+1,-1.18289178831,-1.09432642972
+1,-1.1489853972,-1.25476230938
+0,0.850649329045,1.04511356058
+0,0.94721169958,0.602117008885
+0,0.769624988299,0.841565193223
+0,0.882463189363,0.830069344874
+1,-1.19048953014,-1.11239954012
+1,-1.37596644766,-0.835583923757
+0,1.15934228485,1.09186345868
+0,1.05753142965,1.21179706895
+0,1.01735561622,0.706920484683
+0,1.25766505085,1.46784431684
+1,-0.904286143701,-0.99044421338
+0,0.688260024665,0.992351467181
+0,1.08799888383,0.775358047861
+1,-1.28388131583,-0.693712470801
+1,-0.998580601632,-0.918860662659
+1,-0.777905864167,-0.665513500153
+1,-1.04125869063,-1.01713047393
+1,-0.872275784821,-0.697519693307
+0,1.15431043138,1.10149551628
+0,1.33418913895,1.40379322465
+0,1.37519870216,0.721840657971
+0,1.02407222119,0.839840744484
+1,-0.812167585814,-1.05228941383
+1,-1.17749289026,-0.901481708669
+0,1.01857858685,0.897001451661
+0,0.871092966037,1.15403550193
+0,0.998247749893,1.00736964006
+1,-0.833648095387,-1.20145439675
+0,1.02711313653,0.749427178008
+1,-0.934653682918,-0.950688169763
+1,-0.707870466892,-1.05779320848
+0,1.19638680512,0.92284423706
+0,1.2409973686,1.16302331561
+1,-1.07686675539,-0.767626617732
+0,0.877645943426,1.21517199377
+1,-0.971147020183,-1.11221146273
+1,-0.867142933604,-0.749630698221
+1,-1.044349193,-0.998343943806
+0,1.010116662,1.18559779957
+1,-1.07085798109,-1.11210027352
+0,0.807044965878,0.988275972749
+1,-0.989614222944,-0.910899342659
+0,1.06351163418,1.02178086384
+0,0.926517738237,1.0643475345
+1,-0.824690984171,-1.12187500144
+0,0.952629213312,1.28773201112
+0,1.4960145587,1.19712198723
+1,-0.875043933252,-1.2711026884
+0,0.861985501019,1.2682967345
+0,1.02214055141,0.971836886647
+1,-1.04975530185,-0.861185052048
+0,1.16225778459,0.926479092841
+1,-0.935322471815,-0.96645101227
+0,1.23139224734,0.793878878449
+1,-0.83692733936,-0.698771683734
+0,1.09847002587,1.00367339063
+1,-0.937565881378,-1.12237616805
+0,0.979596407895,0.865568046059
+1,-0.912566324192,-0.872655513966
+1,-0.606827980293,-0.626358852502
+1,-0.905890154881,-0.680739811248
\ No newline at end of file
diff --git a/example/clustering_2d.py b/example/clustering_2d.py
new file mode 100644
index 0000000..fd5bbff
--- /dev/null
+++ b/example/clustering_2d.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+"""
+Using Clustering
+========================================
+
+This is a simple example that illustrates Clustering service usage.
+
+"""
+
+from jubakit.clustering import Clustering, Schema, Dataset, Config
+from jubakit.loader.csv import CSVLoader
+
+# Load a CSV file.
+loader = CSVLoader('blobs.csv')
+
+# Define a Schema that defines types for each columns of the CSV file.
+schema = Schema({
+ 'cluster': Schema.ID,
+}, Schema.NUMBER)
+
+# Create a Dataset.
+dataset = Dataset(loader, schema)
+
+# Create an Clustering Service.
+cfg = Config(method='kmeans')
+clustering = Clustering.run(cfg)
+
+# Update the Clustering model.
+for (idx, row_id, result) in clustering.push(dataset):
+ pass
+
+# Get clusters
+clusters = clustering.get_core_members(light=False)
+# Get centers of each cluster
+centers = clustering.get_k_center()
+
+# Calculate SSE: sum of squared errors
+sse = 0.0
+for cluster, center in zip(clusters, centers):
+ # Center of clusters
+ center = {"x1": center.num_values[0][1], "x2": center.num_values[1][1]}
+ for d in cluster:
+ vector = d.point.num_values
+ x1 = [x[1] for x in vector if x[0] == 'x1'][0]
+ x2 = [x[1] for x in vector if x[0] == 'x2'][0]
+ sse += (x1 - center["x1"])**2 + (x2- center["x2"])**2
+print('SSE:', sse)
+
+clustering.stop()
diff --git a/example/clustering_sklearn_wrapper.py b/example/clustering_sklearn_wrapper.py
new file mode 100644
index 0000000..a4fc22d
--- /dev/null
+++ b/example/clustering_sklearn_wrapper.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+"""
+Using Clustering
+========================================
+
+This is a simple example that illustrates Clustering service usage.
+
+"""
+
+from sklearn.datasets import make_blobs
+
+from jubakit.wrapper.clustering import KMeans, GMM, DBSCAN
+
+# make blob dataset using sklearn API.
+X, y = make_blobs(n_samples=200, centers=3, n_features=2, random_state=42)
+
+# launch clustering instance
+clusterings = [
+ KMeans(k=3, bucket_size=200, embedded=False),
+ GMM(k=3, bucket_size=200, embedded=False),
+ DBSCAN(eps=2.0, bucket_size=200, embedded=False)
+]
+
+for clustering in clusterings:
+ # fit and predict
+ y_pred = clustering.fit_predict(X)
+ # print result
+ labels = set(y_pred)
+ label_counts = {}
+ for label in labels:
+ label_counts[label] = y_pred.count(label)
+ print('{0}: {1}'.format(
+ clustering.__class__.__name__,
+ label_counts))
+ # stop clustering service
+ clustering.stop()
+
diff --git a/jubakit/_version.py b/jubakit/_version.py
index 7cf8a70..7db2873 100644
--- a/jubakit/_version.py
+++ b/jubakit/_version.py
@@ -1 +1 @@
-VERSION = (0, 5, 2)
+VERSION = (0, 5, 3)
diff --git a/jubakit/clustering.py b/jubakit/clustering.py
new file mode 100644
index 0000000..a590f74
--- /dev/null
+++ b/jubakit/clustering.py
@@ -0,0 +1,307 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import json
+import uuid
+
+import jubatus
+import jubatus.embedded
+
+from .base import GenericSchema, BaseDataset, BaseService, GenericConfig, Utils
+from .loader.array import ArrayLoader, ZipArrayLoader
+from .loader.sparse import SparseMatrixLoader
+from .loader.chain import ValueMapChainLoader, MergeChainLoader
+from .compat import *
+
+class Schema(GenericSchema):
+ """
+ Schema for Clustering service.
+ """
+
+ ID = 'i'
+
+ def __init__(self, mapping, fallback=None):
+ self._id_key = self._get_unique_mapping(mapping, fallback, self.ID, 'ID', True)
+ super(Schema, self).__init__(mapping, fallback)
+
+ def transform(self, row):
+ """
+ Clustering schema transforms the row into Datum, its associated ID.
+ """
+ row_id = row.get(self._id_key, None)
+ if row_id is not None:
+ row_id = unicode_t(row_id)
+ else:
+ row_id = unicode_t(uuid.uuid4())
+ d = self._transform_as_datum(row, None, [self._id_key])
+ return (row_id, d)
+
+class Dataset(BaseDataset):
+ """
+ Dataset for Clustering service.
+ """
+
+ @classmethod
+ def _predict(cls, row):
+ return Schema.predict(row, False)
+
+ @classmethod
+ def _from_loader(cls, data_loader, ids, static):
+ if ids is None:
+ loader = data_loader
+ schema = Schema({}, Schema.NUMBER)
+ else:
+ id_loader = ZipArrayLoader(_id=ids)
+ loader = MergeChainLoader(data_loader, id_loader)
+ schema = Schema({'_id': Schema.ID}, Schema.NUMBER)
+ return Dataset(loader, schema, static)
+
+ @classmethod
+ def from_data(cls, data, ids=None, feature_names=None, static=True):
+ """
+ Converts two arrays or a sparse matrix data and its associated id array to Dataset.
+
+ Parameters
+ ----------
+ data : array or scipy 2-D sparse matrix of shape [n_samples, n_features]
+ ids : array of shape [n_samples], optional
+ feature_names : array of shape [n_features], optional
+ """
+
+ if hasattr(data, 'todense'):
+ return cls.from_matrix(data, ids, feature_names, static)
+ else:
+ return cls.from_array(data, ids, feature_names, static)
+
+ @classmethod
+ def from_array(cls, data, ids=None, feature_names=None, static=True):
+ """
+ Converts two arrays (data and its associated targets) to Dataset.
+
+ Parameters
+ ----------
+ data : array of shape [n_samples, n_features]
+ ids : array of shape [n_samples], optional
+ feature_names : array of shape [n_features], optional
+ """
+
+ data_loader = ArrayLoader(data, feature_names)
+ return cls._from_loader(data_loader, ids, static)
+
+ @classmethod
+ def from_matrix(cls, data, ids=None, feature_names=None, static=True):
+ """
+ Converts a sparse matrix data and its associated target array to Dataset.
+
+ Parameters
+ ----------
+
+ data : scipy 2-D sparse matrix of shape [n_samples, n_features]
+ ids : array of shape [n_samples], optional
+ feature_names : array of shape [n_features], optional
+ """
+
+ data_loader = SparseMatrixLoader(data, feature_names)
+ return cls._from_loader(data_loader, ids, static)
+
+ def get_ids(self):
+ """
+ Returns labels of each record in the dataset.
+ """
+
+ if not self._static:
+ raise RuntimeError('non-static datasets cannot fetch list of ids')
+ for (idx, (row_id, d)) in self:
+ yield row_id
+
+class Clustering(BaseService):
+ """
+ Clustering service.
+ """
+
+ @classmethod
+ def name(cls):
+ return 'clustering'
+
+ @classmethod
+ def _client_class(cls):
+ return jubatus.clustering.client.Clustering
+
+ @classmethod
+ def _embedded_class(cls):
+ return jubatus.embedded.Clustering
+
+ def push(self, dataset):
+ """
+ Add data points.
+ """
+
+ cli = self._client()
+ for (idx, (row_id, d)) in dataset:
+ if row_id is None:
+ raise RuntimeError('each row must have `id`.')
+ result = cli.push([jubatus.clustering.types.IndexedPoint(row_id, d)])
+ yield (idx, row_id, result)
+
+ def get_revision(self):
+ """
+ Return revision of clusters
+ """
+
+ cli = self._client()
+ return cli.get_revision()
+
+ def get_core_members(self, light=False):
+ """
+ Returns coreset of cluster in datum.
+ """
+
+ cli = self._client()
+ method = self._get_method()
+ if light:
+ return cli.get_core_members_light()
+ else:
+ return cli.get_core_members()
+
+ def get_k_center(self):
+ """
+ Return k cluster centers.
+ """
+
+ cli = self._client()
+ method = self._get_method()
+ if method not in ('kmeans', 'gmm'):
+ raise RuntimeError('{0} is not supported'.format(method))
+ return cli.get_k_center()
+
+ def get_nearest_center(self, dataset):
+ """
+ Returns nearest cluster center without adding points to cluster.
+ """
+
+ cli = self._client()
+ method = self._get_method()
+ if method not in ('kmeans', 'gmm'):
+ raise RuntimeError('{0} is not supported'.format(method))
+
+ for (idx, (row_id, d)) in dataset:
+ result = cli.get_nearest_center(d)
+ yield (idx, row_id, result)
+
+ def get_nearest_members(self, dataset, light=False):
+ """
+ Returns nearest summary of cluster(coreset) from each point.
+ """
+
+ cli = self._client()
+ method = self._get_method()
+ if method not in ('kmeans', 'gmm'):
+ raise RuntimeError('{0} is not supported'.format(method))
+
+ for (idx, (row_id, d)) in dataset:
+ if light:
+ result = cli.get_nearest_members_light(d)
+ else:
+ result = cli.get_nearest_members(d)
+ yield (idx, row_id, result)
+
+ def _get_method(self):
+ method = None
+ if self._embedded:
+ config = json.loads(self._backend.model.get_config())
+ method = config['method']
+ else:
+ if 'method' in self._backend.config:
+ method = self._backend.config['method']
+ return method
+
+class Config(GenericConfig):
+ """
+ Configulation to run Clustering service.
+ """
+
+ def __init__(self, method=None, parameter=None,
+ compressor_method=None, compressor_parameter=None,
+ converter=None):
+ super(Config, self).__init__(method, parameter, converter)
+ if compressor_method is not None:
+ self['compressor_method'] = compressor_method
+ default_compressor_parameter = \
+ self._default_compressor_parameter(compressor_method)
+ if default_compressor_parameter is None:
+ if 'compressor_parameter' in self:
+ del self['compressor_parameter']
+ else:
+ self['compressor_parameter'] = default_compressor_parameter
+
+ if compressor_parameter is not None:
+ if 'compressor_parameter' in self:
+ self['compressor_parameter'].update(compressor_parameter)
+ else:
+ self['compressor_parameter'] = compressor_parameter
+
+ @classmethod
+ def _default(cls, cfg):
+ super(Config, cls)._default(cfg)
+
+ compressor_method = cls._default_compressor_method()
+ compressor_parameter = cls._default_compressor_parameter(compressor_method)
+
+ if compressor_method is not None:
+ cfg['compressor_method'] = compressor_method
+ if compressor_parameter is not None:
+ cfg['compressor_parameter'] = compressor_parameter
+
+ return cfg
+
+ @classmethod
+ def _default_method(cls):
+ return 'kmeans'
+
+ @classmethod
+ def _default_compressor_method(cls):
+ return 'simple'
+
+ @classmethod
+ def _default_parameter(cls, method):
+ if method in ('kmeans', 'gmm'):
+ return {
+ 'k': 3,
+ 'seed': 0
+ }
+ elif method in ('dbscan'):
+ return {
+ 'eps': 0.2,
+ 'min_core_point': 3
+ }
+ else:
+ raise RuntimeError('unknown method: {0}'.format(method))
+
+ @classmethod
+ def _default_compressor_parameter(cls, method):
+ if method in ('simple'):
+ return {
+ 'bucket_size': 100
+ }
+ elif method in ('compressive'):
+ return {
+ 'bucket_size': 100,
+ 'bucket_length': 2,
+ 'compressed_bucket_size': 100,
+ 'bicriteria_base_size': 10,
+ 'forgetting_factor': 0.0,
+ 'forgetting_threshold': 0.5,
+ 'seed': 0
+ }
+ else:
+ raise RuntimeError('unknown method: {0}'.format(method))
+
+ @classmethod
+ def methods(cls):
+ return ['kmeans', 'gmm', 'dbscan']
+
+ @classmethod
+ def compressor_methods(cls):
+ return ['simple', 'compressive']
+
diff --git a/jubakit/test/test_clustering.py b/jubakit/test/test_clustering.py
new file mode 100644
index 0000000..b02a667
--- /dev/null
+++ b/jubakit/test/test_clustering.py
@@ -0,0 +1,332 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+from unittest import TestCase
+
+try:
+ import numpy as np
+ from scipy.sparse import csr_matrix
+except importError:
+ pass
+
+from jubakit.clustering import Schema, Dataset, Clustering, Config
+from jubakit.compat import *
+
+from . import requireSklearn, requireEmbedded
+from .stub import *
+
+class SchemaTest(TestCase):
+ def test_simple(self):
+ schema = Schema({
+ 'id': Schema.ID,
+ 'k1': Schema.STRING,
+ 'k2': Schema.NUMBER
+ })
+ (row_id, d) = schema.transform({'id': 'user001', 'k1': 'abc', 'k2': '123'})
+
+ self.assertEqual(row_id, 'user001')
+ self.assertEqual({'k1': 'abc'}, dict(d.string_values))
+ self.assertEqual({'k2': 123}, dict(d.num_values))
+
+ def test_without_id(self):
+ # schema without id can be defined
+ Schema({
+ 'k1': Schema.STRING,
+ })
+
+ def test_illegal_id(self):
+ # schema with multiple IDs
+ self.assertRaises(RuntimeError, Schema, {
+ 'k1': Schema.ID,
+ 'k2': Schema.ID,
+ })
+
+ # schema fallback set to id
+ self.assertRaises(RuntimeError, Schema, {
+ 'k1': Schema.ID
+ }, Schema.ID)
+
+class DatasetTest(TestCase):
+ def test_simple(self):
+ loader = StubLoader()
+ schema = Schema({'v': Schema.ID})
+ ds = Dataset(loader, schema)
+ for (idx, (label, d)) in ds:
+ self.assertEqual(unicode_t(idx+1), label)
+ self.assertEqual(0, len(d.string_values))
+ self.assertEqual(0, len(d.num_values))
+ self.assertEqual(0, len(d.binary_values))
+ self.assertEqual(['1','2','3'], list(ds.get_ids()))
+
+ def test_predict(self):
+ loader = StubLoader()
+ dataset = Dataset(loader)
+ self.assertEqual(['v', 1.0], dataset[0][1].num_values[0])
+
+ def test_from_data(self):
+ # load from array format
+ ds = Dataset.from_data(
+ [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data
+ ['i1', 'i2', 'i3'], # ids
+ ['k1', 'k2', 'k3'] # feature names
+ )
+
+ expected_k1s = [10, 20, 40]
+ expected_ids = ['i1', 'i2', 'i3']
+ actual_k1s = []
+ actual_ids = []
+ for (idx, (row_id, d)) in ds:
+ actual_k1s.append(dict(d.num_values).get('k1', None))
+ actual_ids.append(row_id)
+
+ self.assertEqual(expected_k1s, actual_k1s)
+ self.assertEqual(expected_ids, actual_ids)
+
+ # load from scipy.sparse format
+ ds = Dataset.from_data(
+ self._create_matrix(), # data
+ ['i1', 'i2', 'i3'], # ids
+ [ 'k1', 'k2', 'k3'], # feature_names
+ )
+
+ expected_k1s = [1, None, 4]
+ expected_k3s = [2, 3, 6]
+ expected_ids = ['i1', 'i2', 'i3']
+ actual_k1s = []
+ actual_k3s = []
+ actual_ids = []
+ for (idx, (row_id, d)) in ds:
+ actual_k1s.append(dict(d.num_values).get('k1', None))
+ actual_k3s.append(dict(d.num_values).get('k3', None))
+ actual_ids.append(row_id)
+
+ self.assertEqual(expected_k1s, actual_k1s)
+ self.assertEqual(expected_k3s, actual_k3s)
+ self.assertEqual(expected_ids, actual_ids)
+
+ def test_from_array(self):
+ ds = Dataset.from_array(
+ [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data
+ ['i1', 'i2', 'i3'], # ids
+ ['k1', 'k2', 'k3'] # feature names
+ )
+
+ expected_k1s = [10, 20, 40]
+ expected_ids = ['i1', 'i2', 'i3']
+ actual_k1s = []
+ actual_ids = []
+ for (idx, (row_id, d)) in ds:
+ actual_k1s.append(dict(d.num_values).get('k1', None))
+ actual_ids.append(row_id)
+
+ self.assertEqual(expected_k1s, actual_k1s)
+ self.assertEqual(expected_ids, actual_ids)
+
+ def test_from_array_without_ids(self):
+ ds = Dataset.from_array(
+ [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data
+ feature_names=['k1', 'k2', 'k3'] # feature names
+ )
+
+ expected_k1s = [10, 20, 40]
+ actual_k1s = []
+ actual_ids = []
+ for (idx, (row_id, d)) in ds:
+ actual_k1s.append(dict(d.num_values).get('k1', None))
+ actual_ids.append(row_id)
+ self.assertEqual(expected_k1s, actual_k1s)
+ self.assertEqual(len(actual_ids), 3)
+
+ @requireSklearn
+ def test_from_matrix(self):
+ ds = Dataset.from_matrix(
+ self._create_matrix(), # data
+ ['i1', 'i2', 'i3'], # ids
+ ['k1', 'k2', 'k3'] # feature names
+ )
+
+ expected_k1s = [1, None, 4]
+ expected_k3s = [2, 3, 6]
+ expected_ids = ['i1', 'i2', 'i3']
+ actual_k1s = []
+ actual_k3s = []
+ actual_ids = []
+ for (idx, (row_id, d)) in ds:
+ actual_k1s.append(dict(d.num_values).get('k1', None))
+ actual_k3s.append(dict(d.num_values).get('k3', None))
+ actual_ids.append(row_id)
+
+ self.assertEqual(expected_k1s, actual_k1s)
+ self.assertEqual(expected_k3s, actual_k3s)
+ self.assertEqual(expected_ids, actual_ids)
+
+ def test_get_ids(self):
+ ds = Dataset.from_array(
+ [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data
+ ['i1', 'i2', 'i3'], # ids
+ static=True
+ )
+ actual_ids = []
+ expected_ids = ['i1', 'i2', 'i3']
+ for row_id in ds.get_ids():
+ actual_ids.append(row_id)
+ self.assertEqual(expected_ids, actual_ids)
+
+ ds = Dataset.from_array(
+ [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data
+ ['i1', 'i2', 'i3'], # ids
+ static=False
+ )
+ self.assertRaises(RuntimeError, list, ds.get_ids())
+
+ def _create_matrix(self):
+ """
+ Create a sparse matrix:
+
+ [[1, 0, 2],
+ [0, 0, 3],
+ [4, 5, 6]]
+ """
+ row = np.array([0, 0, 1, 2, 2, 2])
+ col = np.array([0, 2, 2, 0, 1, 2])
+ data = np.array([1, 2, 3, 4, 5, 6])
+ return csr_matrix((data, (row, col)), shape=(3, 3))
+
+class ClusteringTest(TestCase):
+ def test_simple(self):
+ clustering = Clustering()
+ clustering.stop()
+
+ @requireEmbedded
+ def test_embedded(self):
+ clustering = Clustering.run(Config(), embedded=True)
+ clustering.stop()
+
+ def test_push(self):
+ clustering = Clustering.run(Config())
+ dataset = self._make_stub_dataset()
+ for (idx, row_id, result) in clustering.push(dataset):
+ self.assertEqual(result, True)
+ clustering.stop()
+
+ def test_get_revision(self):
+ clustering = Clustering.run(Config())
+ self.assertEqual(0, clustering.get_revision())
+ clustering.stop()
+
+ def test_get_core_members(self):
+ dataset = self._make_stub_dataset()
+ config = Config(method='kmeans', compressor_parameter={"bucket_size": 5})
+ clustering = self._make_stub_clustering(config, dataset)
+ clustering.get_core_members(light=False)
+ clustering.get_core_members(light=True)
+ clustering.stop()
+
+ def test_get_k_center(self):
+ def func(clustering, dataset):
+ clustering.get_k_center()
+ self._test_func_with_legal_and_illegal_config(func)
+
+ def test_get_nearest_center(self):
+ def func(clustering, dataset):
+ for _ in clustering.get_nearest_center(dataset): pass
+ self._test_func_with_legal_and_illegal_config(func)
+
+ def test_get_nearest_members(self):
+ def func1(clustering, dataset):
+ for _ in clustering.get_nearest_members(dataset, light=False): pass
+ self._test_func_with_legal_and_illegal_config(func1)
+
+ def func2(clustering, dataset):
+ for _ in clustering.get_nearest_members(dataset, light=True): pass
+ self._test_func_with_legal_and_illegal_config(func2)
+
+ def _test_func_with_legal_and_illegal_config(self, func):
+ dataset = self._make_stub_dataset()
+ # test illegal method
+ config = Config(method='dbscan', compressor_parameter={"bucket_size": 5})
+ clustering = self._make_stub_clustering(config, dataset)
+ self.assertRaises(RuntimeError, lambda: func(clustering, dataset))
+ clustering.stop()
+
+ # test legal method
+ config = Config(method='kmeans', compressor_parameter={"bucket_size": 5})
+ clustering = self._make_stub_clustering(config, dataset)
+ func(clustering, dataset)
+ clustering.stop()
+
+ def _make_stub_clustering(self, config, dataset):
+ dataset = self._make_stub_dataset()
+ clustering = Clustering.run(config)
+ for _ in clustering.push(dataset): pass
+ return clustering
+
+ def _make_stub_dataset(self):
+ ids = ['id1', 'id2', 'id3', 'id4', 'id5']
+ X = [
+ [0, 0, 0],
+ [1, 1, 1],
+ [2, 2, 2],
+ [3, 3, 3],
+ [4, 4, 4]
+ ]
+ dataset = Dataset.from_array(X, ids=ids)
+ return dataset
+
+
+class ConfigTest(TestCase):
+ def test_simple(self):
+ config = Config()
+ self.assertEqual('kmeans', config['method'])
+ self.assertEqual({'k': 3, 'seed': 0}, config['parameter'])
+ self.assertEqual('simple', config['compressor_method'])
+ self.assertEqual({'bucket_size': 100}, config.get('compressor_parameter'))
+
+ def test_methods(self):
+ config = Config()
+ self.assertTrue(isinstance(config.methods(), list))
+
+ def test_compressor_methods(self):
+ config = Config()
+ self.assertTrue(isinstance(config.compressor_methods(), list))
+
+ def test_illegal_comporessor_method(self):
+ self.assertRaises(RuntimeError,
+ Config._default_compressor_parameter,
+ 'invalid_compressor_method')
+
+ def test_default(self):
+ config = Config.default()
+ self.assertEqual('kmeans', config['method'])
+ self.assertEqual('simple', config['compressor_method'])
+
+ def test_method_params(self):
+ self.assertTrue('k' in Config(method='kmeans')['parameter'])
+ self.assertTrue('seed' in Config(method='kmeans')['parameter'])
+ self.assertTrue('k' in Config(method='gmm')['parameter'])
+ self.assertTrue('seed' in Config(method='gmm')['parameter'])
+ self.assertTrue('eps' in Config(method='dbscan')['parameter'])
+ self.assertTrue('min_core_point' in Config(method='dbscan')['parameter'])
+
+ def test_compressor_params(self):
+ self.assertTrue('bucket_size' in
+ Config(compressor_method='simple')['compressor_parameter'])
+ self.assertTrue('bucket_size' in
+ Config(compressor_method='compressive')['compressor_parameter'])
+ self.assertTrue('bucket_length' in
+ Config(compressor_method='compressive')['compressor_parameter'])
+ self.assertTrue('compressed_bucket_size' in
+ Config(compressor_method='compressive')['compressor_parameter'])
+ self.assertTrue('bicriteria_base_size' in
+ Config(compressor_method='compressive')['compressor_parameter'])
+ self.assertTrue('forgetting_factor' in
+ Config(compressor_method='compressive')['compressor_parameter'])
+ self.assertTrue('seed' in
+ Config(compressor_method='compressive')['compressor_parameter'])
+ config = Config(compressor_method='simple',
+ compressor_parameter={'bucket_size': 10})
+ self.assertEqual(10, config['compressor_parameter']['bucket_size'])
+
+ def test_invalid_method(self):
+ self.assertRaises(RuntimeError, Config._default_parameter, 'invalid_method')
diff --git a/jubakit/test/wrapper/test_clustering.py b/jubakit/test/wrapper/test_clustering.py
new file mode 100644
index 0000000..639306d
--- /dev/null
+++ b/jubakit/test/wrapper/test_clustering.py
@@ -0,0 +1,133 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+from unittest import TestCase
+
+try:
+ import numpy as np
+except ImportError:
+ pass
+
+from jubakit.wrapper.clustering import KMeans, GMM, DBSCAN
+from . import requireEmbedded
+
+
+class KMeansTest(TestCase):
+
+ def test_simple(self):
+ clustering = KMeans(embedded=False)
+ clustering.stop()
+
+ @requireEmbedded
+ def test_embedded(self):
+ clustering = KMeans(embedded=True)
+
+ def test_init(self):
+ clustering = KMeans(embedded=False)
+ self.assertEqual(2, clustering.k)
+ self.assertEqual('simple', clustering.compressor_method)
+ self.assertEqual(100, clustering.bucket_size)
+ self.assertEqual(100, clustering.compressed_bucket_size)
+ self.assertEqual(10, clustering.bicriteria_base_size)
+ self.assertEqual(2, clustering.bucket_length)
+ self.assertEqual(0.0, clustering.forgetting_factor)
+ self.assertEqual(0.5, clustering.forgetting_threshold)
+ self.assertEqual(0, clustering.seed)
+ self.asssrtTrue(not clustering.embedded)
+ clustering.stop()
+
+ def test_method(self):
+ clustering = KMeans(embedded=False)
+ self.assertEqual('kmeans', clustering._method())
+ clustering.stop()
+
+ def test_make_compressor_parameter(self):
+ clustering = KMeans(compressor_method='simple', embedded=False)
+ compressor_parameter = {'bucket_size': 100}
+ self.assertEqual(compressor_parameter,
+ clustering._make_compressor_parameter('simple'))
+ clustering.stop()
+
+ clustering = KMeans(compressor_method='compressive', embedded=False)
+ compressor_parameter = {
+ 'bucket_size': 100,
+ 'compressed_bucket_size': 100,
+ 'bicriteria_base_size': 10,
+ 'bucket_length': 2,
+ 'forgetting_factor': 0.0,
+ 'forgetting_threshold': 0.5,
+ 'seed': 0
+ }
+ self.assertEqual(compressor_parameter
+ clustering._make_compressor_parameter('compressive'))
+ clustering.stop()
+
+ def test_fit(self):
+ X = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]])
+ clustering = KMeans(k=10, embedded=False)
+ self.assertRaises(RuntimeWarning, clustering.fit(X))
+ clustering.stop()
+
+ clustering = KMeans(k=5, embedded=False)
+ self.assertTrue(not clustering.fitted)
+ clustering.fit(X)
+ self.assertTrue(clustering.fitted)
+
+ def test_predict(self):
+ X = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]])
+ clustering = KMeans(k=5, embedded=False)
+ self.assertRaises(RuntimeError, clustering.predict(X))
+ clustering.fit(X)
+ y_pred = clustering.predict(X)
+ self.assertEqual(len(y_pred), X.shape[0])
+
+ def test_fit_predict(self):
+ X = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]])
+ clustering = KMeans(k=5, embedded=False)
+ self.assertTrue(not clustering.fitted)
+ y_pred = clustering.fit_predict(X)
+ self.assertTrue(clustering.fitted)
+ self.assertEqual(len(y_pred), X.shape[0])
+
+
+class GMMTest(TestCase):
+
+ def test_simple(self):
+ clustering = GMM(embedded=False)
+ clustering.stop()
+
+ @requireEmbedded
+ def test_embedded(self):
+ clustering = GMM(embedded=True)
+
+ def test_method(self):
+ clustering = GMM(embedded=False)
+ self.assertEqual('gmm', clustering._method())
+ clustering.stop()
+
+class DBSCANTest(testCase):
+
+ def test_simple(self):
+ clustering = DBSCAN(embedded=False)
+ clustering.stop()
+
+ @requireEmbedded
+ def test_embedded(self):
+ clustering = DBSCAN(embedded=True)
+
+ def test_init(self):
+ clustering = DBSCAN(embedded=False)
+ self.assertEqual(0.2, clustering.eps)
+ self.assertEqual(3, clustering.min_core_point)
+ self.assertEqual('simple', clustering.compressor_method)
+ self.assertEqual(100, clustering.bucket_size)
+ self.assertEqual(100, clustering.compressed_bucket_size)
+ self.assertEqual(10, clustering.bicriteria_base_size)
+ self.assertEqual(2, clustering.bucket_length)
+ self.assertEqual(0.0, clustering.forgetting_factor)
+ self.assertEqual(0.5, clustering.forgetting_threshold)
+ self.assertEqual(0, clustering.seed)
+ self.asssrtTrue(not clustering.embedded)
+ clustering.stop()
+
diff --git a/jubakit/wrapper/clustering.py b/jubakit/wrapper/clustering.py
new file mode 100644
index 0000000..301a8e0
--- /dev/null
+++ b/jubakit/wrapper/clustering.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import numpy as np
+from sklearn.base import BaseEstimator, ClusterMixin, TransformerMixin
+from ..clustering import Clustering, Config, Dataset
+
+
+class BaseJubatusClustering(BaseEstimator, ClusterMixin):
+ """
+ scikit-learn Wrapper for Jubatus Clustering.
+ """
+
+ def __init__(self, compressor_method='simple',
+ bucket_size=100, compressed_bucket_size=100,
+ bicriteria_base_size=10, bucket_length=2,
+ forgetting_factor=0.0, forgetting_threshold=0.5,
+ seed=0, embedded=True):
+ """
+ Creates a base class for Jubatus Clustering
+ """
+ self.compressor_method = compressor_method
+ self.bucket_size = bucket_size
+ self.compressed_bucket_size = compressed_bucket_size
+ self.bicriteria_base_size = bicriteria_base_size
+ self.bucket_length = bucket_length
+ self.forgetting_factor = forgetting_factor
+ self.forgetting_threshold = forgetting_threshold
+ self.seed = seed
+ self.embedded = embedded
+ self.compressor_parameter = \
+ self._make_compressor_parameter(self.compressor_method)
+ self.fitted = False
+
+ def _launch_clustering(self):
+ """
+ Launch Jubatus Clustering
+ """
+ raise NotImplementedError()
+
+ def _make_compressor_parameter(self, compressor_method):
+ if compressor_method == 'simple':
+ return {
+ 'bucket_size': self.bucket_size,
+ }
+ elif compressor_method == 'compressive':
+ return {
+ 'bucket_size': self.bucket_size,
+ 'compressed_bucket_size': self.compressed_bucket_size,
+ 'bicriteria_base_size': self.bicriteria_base_size,
+ 'bucket_length': self.bucket_length,
+ 'forgetting_factor': self.forgetting_factor,
+ 'forgetting_threshold': self.forgetting_threshold,
+ 'seed': self.seed
+ }
+ else:
+ raise NotImplementedError()
+
+ def fit_predict(self, X, y=None):
+ """
+ Construct clustering model and
+ Predict the closest cluster each sample in X belongs to.
+ """
+ ids = list(range(len(X)))
+ dataset = Dataset.from_data(X, ids=ids)
+ self._launch_clustering()
+ self.clustering_.clear()
+ for _ in self.clustering_.push(dataset):
+ pass
+ self.fitted = True
+ clusters = self.clustering_.get_core_members(light=True)
+ labels = ['None'] * len(ids)
+ for cluster_id, cluster in enumerate(clusters):
+ for point in cluster:
+ labels[int(point.id)] = cluster_id
+ return labels
+
+ def stop(self):
+ self.clustering_.stop()
+
+ def clear(self):
+ self.clustering_.clear()
+
+
+class BaseKFixedClustering(BaseJubatusClustering):
+
+ def __init__(self, k=2, compressor_method='simple',
+ bucket_size=100, compressed_bucket_size=100,
+ bicriteria_base_size=10, bucket_length=2,
+ forgetting_factor=0.0, forgetting_threshold=0.5,
+ seed=0, embedded=True):
+ super(BaseKFixedClustering, self).__init__(
+ compressor_method, bucket_size, compressed_bucket_size, bicriteria_base_size,
+ bucket_length, forgetting_factor, forgetting_threshold, seed, embedded)
+ self.k = k
+
+ def _method(self):
+ raise NotImplementedError()
+
+ def _launch_clustering(self):
+ self.method = self._method()
+ self.parameter = {
+ 'k': self.k,
+ 'seed': self.seed
+ }
+ self.config_ = Config(method=self.method, parameter=self.parameter,
+ compressor_method=self.compressor_method,
+ compressor_parameter=self.compressor_parameter)
+ self.clustering_ = Clustering.run(config=self.config_,
+ embedded=self.embedded)
+
+ def fit(self, X, y=None):
+ """
+ Construct clustering model.
+ """
+ if len(X) < self.k:
+ raise RuntimeWarning("At least k={0} points are needed \
+ but {1} points given".format(self.k, len(X)))
+ dataset = Dataset.from_data(X)
+ self._launch_clustering()
+ self.clustering_.clear()
+ for _ in self.clustering_.push(dataset): pass
+ self.fitted = True
+ return self
+
+ def predict(self, X):
+ """
+ Predict the closest cluster each sample in X belongs to.
+ """
+ if not self.fitted:
+ raise RuntimeError("clustering model not fitted yet.")
+ dataset = Dataset.from_data(X)
+ y_pred = []
+ mappings = {}
+ count = 0
+ for idx, row_id, result in self.clustering_.get_nearest_center(dataset):
+ if result not in mappings:
+ mappings[result] = count
+ y_pred.append(count)
+ count += 1
+ else:
+ y_pred.append(mappings[result])
+ return y_pred
+
+
+class KMeans(BaseKFixedClustering):
+
+ def _method(self):
+ return 'kmeans'
+
+
+class GMM(BaseKFixedClustering):
+
+ def _method(self):
+ return 'gmm'
+
+
+class DBSCAN(BaseJubatusClustering):
+
+ def __init__(self, eps=0.2, min_core_point=3,
+ bucket_size=100, compressed_bucket_size=100,
+ bicriteria_base_size=10, bucket_length=2,
+ forgetting_factor=0.0, forgetting_threshold=0.5,
+ seed=0, embedded=True):
+ super(DBSCAN, self).__init__('simple', bucket_size,
+ compressed_bucket_size, bicriteria_base_size,
+ bucket_length, forgetting_factor,
+ forgetting_threshold, seed, embedded)
+ self.eps = eps
+ self.min_core_point = min_core_point
+
+ def _launch_clustering(self):
+ self.method = 'dbscan'
+ self.parameter = {
+ 'eps': self.eps,
+ 'min_core_point': self.min_core_point
+ }
+ self.config_ = Config(method=self.method, parameter=self.parameter,
+ compressor_method=self.compressor_method,
+ compressor_parameter=self.compressor_parameter)
+ self.clustering_ = Clustering.run(config=self.config_,
+ embedded=self.embedded)
+