Skip to content

Commit

Permalink
[doc] Add typing to dask demos. (dmlc#10207)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Apr 22, 2024
1 parent 3fbb221 commit 59d7b8d
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 9 deletions.
3 changes: 2 additions & 1 deletion demo/dask/cpu_survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

import os

import dask.array as da
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster

from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix


def main(client):
def main(client: Client) -> da.Array:
# Load an example survival data from CSV into a Dask data frame.
# The Veterans' Administration Lung Cancer Trial
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
Expand Down
2 changes: 1 addition & 1 deletion demo/dask/cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from xgboost.dask import DaskDMatrix


def main(client):
def main(client: Client) -> None:
# generate some random data for demonstration
m = 100000
n = 100
Expand Down
16 changes: 11 additions & 5 deletions demo/dask/dask_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
====================================
"""

from typing import Any

import numpy as np
from dask.distributed import Client, LocalCluster
from dask_ml.datasets import make_regression
Expand All @@ -13,7 +15,7 @@
from xgboost.dask import DaskDMatrix


def probability_for_going_backward(epoch):
def probability_for_going_backward(epoch: int) -> float:
return 0.999 / (1.0 + 0.05 * np.log(1.0 + epoch))


Expand All @@ -23,7 +25,9 @@ class CustomEarlyStopping(xgb.callback.TrainingCallback):
In the beginning, allow the metric to become worse with a probability of 0.999.
As boosting progresses, the probability should be adjusted downward"""

def __init__(self, *, validation_set, target_metric, maximize, seed):
def __init__(
self, *, validation_set: str, target_metric: str, maximize: bool, seed: int
) -> None:
self.validation_set = validation_set
self.target_metric = target_metric
self.maximize = maximize
Expand All @@ -34,15 +38,17 @@ def __init__(self, *, validation_set, target_metric, maximize, seed):
else:
self.better = lambda x, y: x < y

def after_iteration(self, model, epoch, evals_log):
def after_iteration(
self, model: Any, epoch: int, evals_log: xgb.callback.TrainingCallback.EvalsLog
) -> bool:
metric_history = evals_log[self.validation_set][self.target_metric]
if len(metric_history) < 2 or self.better(
metric_history[-1], metric_history[-2]
):
return False # continue training
p = probability_for_going_backward(epoch)
go_backward = self.rng.choice(2, size=(1,), replace=True, p=[1 - p, p]).astype(
np.bool
np.bool_
)[0]
print(
"The validation metric went into the wrong direction. "
Expand All @@ -54,7 +60,7 @@ def after_iteration(self, model, epoch, evals_log):
return True # stop training


def main(client):
def main(client: Client) -> None:
m = 100000
n = 100
X, y = make_regression(n_samples=m, n_features=n, chunks=200, random_state=0)
Expand Down
2 changes: 1 addition & 1 deletion demo/dask/sklearn_cpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from xgboost import dask as dxgb


def main(client):
def main(client: Client) -> dxgb.Booster:
# generate some random data for demonstration
n = 100
m = 10000
Expand Down
2 changes: 1 addition & 1 deletion demo/dask/sklearn_gpu_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from xgboost import dask as dxgb


def main(client):
def main(client: Client) -> dxgb.Booster:
# generate some random data for demonstration
n = 100
m = 1000000
Expand Down
1 change: 1 addition & 0 deletions tests/ci_build/lint_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class LintersPaths:
"tests/test_distributed/test_gpu_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",
# demo
"demo/dask/",
"demo/json-model/json_parser.py",
"demo/guide-python/external_memory.py",
"demo/guide-python/sklearn_examples.py",
Expand Down

0 comments on commit 59d7b8d

Please sign in to comment.