diff --git a/ChangeLog.rst b/ChangeLog.rst index 4d3c226..94346be 100644 --- a/ChangeLog.rst +++ b/ChangeLog.rst @@ -1,6 +1,12 @@ ChangeLog ==================================================== +Release 0.5.3 (2017-12-18) +--------------------------------------- + +* New Features + * Add Clustering service (#93, #98) + Release 0.5.2 (2017-10-30) --------------------------------------- diff --git a/README.rst b/README.rst index a7e0d15..adbe7cc 100644 --- a/README.rst +++ b/README.rst @@ -16,7 +16,7 @@ jubakit is a Python module to access Jubatus features easily. jubakit can be used in conjunction with `scikit-learn `_ so that you can use powerful features like cross validation and model evaluation. See the `Jubakit Documentation `_ for the detailed description. -Currently jubakit supports `Classifier `_, `Regression `_, `Anomaly `_, `Recommender `_ and `Weight `_ engines. +Currently jubakit supports `Classifier `_, `Regression `_, `Anomaly `_, `Recommender `_, `Clustering `_ and `Weight `_ engines. Install ------- @@ -105,6 +105,8 @@ See the `example `_ dire +-----------------------------------+-----------------------------------------------+-----------------------+ | recommender_npb.py | Recommend similar items | | +-----------------------------------+-----------------------------------------------+-----------------------+ +| clustering_2d.py | Clustering 2-dimensional dataset | | ++-----------------------------------+-----------------------------------------------+-----------------------+ | weight_shogun.py | Tracing fv_converter behavior using Weight | | +-----------------------------------+-----------------------------------------------+-----------------------+ | weight_model_extract.py | Extract contents of Weight model file | | diff --git a/example/blobs.csv b/example/blobs.csv new file mode 100644 index 0000000..ac887ce --- /dev/null +++ b/example/blobs.csv @@ -0,0 +1,301 @@ +cluster,x1,x2 +1,-0.93365052788,-1.14252366433 +1,-1.30942924203,-1.17678133335 +1,-1.34024502265,-1.31247402809 +0,0.705946662732,1.08441625777 +1,-1.13881130967,-0.7085249752 +0,0.956149945383,0.628758392122 +1,-0.749811617546,-0.783501028126 +0,1.09808612117,1.20026950289 +0,1.27769396504,0.928303484449 +0,0.666131442806,0.876941115329 +0,1.10948853175,0.788762702015 +0,0.928641954222,0.737586758678 +1,-0.92868829085,-0.802790686893 +0,1.08601733009,0.743459059141 +0,1.15436399849,1.23561794118 +0,0.894167921591,1.00858976753 +0,1.01180789786,0.962622197845 +1,-1.17281887522,-1.31434156598 +0,0.722735657016,0.748793610328 +0,1.07877265618,0.953919497595 +1,-1.0616397876,-0.836511727778 +1,-0.807477769967,-0.865361420293 +0,0.953572364838,0.935215671951 +0,0.997909419762,0.957200423884 +0,0.96365322024,1.06081199546 +0,0.961237465492,1.07967600026 +1,-0.495908857816,-0.959438464457 +1,-1.69287710777,-1.23728584981 +1,-0.967713503586,-1.06734691425 +1,-0.721891018457,-0.961223455053 +0,1.32631992264,1.01557318933 +1,-0.837405444918,-1.10729247379 +0,1.01495695974,1.23246945591 +1,-0.756971166227,-1.12469423796 +1,-1.14208325927,-0.482392806743 +0,0.977752246202,0.879820916999 +0,0.862810629696,1.08769305379 +0,1.09954406122,0.938612261241 +1,-0.823182960746,-0.930780485092 +0,0.689785527641,1.19453266431 +0,0.757724405413,1.3473090632 +0,0.877481320621,0.793741068761 +0,0.928999199851,1.00270428848 +1,-0.845402369455,-0.749220815906 +1,-0.970156814095,-0.758907438626 +1,-0.87235646144,-0.880970665629 +0,0.607339613379,1.16354040159 +1,-1.34614348729,-1.02540477671 +1,-1.2145504187,-0.999724235822 +1,-0.870527368967,-1.11699430888 +1,-0.718380625072,-1.1088022103 +0,1.062650858,0.785094270862 +1,-0.806293152097,-1.09505549938 +0,0.811787817815,0.869230709957 +0,0.892888731123,1.12132309814 +1,-0.997885067324,-1.14134314191 +1,-0.910173309421,-1.10079246393 +1,-0.481630999694,-0.883757493582 +0,0.707076543774,1.15769692848 +1,-1.16068210274,-1.25297022359 +1,-0.881054298908,-1.08869920088 +0,1.01845183546,1.10855776458 +0,1.00670796987,0.511756022803 +0,1.09775194332,1.12818658085 +1,-0.867627089872,-0.522748510569 +0,0.558298055212,0.886244752846 +1,-1.19590082703,-0.580458281596 +0,0.825876418599,1.02713177735 +1,-1.10534892994,-1.01661263094 +0,1.0137580997,1.01394515234 +0,1.27443347546,1.15344843593 +1,-0.841169236496,-1.2858972431 +0,1.06639728878,1.15198632838 +0,0.732734102261,0.84064579653 +1,-0.97370887636,-1.06128053325 +0,1.14242937492,1.0584408659 +1,-0.876023303483,-1.18236840362 +0,1.07701802062,0.746070794056 +1,-0.935814391035,-1.08753717031 +1,-1.01698475648,-0.775420213246 +0,0.537324421414,1.44402657477 +1,-1.3170507062,-0.872816250137 +1,-1.13712268693,-0.589193699823 +1,-0.913274447267,-0.887080518959 +1,-1.00859225235,-1.06118433703 +1,-0.792496157825,-0.915519738475 +0,0.980027951159,0.494567435192 +0,1.33339766676,0.837165358674 +1,-0.876643120952,-1.18292066828 +1,-0.979084375782,-0.82307189748 +0,0.816902215077,0.899673400048 +1,-0.57374001078,-0.734988521948 +1,-1.0934833097,-0.801864954993 +1,-1.06061689938,-1.34832895394 +1,-0.752697423238,-0.753716927762 +1,-1.42694528231,-0.697203669377 +1,-0.841274764173,-0.781559535673 +0,0.756459398564,0.844521971092 +0,1.29976101673,1.10831386395 +1,-1.03171705276,-1.25389155479 +1,-0.890890412259,-0.812819878009 +1,-0.551192238809,-0.994639833663 +0,1.24756402929,0.776557542347 +0,1.05839562468,1.04517104829 +1,-1.12073030272,-0.837256437347 +1,-0.827044749048,-1.05624253637 +0,0.51395666615,1.11051505652 +0,0.739673389253,0.741822007211 +1,-1.27381981888,-1.19688252607 +1,-0.654682929756,-0.973799559948 +1,-1.27326329396,-0.989773871415 +1,-1.09444784539,-1.27763064589 +1,-0.993825738802,-1.09185165069 +1,-0.802710426742,-0.993350846287 +0,1.11525094936,1.00938723789 +0,0.917137550958,0.816560261526 +0,0.992330530566,0.791111411532 +1,-0.707168296188,-0.925412488074 +0,1.13771832227,0.96075461916 +0,0.890959410912,0.938086246527 +1,-1.03627843595,-0.776940390286 +0,0.915069704117,0.666640094248 +0,0.816783179949,0.786688515747 +0,0.801292624248,0.903060164209 +0,1.16113203714,1.03195696871 +0,1.26008314539,1.24444035364 +1,-0.948892165167,-0.926001056779 +1,-0.941296417563,-1.10844766803 +1,-1.10542842758,-0.92429505133 +0,1.01750150234,1.25568028746 +0,0.752380047579,1.05486143046 +1,-0.73702836877,-1.23281202586 +0,0.627606560555,1.05109826999 +0,0.517776378126,1.19789326426 +1,-0.713222432843,-0.958592252019 +0,1.04382879954,0.480124791259 +1,-1.09346632927,-0.536183575128 +0,0.797768559477,0.867412022334 +1,-1.25166127245,-1.20054168598 +0,1.13751133453,1.4897145682 +0,1.32907920212,1.16890122329 +1,-1.05221319285,-0.619013961394 +0,1.31527194689,0.928056818898 +1,-0.84617505511,-0.960290093304 +0,0.93145065596,1.0457932325 +1,-1.1406088498,-0.766721044648 +1,-1.18117331655,-1.09069328092 +0,1.08225304345,0.829386217405 +1,-0.769894971178,-0.699558950919 +0,1.02050957341,0.986519722895 +1,-1.34325173759,-1.18516751807 +1,-0.582824676028,-0.97448263386 +1,-0.91682874077,-1.30224725509 +0,1.20912285457,0.986237097157 +1,-0.83145177721,-1.04546167864 +1,-1.20867057072,-0.695569099091 +1,-0.964820787331,-0.986652971401 +1,-0.723287318177,-1.17655839686 +0,1.24904540791,0.958520057889 +1,-0.959382519428,-1.10295631686 +1,-1.01521994073,-0.81720667044 +0,0.979055380374,0.930820910756 +0,1.11319239668,1.10061025635 +1,-1.10539188939,-1.15785134438 +0,1.25569667484,0.809511757516 +1,-1.00258583913,-1.0325433727 +1,-1.33488656243,-1.09859469138 +0,0.891235180887,0.819265367588 +0,1.14944236244,1.1630855575 +1,-0.946205231191,-0.849020692018 +0,0.943022145949,0.715867364935 +1,-0.76076441316,-0.779314332251 +0,0.910738956598,1.35569832177 +0,0.721201637698,0.955888602789 +1,-0.787255827868,-0.55261444834 +1,-0.573276712201,-0.959791230078 +0,1.0087782502,1.45900319422 +0,1.01110015772,1.33955783784 +1,-0.605276760547,-0.952284518005 +1,-0.925546454866,-1.17246909632 +1,-1.02198782791,-0.866145630703 +0,0.957553827204,1.05482417288 +1,-0.811182149039,-1.04335186967 +0,1.02046011441,0.904556127609 +1,-0.916580747162,-1.08094162804 +1,-1.31552650363,-0.750744076626 +0,0.559553483781,1.12829448684 +0,1.29397635256,1.04237547388 +1,-1.13287497615,-0.81420008866 +1,-0.944853915653,-1.20704136072 +1,-1.2461332122,-0.958377844123 +1,-0.977067258698,-1.30396981597 +0,0.903306301543,0.949296054698 +0,0.952134777039,1.06964011491 +0,0.621814550845,0.623016191179 +0,0.996972697804,1.2927729619 +0,0.719837189355,0.973600117394 +0,1.24771674478,0.924379447806 +0,1.09031605882,0.710713190899 +0,0.602160571357,0.907628480382 +1,-1.30509663897,-1.22218292874 +1,-1.24449003602,-1.11234483665 +1,-0.996925647094,-1.12394888285 +1,-0.796142515273,-1.06934554615 +1,-1.01453272118,-0.935585735187 +0,0.892871032184,0.89163677657 +0,0.878883422428,0.699279561569 +1,-0.888679242366,-0.954850905534 +0,1.48164360015,0.942929378117 +0,0.877942993272,0.82188051988 +0,1.2840129128,0.782852923729 +0,0.979628162094,1.0046376355 +0,1.0011746956,1.12267964379 +0,0.830659477108,0.869425923633 +1,-1.03726570769,-1.23452374255 +0,0.987657575081,1.37662447447 +0,0.580467666554,0.908464440585 +1,-0.660081209999,-0.984522246947 +1,-0.980369691949,-1.23243599229 +0,0.606533785575,0.699995653768 +0,1.2587530271,0.85275307954 +0,0.979099363225,1.17801837723 +1,-1.01138332504,-0.896736446022 +0,0.769932593875,1.1598559749 +1,-1.04170291003,-0.912274263627 +0,0.371305675372,1.19505468158 +1,-0.896310853917,-1.15976596784 +0,0.97455113657,0.807068299647 +0,1.21648553711,0.767050284428 +0,0.831378092761,0.649644408667 +1,-1.25986459056,-1.54711480041 +0,1.13340112634,0.736359688345 +1,-1.2384024028,-1.06698951467 +1,-0.988528343692,-0.73969035651 +1,-1.1593662035,-1.37746336517 +1,-0.933518993839,-0.984825025698 +1,-1.51994563014,-0.992198993144 +1,-1.18289178831,-1.09432642972 +1,-1.1489853972,-1.25476230938 +0,0.850649329045,1.04511356058 +0,0.94721169958,0.602117008885 +0,0.769624988299,0.841565193223 +0,0.882463189363,0.830069344874 +1,-1.19048953014,-1.11239954012 +1,-1.37596644766,-0.835583923757 +0,1.15934228485,1.09186345868 +0,1.05753142965,1.21179706895 +0,1.01735561622,0.706920484683 +0,1.25766505085,1.46784431684 +1,-0.904286143701,-0.99044421338 +0,0.688260024665,0.992351467181 +0,1.08799888383,0.775358047861 +1,-1.28388131583,-0.693712470801 +1,-0.998580601632,-0.918860662659 +1,-0.777905864167,-0.665513500153 +1,-1.04125869063,-1.01713047393 +1,-0.872275784821,-0.697519693307 +0,1.15431043138,1.10149551628 +0,1.33418913895,1.40379322465 +0,1.37519870216,0.721840657971 +0,1.02407222119,0.839840744484 +1,-0.812167585814,-1.05228941383 +1,-1.17749289026,-0.901481708669 +0,1.01857858685,0.897001451661 +0,0.871092966037,1.15403550193 +0,0.998247749893,1.00736964006 +1,-0.833648095387,-1.20145439675 +0,1.02711313653,0.749427178008 +1,-0.934653682918,-0.950688169763 +1,-0.707870466892,-1.05779320848 +0,1.19638680512,0.92284423706 +0,1.2409973686,1.16302331561 +1,-1.07686675539,-0.767626617732 +0,0.877645943426,1.21517199377 +1,-0.971147020183,-1.11221146273 +1,-0.867142933604,-0.749630698221 +1,-1.044349193,-0.998343943806 +0,1.010116662,1.18559779957 +1,-1.07085798109,-1.11210027352 +0,0.807044965878,0.988275972749 +1,-0.989614222944,-0.910899342659 +0,1.06351163418,1.02178086384 +0,0.926517738237,1.0643475345 +1,-0.824690984171,-1.12187500144 +0,0.952629213312,1.28773201112 +0,1.4960145587,1.19712198723 +1,-0.875043933252,-1.2711026884 +0,0.861985501019,1.2682967345 +0,1.02214055141,0.971836886647 +1,-1.04975530185,-0.861185052048 +0,1.16225778459,0.926479092841 +1,-0.935322471815,-0.96645101227 +0,1.23139224734,0.793878878449 +1,-0.83692733936,-0.698771683734 +0,1.09847002587,1.00367339063 +1,-0.937565881378,-1.12237616805 +0,0.979596407895,0.865568046059 +1,-0.912566324192,-0.872655513966 +1,-0.606827980293,-0.626358852502 +1,-0.905890154881,-0.680739811248 \ No newline at end of file diff --git a/example/clustering_2d.py b/example/clustering_2d.py new file mode 100644 index 0000000..fd5bbff --- /dev/null +++ b/example/clustering_2d.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, unicode_literals + +""" +Using Clustering +======================================== + +This is a simple example that illustrates Clustering service usage. + +""" + +from jubakit.clustering import Clustering, Schema, Dataset, Config +from jubakit.loader.csv import CSVLoader + +# Load a CSV file. +loader = CSVLoader('blobs.csv') + +# Define a Schema that defines types for each columns of the CSV file. +schema = Schema({ + 'cluster': Schema.ID, +}, Schema.NUMBER) + +# Create a Dataset. +dataset = Dataset(loader, schema) + +# Create an Clustering Service. +cfg = Config(method='kmeans') +clustering = Clustering.run(cfg) + +# Update the Clustering model. +for (idx, row_id, result) in clustering.push(dataset): + pass + +# Get clusters +clusters = clustering.get_core_members(light=False) +# Get centers of each cluster +centers = clustering.get_k_center() + +# Calculate SSE: sum of squared errors +sse = 0.0 +for cluster, center in zip(clusters, centers): + # Center of clusters + center = {"x1": center.num_values[0][1], "x2": center.num_values[1][1]} + for d in cluster: + vector = d.point.num_values + x1 = [x[1] for x in vector if x[0] == 'x1'][0] + x2 = [x[1] for x in vector if x[0] == 'x2'][0] + sse += (x1 - center["x1"])**2 + (x2- center["x2"])**2 +print('SSE:', sse) + +clustering.stop() diff --git a/example/clustering_sklearn_wrapper.py b/example/clustering_sklearn_wrapper.py new file mode 100644 index 0000000..a4fc22d --- /dev/null +++ b/example/clustering_sklearn_wrapper.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, unicode_literals + +""" +Using Clustering +======================================== + +This is a simple example that illustrates Clustering service usage. + +""" + +from sklearn.datasets import make_blobs + +from jubakit.wrapper.clustering import KMeans, GMM, DBSCAN + +# make blob dataset using sklearn API. +X, y = make_blobs(n_samples=200, centers=3, n_features=2, random_state=42) + +# launch clustering instance +clusterings = [ + KMeans(k=3, bucket_size=200, embedded=False), + GMM(k=3, bucket_size=200, embedded=False), + DBSCAN(eps=2.0, bucket_size=200, embedded=False) +] + +for clustering in clusterings: + # fit and predict + y_pred = clustering.fit_predict(X) + # print result + labels = set(y_pred) + label_counts = {} + for label in labels: + label_counts[label] = y_pred.count(label) + print('{0}: {1}'.format( + clustering.__class__.__name__, + label_counts)) + # stop clustering service + clustering.stop() + diff --git a/jubakit/_version.py b/jubakit/_version.py index 7cf8a70..7db2873 100644 --- a/jubakit/_version.py +++ b/jubakit/_version.py @@ -1 +1 @@ -VERSION = (0, 5, 2) +VERSION = (0, 5, 3) diff --git a/jubakit/clustering.py b/jubakit/clustering.py new file mode 100644 index 0000000..a590f74 --- /dev/null +++ b/jubakit/clustering.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import uuid + +import jubatus +import jubatus.embedded + +from .base import GenericSchema, BaseDataset, BaseService, GenericConfig, Utils +from .loader.array import ArrayLoader, ZipArrayLoader +from .loader.sparse import SparseMatrixLoader +from .loader.chain import ValueMapChainLoader, MergeChainLoader +from .compat import * + +class Schema(GenericSchema): + """ + Schema for Clustering service. + """ + + ID = 'i' + + def __init__(self, mapping, fallback=None): + self._id_key = self._get_unique_mapping(mapping, fallback, self.ID, 'ID', True) + super(Schema, self).__init__(mapping, fallback) + + def transform(self, row): + """ + Clustering schema transforms the row into Datum, its associated ID. + """ + row_id = row.get(self._id_key, None) + if row_id is not None: + row_id = unicode_t(row_id) + else: + row_id = unicode_t(uuid.uuid4()) + d = self._transform_as_datum(row, None, [self._id_key]) + return (row_id, d) + +class Dataset(BaseDataset): + """ + Dataset for Clustering service. + """ + + @classmethod + def _predict(cls, row): + return Schema.predict(row, False) + + @classmethod + def _from_loader(cls, data_loader, ids, static): + if ids is None: + loader = data_loader + schema = Schema({}, Schema.NUMBER) + else: + id_loader = ZipArrayLoader(_id=ids) + loader = MergeChainLoader(data_loader, id_loader) + schema = Schema({'_id': Schema.ID}, Schema.NUMBER) + return Dataset(loader, schema, static) + + @classmethod + def from_data(cls, data, ids=None, feature_names=None, static=True): + """ + Converts two arrays or a sparse matrix data and its associated id array to Dataset. + + Parameters + ---------- + data : array or scipy 2-D sparse matrix of shape [n_samples, n_features] + ids : array of shape [n_samples], optional + feature_names : array of shape [n_features], optional + """ + + if hasattr(data, 'todense'): + return cls.from_matrix(data, ids, feature_names, static) + else: + return cls.from_array(data, ids, feature_names, static) + + @classmethod + def from_array(cls, data, ids=None, feature_names=None, static=True): + """ + Converts two arrays (data and its associated targets) to Dataset. + + Parameters + ---------- + data : array of shape [n_samples, n_features] + ids : array of shape [n_samples], optional + feature_names : array of shape [n_features], optional + """ + + data_loader = ArrayLoader(data, feature_names) + return cls._from_loader(data_loader, ids, static) + + @classmethod + def from_matrix(cls, data, ids=None, feature_names=None, static=True): + """ + Converts a sparse matrix data and its associated target array to Dataset. + + Parameters + ---------- + + data : scipy 2-D sparse matrix of shape [n_samples, n_features] + ids : array of shape [n_samples], optional + feature_names : array of shape [n_features], optional + """ + + data_loader = SparseMatrixLoader(data, feature_names) + return cls._from_loader(data_loader, ids, static) + + def get_ids(self): + """ + Returns labels of each record in the dataset. + """ + + if not self._static: + raise RuntimeError('non-static datasets cannot fetch list of ids') + for (idx, (row_id, d)) in self: + yield row_id + +class Clustering(BaseService): + """ + Clustering service. + """ + + @classmethod + def name(cls): + return 'clustering' + + @classmethod + def _client_class(cls): + return jubatus.clustering.client.Clustering + + @classmethod + def _embedded_class(cls): + return jubatus.embedded.Clustering + + def push(self, dataset): + """ + Add data points. + """ + + cli = self._client() + for (idx, (row_id, d)) in dataset: + if row_id is None: + raise RuntimeError('each row must have `id`.') + result = cli.push([jubatus.clustering.types.IndexedPoint(row_id, d)]) + yield (idx, row_id, result) + + def get_revision(self): + """ + Return revision of clusters + """ + + cli = self._client() + return cli.get_revision() + + def get_core_members(self, light=False): + """ + Returns coreset of cluster in datum. + """ + + cli = self._client() + method = self._get_method() + if light: + return cli.get_core_members_light() + else: + return cli.get_core_members() + + def get_k_center(self): + """ + Return k cluster centers. + """ + + cli = self._client() + method = self._get_method() + if method not in ('kmeans', 'gmm'): + raise RuntimeError('{0} is not supported'.format(method)) + return cli.get_k_center() + + def get_nearest_center(self, dataset): + """ + Returns nearest cluster center without adding points to cluster. + """ + + cli = self._client() + method = self._get_method() + if method not in ('kmeans', 'gmm'): + raise RuntimeError('{0} is not supported'.format(method)) + + for (idx, (row_id, d)) in dataset: + result = cli.get_nearest_center(d) + yield (idx, row_id, result) + + def get_nearest_members(self, dataset, light=False): + """ + Returns nearest summary of cluster(coreset) from each point. + """ + + cli = self._client() + method = self._get_method() + if method not in ('kmeans', 'gmm'): + raise RuntimeError('{0} is not supported'.format(method)) + + for (idx, (row_id, d)) in dataset: + if light: + result = cli.get_nearest_members_light(d) + else: + result = cli.get_nearest_members(d) + yield (idx, row_id, result) + + def _get_method(self): + method = None + if self._embedded: + config = json.loads(self._backend.model.get_config()) + method = config['method'] + else: + if 'method' in self._backend.config: + method = self._backend.config['method'] + return method + +class Config(GenericConfig): + """ + Configulation to run Clustering service. + """ + + def __init__(self, method=None, parameter=None, + compressor_method=None, compressor_parameter=None, + converter=None): + super(Config, self).__init__(method, parameter, converter) + if compressor_method is not None: + self['compressor_method'] = compressor_method + default_compressor_parameter = \ + self._default_compressor_parameter(compressor_method) + if default_compressor_parameter is None: + if 'compressor_parameter' in self: + del self['compressor_parameter'] + else: + self['compressor_parameter'] = default_compressor_parameter + + if compressor_parameter is not None: + if 'compressor_parameter' in self: + self['compressor_parameter'].update(compressor_parameter) + else: + self['compressor_parameter'] = compressor_parameter + + @classmethod + def _default(cls, cfg): + super(Config, cls)._default(cfg) + + compressor_method = cls._default_compressor_method() + compressor_parameter = cls._default_compressor_parameter(compressor_method) + + if compressor_method is not None: + cfg['compressor_method'] = compressor_method + if compressor_parameter is not None: + cfg['compressor_parameter'] = compressor_parameter + + return cfg + + @classmethod + def _default_method(cls): + return 'kmeans' + + @classmethod + def _default_compressor_method(cls): + return 'simple' + + @classmethod + def _default_parameter(cls, method): + if method in ('kmeans', 'gmm'): + return { + 'k': 3, + 'seed': 0 + } + elif method in ('dbscan'): + return { + 'eps': 0.2, + 'min_core_point': 3 + } + else: + raise RuntimeError('unknown method: {0}'.format(method)) + + @classmethod + def _default_compressor_parameter(cls, method): + if method in ('simple'): + return { + 'bucket_size': 100 + } + elif method in ('compressive'): + return { + 'bucket_size': 100, + 'bucket_length': 2, + 'compressed_bucket_size': 100, + 'bicriteria_base_size': 10, + 'forgetting_factor': 0.0, + 'forgetting_threshold': 0.5, + 'seed': 0 + } + else: + raise RuntimeError('unknown method: {0}'.format(method)) + + @classmethod + def methods(cls): + return ['kmeans', 'gmm', 'dbscan'] + + @classmethod + def compressor_methods(cls): + return ['simple', 'compressive'] + diff --git a/jubakit/test/test_clustering.py b/jubakit/test/test_clustering.py new file mode 100644 index 0000000..b02a667 --- /dev/null +++ b/jubakit/test/test_clustering.py @@ -0,0 +1,332 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, unicode_literals + +from unittest import TestCase + +try: + import numpy as np + from scipy.sparse import csr_matrix +except importError: + pass + +from jubakit.clustering import Schema, Dataset, Clustering, Config +from jubakit.compat import * + +from . import requireSklearn, requireEmbedded +from .stub import * + +class SchemaTest(TestCase): + def test_simple(self): + schema = Schema({ + 'id': Schema.ID, + 'k1': Schema.STRING, + 'k2': Schema.NUMBER + }) + (row_id, d) = schema.transform({'id': 'user001', 'k1': 'abc', 'k2': '123'}) + + self.assertEqual(row_id, 'user001') + self.assertEqual({'k1': 'abc'}, dict(d.string_values)) + self.assertEqual({'k2': 123}, dict(d.num_values)) + + def test_without_id(self): + # schema without id can be defined + Schema({ + 'k1': Schema.STRING, + }) + + def test_illegal_id(self): + # schema with multiple IDs + self.assertRaises(RuntimeError, Schema, { + 'k1': Schema.ID, + 'k2': Schema.ID, + }) + + # schema fallback set to id + self.assertRaises(RuntimeError, Schema, { + 'k1': Schema.ID + }, Schema.ID) + +class DatasetTest(TestCase): + def test_simple(self): + loader = StubLoader() + schema = Schema({'v': Schema.ID}) + ds = Dataset(loader, schema) + for (idx, (label, d)) in ds: + self.assertEqual(unicode_t(idx+1), label) + self.assertEqual(0, len(d.string_values)) + self.assertEqual(0, len(d.num_values)) + self.assertEqual(0, len(d.binary_values)) + self.assertEqual(['1','2','3'], list(ds.get_ids())) + + def test_predict(self): + loader = StubLoader() + dataset = Dataset(loader) + self.assertEqual(['v', 1.0], dataset[0][1].num_values[0]) + + def test_from_data(self): + # load from array format + ds = Dataset.from_data( + [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data + ['i1', 'i2', 'i3'], # ids + ['k1', 'k2', 'k3'] # feature names + ) + + expected_k1s = [10, 20, 40] + expected_ids = ['i1', 'i2', 'i3'] + actual_k1s = [] + actual_ids = [] + for (idx, (row_id, d)) in ds: + actual_k1s.append(dict(d.num_values).get('k1', None)) + actual_ids.append(row_id) + + self.assertEqual(expected_k1s, actual_k1s) + self.assertEqual(expected_ids, actual_ids) + + # load from scipy.sparse format + ds = Dataset.from_data( + self._create_matrix(), # data + ['i1', 'i2', 'i3'], # ids + [ 'k1', 'k2', 'k3'], # feature_names + ) + + expected_k1s = [1, None, 4] + expected_k3s = [2, 3, 6] + expected_ids = ['i1', 'i2', 'i3'] + actual_k1s = [] + actual_k3s = [] + actual_ids = [] + for (idx, (row_id, d)) in ds: + actual_k1s.append(dict(d.num_values).get('k1', None)) + actual_k3s.append(dict(d.num_values).get('k3', None)) + actual_ids.append(row_id) + + self.assertEqual(expected_k1s, actual_k1s) + self.assertEqual(expected_k3s, actual_k3s) + self.assertEqual(expected_ids, actual_ids) + + def test_from_array(self): + ds = Dataset.from_array( + [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data + ['i1', 'i2', 'i3'], # ids + ['k1', 'k2', 'k3'] # feature names + ) + + expected_k1s = [10, 20, 40] + expected_ids = ['i1', 'i2', 'i3'] + actual_k1s = [] + actual_ids = [] + for (idx, (row_id, d)) in ds: + actual_k1s.append(dict(d.num_values).get('k1', None)) + actual_ids.append(row_id) + + self.assertEqual(expected_k1s, actual_k1s) + self.assertEqual(expected_ids, actual_ids) + + def test_from_array_without_ids(self): + ds = Dataset.from_array( + [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data + feature_names=['k1', 'k2', 'k3'] # feature names + ) + + expected_k1s = [10, 20, 40] + actual_k1s = [] + actual_ids = [] + for (idx, (row_id, d)) in ds: + actual_k1s.append(dict(d.num_values).get('k1', None)) + actual_ids.append(row_id) + self.assertEqual(expected_k1s, actual_k1s) + self.assertEqual(len(actual_ids), 3) + + @requireSklearn + def test_from_matrix(self): + ds = Dataset.from_matrix( + self._create_matrix(), # data + ['i1', 'i2', 'i3'], # ids + ['k1', 'k2', 'k3'] # feature names + ) + + expected_k1s = [1, None, 4] + expected_k3s = [2, 3, 6] + expected_ids = ['i1', 'i2', 'i3'] + actual_k1s = [] + actual_k3s = [] + actual_ids = [] + for (idx, (row_id, d)) in ds: + actual_k1s.append(dict(d.num_values).get('k1', None)) + actual_k3s.append(dict(d.num_values).get('k3', None)) + actual_ids.append(row_id) + + self.assertEqual(expected_k1s, actual_k1s) + self.assertEqual(expected_k3s, actual_k3s) + self.assertEqual(expected_ids, actual_ids) + + def test_get_ids(self): + ds = Dataset.from_array( + [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data + ['i1', 'i2', 'i3'], # ids + static=True + ) + actual_ids = [] + expected_ids = ['i1', 'i2', 'i3'] + for row_id in ds.get_ids(): + actual_ids.append(row_id) + self.assertEqual(expected_ids, actual_ids) + + ds = Dataset.from_array( + [ [10, 20, 30], [20, 10, 50], [40, 10, 30]], # data + ['i1', 'i2', 'i3'], # ids + static=False + ) + self.assertRaises(RuntimeError, list, ds.get_ids()) + + def _create_matrix(self): + """ + Create a sparse matrix: + + [[1, 0, 2], + [0, 0, 3], + [4, 5, 6]] + """ + row = np.array([0, 0, 1, 2, 2, 2]) + col = np.array([0, 2, 2, 0, 1, 2]) + data = np.array([1, 2, 3, 4, 5, 6]) + return csr_matrix((data, (row, col)), shape=(3, 3)) + +class ClusteringTest(TestCase): + def test_simple(self): + clustering = Clustering() + clustering.stop() + + @requireEmbedded + def test_embedded(self): + clustering = Clustering.run(Config(), embedded=True) + clustering.stop() + + def test_push(self): + clustering = Clustering.run(Config()) + dataset = self._make_stub_dataset() + for (idx, row_id, result) in clustering.push(dataset): + self.assertEqual(result, True) + clustering.stop() + + def test_get_revision(self): + clustering = Clustering.run(Config()) + self.assertEqual(0, clustering.get_revision()) + clustering.stop() + + def test_get_core_members(self): + dataset = self._make_stub_dataset() + config = Config(method='kmeans', compressor_parameter={"bucket_size": 5}) + clustering = self._make_stub_clustering(config, dataset) + clustering.get_core_members(light=False) + clustering.get_core_members(light=True) + clustering.stop() + + def test_get_k_center(self): + def func(clustering, dataset): + clustering.get_k_center() + self._test_func_with_legal_and_illegal_config(func) + + def test_get_nearest_center(self): + def func(clustering, dataset): + for _ in clustering.get_nearest_center(dataset): pass + self._test_func_with_legal_and_illegal_config(func) + + def test_get_nearest_members(self): + def func1(clustering, dataset): + for _ in clustering.get_nearest_members(dataset, light=False): pass + self._test_func_with_legal_and_illegal_config(func1) + + def func2(clustering, dataset): + for _ in clustering.get_nearest_members(dataset, light=True): pass + self._test_func_with_legal_and_illegal_config(func2) + + def _test_func_with_legal_and_illegal_config(self, func): + dataset = self._make_stub_dataset() + # test illegal method + config = Config(method='dbscan', compressor_parameter={"bucket_size": 5}) + clustering = self._make_stub_clustering(config, dataset) + self.assertRaises(RuntimeError, lambda: func(clustering, dataset)) + clustering.stop() + + # test legal method + config = Config(method='kmeans', compressor_parameter={"bucket_size": 5}) + clustering = self._make_stub_clustering(config, dataset) + func(clustering, dataset) + clustering.stop() + + def _make_stub_clustering(self, config, dataset): + dataset = self._make_stub_dataset() + clustering = Clustering.run(config) + for _ in clustering.push(dataset): pass + return clustering + + def _make_stub_dataset(self): + ids = ['id1', 'id2', 'id3', 'id4', 'id5'] + X = [ + [0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [3, 3, 3], + [4, 4, 4] + ] + dataset = Dataset.from_array(X, ids=ids) + return dataset + + +class ConfigTest(TestCase): + def test_simple(self): + config = Config() + self.assertEqual('kmeans', config['method']) + self.assertEqual({'k': 3, 'seed': 0}, config['parameter']) + self.assertEqual('simple', config['compressor_method']) + self.assertEqual({'bucket_size': 100}, config.get('compressor_parameter')) + + def test_methods(self): + config = Config() + self.assertTrue(isinstance(config.methods(), list)) + + def test_compressor_methods(self): + config = Config() + self.assertTrue(isinstance(config.compressor_methods(), list)) + + def test_illegal_comporessor_method(self): + self.assertRaises(RuntimeError, + Config._default_compressor_parameter, + 'invalid_compressor_method') + + def test_default(self): + config = Config.default() + self.assertEqual('kmeans', config['method']) + self.assertEqual('simple', config['compressor_method']) + + def test_method_params(self): + self.assertTrue('k' in Config(method='kmeans')['parameter']) + self.assertTrue('seed' in Config(method='kmeans')['parameter']) + self.assertTrue('k' in Config(method='gmm')['parameter']) + self.assertTrue('seed' in Config(method='gmm')['parameter']) + self.assertTrue('eps' in Config(method='dbscan')['parameter']) + self.assertTrue('min_core_point' in Config(method='dbscan')['parameter']) + + def test_compressor_params(self): + self.assertTrue('bucket_size' in + Config(compressor_method='simple')['compressor_parameter']) + self.assertTrue('bucket_size' in + Config(compressor_method='compressive')['compressor_parameter']) + self.assertTrue('bucket_length' in + Config(compressor_method='compressive')['compressor_parameter']) + self.assertTrue('compressed_bucket_size' in + Config(compressor_method='compressive')['compressor_parameter']) + self.assertTrue('bicriteria_base_size' in + Config(compressor_method='compressive')['compressor_parameter']) + self.assertTrue('forgetting_factor' in + Config(compressor_method='compressive')['compressor_parameter']) + self.assertTrue('seed' in + Config(compressor_method='compressive')['compressor_parameter']) + config = Config(compressor_method='simple', + compressor_parameter={'bucket_size': 10}) + self.assertEqual(10, config['compressor_parameter']['bucket_size']) + + def test_invalid_method(self): + self.assertRaises(RuntimeError, Config._default_parameter, 'invalid_method') diff --git a/jubakit/test/wrapper/test_clustering.py b/jubakit/test/wrapper/test_clustering.py new file mode 100644 index 0000000..639306d --- /dev/null +++ b/jubakit/test/wrapper/test_clustering.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, unicode_literals + +from unittest import TestCase + +try: + import numpy as np +except ImportError: + pass + +from jubakit.wrapper.clustering import KMeans, GMM, DBSCAN +from . import requireEmbedded + + +class KMeansTest(TestCase): + + def test_simple(self): + clustering = KMeans(embedded=False) + clustering.stop() + + @requireEmbedded + def test_embedded(self): + clustering = KMeans(embedded=True) + + def test_init(self): + clustering = KMeans(embedded=False) + self.assertEqual(2, clustering.k) + self.assertEqual('simple', clustering.compressor_method) + self.assertEqual(100, clustering.bucket_size) + self.assertEqual(100, clustering.compressed_bucket_size) + self.assertEqual(10, clustering.bicriteria_base_size) + self.assertEqual(2, clustering.bucket_length) + self.assertEqual(0.0, clustering.forgetting_factor) + self.assertEqual(0.5, clustering.forgetting_threshold) + self.assertEqual(0, clustering.seed) + self.asssrtTrue(not clustering.embedded) + clustering.stop() + + def test_method(self): + clustering = KMeans(embedded=False) + self.assertEqual('kmeans', clustering._method()) + clustering.stop() + + def test_make_compressor_parameter(self): + clustering = KMeans(compressor_method='simple', embedded=False) + compressor_parameter = {'bucket_size': 100} + self.assertEqual(compressor_parameter, + clustering._make_compressor_parameter('simple')) + clustering.stop() + + clustering = KMeans(compressor_method='compressive', embedded=False) + compressor_parameter = { + 'bucket_size': 100, + 'compressed_bucket_size': 100, + 'bicriteria_base_size': 10, + 'bucket_length': 2, + 'forgetting_factor': 0.0, + 'forgetting_threshold': 0.5, + 'seed': 0 + } + self.assertEqual(compressor_parameter + clustering._make_compressor_parameter('compressive')) + clustering.stop() + + def test_fit(self): + X = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]) + clustering = KMeans(k=10, embedded=False) + self.assertRaises(RuntimeWarning, clustering.fit(X)) + clustering.stop() + + clustering = KMeans(k=5, embedded=False) + self.assertTrue(not clustering.fitted) + clustering.fit(X) + self.assertTrue(clustering.fitted) + + def test_predict(self): + X = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]) + clustering = KMeans(k=5, embedded=False) + self.assertRaises(RuntimeError, clustering.predict(X)) + clustering.fit(X) + y_pred = clustering.predict(X) + self.assertEqual(len(y_pred), X.shape[0]) + + def test_fit_predict(self): + X = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4]]) + clustering = KMeans(k=5, embedded=False) + self.assertTrue(not clustering.fitted) + y_pred = clustering.fit_predict(X) + self.assertTrue(clustering.fitted) + self.assertEqual(len(y_pred), X.shape[0]) + + +class GMMTest(TestCase): + + def test_simple(self): + clustering = GMM(embedded=False) + clustering.stop() + + @requireEmbedded + def test_embedded(self): + clustering = GMM(embedded=True) + + def test_method(self): + clustering = GMM(embedded=False) + self.assertEqual('gmm', clustering._method()) + clustering.stop() + +class DBSCANTest(testCase): + + def test_simple(self): + clustering = DBSCAN(embedded=False) + clustering.stop() + + @requireEmbedded + def test_embedded(self): + clustering = DBSCAN(embedded=True) + + def test_init(self): + clustering = DBSCAN(embedded=False) + self.assertEqual(0.2, clustering.eps) + self.assertEqual(3, clustering.min_core_point) + self.assertEqual('simple', clustering.compressor_method) + self.assertEqual(100, clustering.bucket_size) + self.assertEqual(100, clustering.compressed_bucket_size) + self.assertEqual(10, clustering.bicriteria_base_size) + self.assertEqual(2, clustering.bucket_length) + self.assertEqual(0.0, clustering.forgetting_factor) + self.assertEqual(0.5, clustering.forgetting_threshold) + self.assertEqual(0, clustering.seed) + self.asssrtTrue(not clustering.embedded) + clustering.stop() + diff --git a/jubakit/wrapper/clustering.py b/jubakit/wrapper/clustering.py new file mode 100644 index 0000000..301a8e0 --- /dev/null +++ b/jubakit/wrapper/clustering.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, unicode_literals + +import numpy as np +from sklearn.base import BaseEstimator, ClusterMixin, TransformerMixin +from ..clustering import Clustering, Config, Dataset + + +class BaseJubatusClustering(BaseEstimator, ClusterMixin): + """ + scikit-learn Wrapper for Jubatus Clustering. + """ + + def __init__(self, compressor_method='simple', + bucket_size=100, compressed_bucket_size=100, + bicriteria_base_size=10, bucket_length=2, + forgetting_factor=0.0, forgetting_threshold=0.5, + seed=0, embedded=True): + """ + Creates a base class for Jubatus Clustering + """ + self.compressor_method = compressor_method + self.bucket_size = bucket_size + self.compressed_bucket_size = compressed_bucket_size + self.bicriteria_base_size = bicriteria_base_size + self.bucket_length = bucket_length + self.forgetting_factor = forgetting_factor + self.forgetting_threshold = forgetting_threshold + self.seed = seed + self.embedded = embedded + self.compressor_parameter = \ + self._make_compressor_parameter(self.compressor_method) + self.fitted = False + + def _launch_clustering(self): + """ + Launch Jubatus Clustering + """ + raise NotImplementedError() + + def _make_compressor_parameter(self, compressor_method): + if compressor_method == 'simple': + return { + 'bucket_size': self.bucket_size, + } + elif compressor_method == 'compressive': + return { + 'bucket_size': self.bucket_size, + 'compressed_bucket_size': self.compressed_bucket_size, + 'bicriteria_base_size': self.bicriteria_base_size, + 'bucket_length': self.bucket_length, + 'forgetting_factor': self.forgetting_factor, + 'forgetting_threshold': self.forgetting_threshold, + 'seed': self.seed + } + else: + raise NotImplementedError() + + def fit_predict(self, X, y=None): + """ + Construct clustering model and + Predict the closest cluster each sample in X belongs to. + """ + ids = list(range(len(X))) + dataset = Dataset.from_data(X, ids=ids) + self._launch_clustering() + self.clustering_.clear() + for _ in self.clustering_.push(dataset): + pass + self.fitted = True + clusters = self.clustering_.get_core_members(light=True) + labels = ['None'] * len(ids) + for cluster_id, cluster in enumerate(clusters): + for point in cluster: + labels[int(point.id)] = cluster_id + return labels + + def stop(self): + self.clustering_.stop() + + def clear(self): + self.clustering_.clear() + + +class BaseKFixedClustering(BaseJubatusClustering): + + def __init__(self, k=2, compressor_method='simple', + bucket_size=100, compressed_bucket_size=100, + bicriteria_base_size=10, bucket_length=2, + forgetting_factor=0.0, forgetting_threshold=0.5, + seed=0, embedded=True): + super(BaseKFixedClustering, self).__init__( + compressor_method, bucket_size, compressed_bucket_size, bicriteria_base_size, + bucket_length, forgetting_factor, forgetting_threshold, seed, embedded) + self.k = k + + def _method(self): + raise NotImplementedError() + + def _launch_clustering(self): + self.method = self._method() + self.parameter = { + 'k': self.k, + 'seed': self.seed + } + self.config_ = Config(method=self.method, parameter=self.parameter, + compressor_method=self.compressor_method, + compressor_parameter=self.compressor_parameter) + self.clustering_ = Clustering.run(config=self.config_, + embedded=self.embedded) + + def fit(self, X, y=None): + """ + Construct clustering model. + """ + if len(X) < self.k: + raise RuntimeWarning("At least k={0} points are needed \ + but {1} points given".format(self.k, len(X))) + dataset = Dataset.from_data(X) + self._launch_clustering() + self.clustering_.clear() + for _ in self.clustering_.push(dataset): pass + self.fitted = True + return self + + def predict(self, X): + """ + Predict the closest cluster each sample in X belongs to. + """ + if not self.fitted: + raise RuntimeError("clustering model not fitted yet.") + dataset = Dataset.from_data(X) + y_pred = [] + mappings = {} + count = 0 + for idx, row_id, result in self.clustering_.get_nearest_center(dataset): + if result not in mappings: + mappings[result] = count + y_pred.append(count) + count += 1 + else: + y_pred.append(mappings[result]) + return y_pred + + +class KMeans(BaseKFixedClustering): + + def _method(self): + return 'kmeans' + + +class GMM(BaseKFixedClustering): + + def _method(self): + return 'gmm' + + +class DBSCAN(BaseJubatusClustering): + + def __init__(self, eps=0.2, min_core_point=3, + bucket_size=100, compressed_bucket_size=100, + bicriteria_base_size=10, bucket_length=2, + forgetting_factor=0.0, forgetting_threshold=0.5, + seed=0, embedded=True): + super(DBSCAN, self).__init__('simple', bucket_size, + compressed_bucket_size, bicriteria_base_size, + bucket_length, forgetting_factor, + forgetting_threshold, seed, embedded) + self.eps = eps + self.min_core_point = min_core_point + + def _launch_clustering(self): + self.method = 'dbscan' + self.parameter = { + 'eps': self.eps, + 'min_core_point': self.min_core_point + } + self.config_ = Config(method=self.method, parameter=self.parameter, + compressor_method=self.compressor_method, + compressor_parameter=self.compressor_parameter) + self.clustering_ = Clustering.run(config=self.config_, + embedded=self.embedded) +