Skip to content

Commit

Permalink
Integrate dask-expr and make CI happy (dask#980)
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger authored Mar 7, 2024
1 parent b5640cb commit fa5a583
Show file tree
Hide file tree
Showing 45 changed files with 329 additions and 113 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/action@v3.0.0
- uses: actions/checkout@v4.1.1
- uses: actions/setup-python@v5
with:
python-version: '3.9'
- uses: pre-commit/action@v3.0.1
4 changes: 3 additions & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ jobs:
matrix:
# os: ["windows-latest", "ubuntu-latest", "macos-latest"]
os: ["ubuntu-latest"]
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11"]
query-planning: [true, false]

env:
PYTHON_VERSION: ${{ matrix.python-version }}
PARALLEL: "true"
COVERAGE: "true"
DASK_DATAFRAME__QUERY_PLANNING: ${{ matrix.query-planning }}

steps:
- name: Checkout source
Expand Down
13 changes: 8 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
repos:
- repo: https://github.com/python/black
rev: 22.3.0
- repo: https://github.com/psf/black
rev: 23.12.1
hooks:
- id: black
language_version: python3
args:
- --target-version=py39
- repo: https://github.com/pycqa/flake8
rev: 3.7.9
rev: 7.0.0
hooks:
- id: flake8
language_version: python3
- repo: https://github.com/timothycrosley/isort
rev: 4.3.21
args: ["--ignore=E501,W503,E203,E741,E731"]
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
language_version: python3
5 changes: 4 additions & 1 deletion ci/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ channels:
- conda-forge
- defaults
dependencies:
- dask
- dask-glm
- multipledispatch >=0.4.9
- mypy
Expand All @@ -21,3 +20,7 @@ dependencies:
- scipy
- sparse
- toolz
- pip
- pip:
- git+https://github.com/dask-contrib/dask-expr
- git+https://github.com/dask/dask
9 changes: 6 additions & 3 deletions ci/environment-3.8.yaml → ci/environment-3.11.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
name: dask-ml-3.8
name: dask-ml-3.11
channels:
- conda-forge
- defaults
dependencies:
- dask
- dask-glm
- multipledispatch >=0.4.9
- mypy
Expand All @@ -16,8 +15,12 @@ dependencies:
- pytest
- pytest-cov
- pytest-mock
- python=3.8.*
- python=3.11.*
- scikit-learn >=1.2.0
- scipy
- sparse
- toolz
- pip
- pip:
- git+https://github.com/dask-contrib/dask-expr
- git+https://github.com/dask/dask
5 changes: 4 additions & 1 deletion ci/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ channels:
- conda-forge
- defaults
dependencies:
- dask
- dask-glm
- multipledispatch >=0.4.9
- mypy
Expand All @@ -21,3 +20,7 @@ dependencies:
- scipy
- sparse
- toolz
- pip
- pip:
- git+https://github.com/dask-contrib/dask-expr
- git+https://github.com/dask/dask
34 changes: 26 additions & 8 deletions ci/environment-docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,21 @@ channels:
dependencies:
- black
- coverage
- graphviz
- heapdict
- ipykernel
- ipython
- multipledispatch
- mypy
- nbsphinx
- nomkl
- nose
- numba
- numpy
- numpydoc
- pandas
- psutil
- python=3.10
- sortedcontainers
- scikit-learn >=1.2.0
- scipy
- sparse
- sphinx
- sphinx_rtd_theme
- sphinx-gallery
- tornado
- toolz
- zict
Expand All @@ -35,5 +28,30 @@ dependencies:
- dask-glm
- dask-xgboost
- pip:
- dask-sphinx-theme >=3.0.0
- graphviz
- numpydoc
- sphinx>=4.0.0,<5.0.0
- dask-sphinx-theme>=3.0.0
- sphinx-click
- sphinx-copybutton
- sphinx-remove-toctrees
- sphinx_autosummary_accessors
- sphinx-tabs
- sphinx-design
- jupyter_sphinx
# FIXME: `sphinxcontrib-*` pins are a workaround until we have sphinx>=5.
# See https://github.com/dask/dask-sphinx-theme/issues/68.
- sphinxcontrib-applehelp>=1.0.0,<1.0.7
- sphinxcontrib-devhelp>=1.0.0,<1.0.6
- sphinxcontrib-htmlhelp>=2.0.0,<2.0.5
- sphinxcontrib-serializinghtml>=1.1.0,<1.1.10
- sphinxcontrib-qthelp>=1.0.0,<1.0.7
- toolz
- cloudpickle>=1.5.0
- pandas>=1.4.0
- dask-expr
- fsspec
- scipy
- pytest
- pytest-check-links
- requests-cache
2 changes: 1 addition & 1 deletion dask_ml/_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def fit(
shuffle_blocks=True,
random_state=None,
assume_equal_chunks=False,
**kwargs
**kwargs,
):
"""Fit scikit learn model against dask arrays
Expand Down
7 changes: 3 additions & 4 deletions dask_ml/ensemble/_blockwise.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import dask
import dask.array as da
import dask.dataframe as dd
import numpy as np
import sklearn.base
from sklearn.utils.validation import check_is_fitted

from ..base import ClassifierMixin, RegressorMixin
from ..utils import check_array
from ..utils import check_array, is_frame_base


class BlockwiseBase(sklearn.base.BaseEstimator):
Expand Down Expand Up @@ -62,7 +61,7 @@ def _predict(self, X):
dtype=np.dtype(dtype),
chunks=chunks,
)
elif isinstance(X, dd._Frame):
elif is_frame_base(X):
meta = np.empty((0, len(self.classes_)), dtype=dtype)
combined = X.map_partitions(
_predict_stack, estimators=self.estimators_, meta=meta
Expand Down Expand Up @@ -184,7 +183,7 @@ def _collect_probas(self, X):
chunks=chunks,
meta=meta,
)
elif isinstance(X, dd._Frame):
elif is_frame_base(X):
# TODO: replace with a _predict_proba_stack version.
# This current raises; dask.dataframe doesn't like map_partitions that
# return new axes.
Expand Down
67 changes: 45 additions & 22 deletions dask_ml/linear_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,60 @@
import numpy as np
from multipledispatch import dispatch

if getattr(dd, "_dask_expr_enabled", lambda: False)():
import dask_expr

@dispatch(dd._Frame)
def exp(A):
return da.exp(A)
@dispatch(dask_expr.FrameBase)
def exp(A):
return da.exp(A)

@dispatch(dask_expr.FrameBase)
def absolute(A):
return da.absolute(A)

@dispatch(dd._Frame)
def absolute(A):
return da.absolute(A)
@dispatch(dask_expr.FrameBase)
def sign(A):
return da.sign(A)

@dispatch(dask_expr.FrameBase)
def log1p(A):
return da.log1p(A)

@dispatch(dd._Frame)
def sign(A):
return da.sign(A)
@dispatch(dask_expr.FrameBase) # noqa: F811
def add_intercept(X): # noqa: F811
columns = X.columns
if "intercept" in columns:
raise ValueError("'intercept' column already in 'X'")
return X.assign(intercept=1)[["intercept"] + list(columns)]

else:

@dispatch(dd._Frame)
def log1p(A):
return da.log1p(A)
@dispatch(dd._Frame)
def exp(A):
return da.exp(A)

@dispatch(dd._Frame)
def absolute(A):
return da.absolute(A)

@dispatch(np.ndarray)
def add_intercept(X):
@dispatch(dd._Frame)
def sign(A):
return da.sign(A)

@dispatch(dd._Frame)
def log1p(A):
return da.log1p(A)

@dispatch(dd._Frame) # noqa: F811
def add_intercept(X): # noqa: F811
columns = X.columns
if "intercept" in columns:
raise ValueError("'intercept' column already in 'X'")
return X.assign(intercept=1)[["intercept"] + list(columns)]


@dispatch(np.ndarray) # noqa: F811
def add_intercept(X): # noqa: F811
return _add_intercept(X)


Expand All @@ -53,14 +84,6 @@ def add_intercept(X): # noqa: F811
return X.map_blocks(_add_intercept, dtype=X.dtype, chunks=chunks)


@dispatch(dd.DataFrame) # noqa: F811
def add_intercept(X): # noqa: F811
columns = X.columns
if "intercept" in columns:
raise ValueError("'intercept' column already in 'X'")
return X.assign(intercept=1)[["intercept"] + list(columns)]


@dispatch(np.ndarray) # noqa: F811
def lr_prob_stack(prob): # noqa: F811
return np.vstack([1 - prob, prob]).T
Expand Down
6 changes: 3 additions & 3 deletions dask_ml/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def pairwise_distances(
Y: ArrayLike,
metric: Union[str, Callable[[ArrayLike, ArrayLike], float]] = "euclidean",
n_jobs: Optional[int] = None,
**kwargs: Any
**kwargs: Any,
):
if isinstance(Y, da.Array):
raise TypeError("`Y` must be a numpy array")
Expand All @@ -62,7 +62,7 @@ def pairwise_distances(
dtype=float,
chunks=chunks,
metric=metric,
**kwargs
**kwargs,
)


Expand Down Expand Up @@ -203,7 +203,7 @@ def pairwise_kernels(
metric: Union[str, Callable[[ArrayLike, ArrayLike], float]] = "linear",
filter_params: bool = False,
n_jobs: Optional[int] = 1,
**kwds
**kwds,
):
from sklearn.gaussian_process.kernels import Kernel as GPKernel

Expand Down
1 change: 0 additions & 1 deletion dask_ml/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def mean_squared_log_error(
multioutput: Optional[str] = "uniform_average",
compute: bool = True,
) -> ArrayLike:

result = mean_squared_error(
np.log1p(y_true),
np.log1p(y_pred),
Expand Down
1 change: 0 additions & 1 deletion dask_ml/model_selection/_hyperband.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,6 @@ def _get_meta(
SHAs: Dict[int, SuccessiveHalvingSearchCV],
key: Callable[[int, int], str],
) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:

meta_ = []
history_ = {}
for bracket in brackets:
Expand Down
1 change: 0 additions & 1 deletion dask_ml/model_selection/_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,6 @@ def __init__(
predict_proba_meta=None,
transform_meta=None,
):

self.n_initial_parameters = n_initial_parameters
self.decay_rate = decay_rate
self.fits_per_score = fits_per_score
Expand Down
5 changes: 2 additions & 3 deletions dask_ml/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@
_check_param_grid = None

if SK_VERSION <= packaging.version.parse("0.21.dev0"):

_RETURN_TRAIN_SCORE_DEFAULT = "warn"

def handle_deprecated_train_score(results, return_train_score):
Expand Down Expand Up @@ -414,7 +413,7 @@ def do_fit_and_score(
xtest = X_test + (n,)
ytest = y_test + (n,)

for (name, m) in fit_ests:
for name, m in fit_ests:
dsk[(score_name, m, n)] = (
score,
(name, m, n),
Expand Down Expand Up @@ -879,7 +878,7 @@ def _do_featureunion(

fit_steps = []
tr_Xs = []
for (step_name, step) in est.transformer_list:
for step_name, step in est.transformer_list:
fits, out_Xs = _do_fit_step(
dsk,
next_token,
Expand Down
2 changes: 1 addition & 1 deletion dask_ml/preprocessing/_block_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
func: Callable[..., Union[ArrayLike, DataFrameType]],
*,
validate: bool = False,
**kw_args: Any
**kw_args: Any,
):
self.func: Callable[..., Union[ArrayLike, DataFrameType]] = func
self.validate = validate
Expand Down
Loading

0 comments on commit fa5a583

Please sign in to comment.