Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Dask-ML ParallelPostFit prediction fails on empty partitions #911

Open
VibhuJawa opened this issue Mar 25, 2022 · 0 comments · May be fixed by #912
Open

[BUG] Dask-ML ParallelPostFit prediction fails on empty partitions #911

VibhuJawa opened this issue Mar 25, 2022 · 0 comments · May be fixed by #912

Comments

@VibhuJawa
Copy link
Collaborator

Dask-ML ParallelPostFit prediction fails on empty partitions

Minimal Complete Verifiable Example:

from sklearn.linear_model import LogisticRegression
import dask.dataframe as dd
from dask_ml.wrappers import ParallelPostFit
import pandas as pd

df = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8], "y": [True, False] * 4})
ddf = dd.from_pandas(df, npartitions=4)

clf = ParallelPostFit(LogisticRegression())
clf = clf.fit(df[["x"]], df["y"])

ddf_with_empty_part = ddf[ddf.x < 5][["x"]]
result = clf.predict(ddf_with_empty_part).compute()

expected = clf.estimator.predict(ddf_with_empty_part.compute())

assert_eq_ar(result, expected)

TRACE

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [3], in <cell line: 1>()
----> 1 result.compute()

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/base.py:290, in DaskMethodsMixin.compute(self, **kwargs)
    266 def compute(self, **kwargs):
    267     """Compute this dask collection
    268 
    269     This turns a lazy Dask collection into its in-memory equivalent.
   (...)
    288     dask.base.compute
    289     """
--> 290     (result,) = compute(self, traverse=False, **kwargs)
    291     return result

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/base.py:573, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
    570     keys.append(x.__dask_keys__())
    571     postcomputes.append(x.__dask_postcompute__())
--> 573 results = schedule(dsk, keys, **kwargs)
    574 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/threaded.py:81, in get(dsk, result, cache, num_workers, pool, **kwargs)
     78     elif isinstance(pool, multiprocessing.pool.Pool):
     79         pool = MultiprocessingPoolExecutor(pool)
---> 81 results = get_async(
     82     pool.submit,
     83     pool._max_workers,
     84     dsk,
     85     result,
     86     cache=cache,
     87     get_id=_thread_get_id,
     88     pack_exception=pack_exception,
     89     **kwargs,
     90 )
     92 # Cleanup pools associated to dead threads
     93 with pools_lock:

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/local.py:506, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
    504         _execute_task(task, data)  # Re-execute locally
    505     else:
--> 506         raise_exception(exc, tb)
    507 res, worker_id = loads(res_info)
    508 state["cache"][key] = res

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/local.py:314, in reraise(exc, tb)
    312 if exc.__traceback__ is not tb:
    313     raise exc.with_traceback(tb)
--> 314 raise exc

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/local.py:219, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
    217 try:
    218     task, data = loads(task_info)
--> 219     result = _execute_task(task, data)
    220     id = get_id()
    221     result = dumps((result, id))

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/optimization.py:969, in SubgraphCallable.__call__(self, *args)
    967 if not len(args) == len(self.inkeys):
    968     raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 969 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/core.py:149, in get(dsk, out, cache)
    147 for key in toposort(dsk):
    148     task = dsk[key]
--> 149     result = _execute_task(task, cache)
    150     cache[key] = result
    151 result = _execute_task(out, cache)

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
    115     func, args = arg[0], arg[1:]
    116     # Note: Don't assign the subtask results to a variable. numpy detects
    117     # temporaries by their reference count and can execute certain
    118     # operations in-place.
--> 119     return func(*(_execute_task(a, cache) for a in args))
    120 elif not ishashable(arg):
    121     return arg

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/utils.py:39, in apply(func, args, kwargs)
     37 def apply(func, args, kwargs=None):
     38     if kwargs:
---> 39         return func(*args, **kwargs)
     40     else:
     41         return func(*args)

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/dataframe/core.py:6259, in apply_and_enforce(*args, **kwargs)
   6257 func = kwargs.pop("_func")
   6258 meta = kwargs.pop("_meta")
-> 6259 df = func(*args, **kwargs)
   6260 if is_dataframe_like(df) or is_series_like(df) or is_index_like(df):
   6261     if not len(df):

File ~/dask_ml_dev/dask-ml/dask_ml/wrappers.py:630, in _predict(part, estimator)
    629 def _predict(part, estimator):
--> 630     return estimator.predict(part)

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/sklearn/linear_model/_base.py:425, in LinearClassifierMixin.predict(self, X)
    411 def predict(self, X):
    412     """
    413     Predict class labels for samples in X.
    414 
   (...)
    423         Vector containing the class labels for each sample.
    424     """
--> 425     scores = self.decision_function(X)
    426     if len(scores.shape) == 1:
    427         indices = (scores > 0).astype(int)

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/sklearn/linear_model/_base.py:407, in LinearClassifierMixin.decision_function(self, X)
    387 """
    388 Predict confidence scores for samples.
    389 
   (...)
    403     this class would be predicted.
    404 """
    405 check_is_fitted(self)
--> 407 X = self._validate_data(X, accept_sparse="csr", reset=False)
    408 scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_
    409 return scores.ravel() if scores.shape[1] == 1 else scores

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/sklearn/base.py:566, in BaseEstimator._validate_data(self, X, y, reset, validate_separately, **check_params)
    564     raise ValueError("Validation should be done on X, y or both.")
    565 elif not no_val_X and no_val_y:
--> 566     X = check_array(X, **check_params)
    567     out = X
    568 elif no_val_X and not no_val_y:

File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/sklearn/utils/validation.py:805, in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)
    803     n_samples = _num_samples(array)
    804     if n_samples < ensure_min_samples:
--> 805         raise ValueError(
    806             "Found array with %d sample(s) (shape=%s) while a"
    807             " minimum of %d is required%s."
    808             % (n_samples, array.shape, ensure_min_samples, context)
    809         )
    811 if ensure_min_features > 0 and array.ndim == 2:
    812     n_features = array.shape[1]

ValueError: Found array with 0 sample(s) (shape=(0, 1)) while a minimum of 1 is required.

Anything else we need to know?:

Related Issue: dask-contrib/dask-sql#414

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant