Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor with black #1009

Merged
merged 5 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
11 changes: 3 additions & 8 deletions .azure-pipelines/linux-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,6 @@ jobs:
fi
displayName: 'install onnx'

- script: |
pip install flake8
displayName: 'install flake8'

- script: |
pip install $(onnxrt.version)
displayName: 'install onnxruntime'
Expand Down Expand Up @@ -334,11 +330,10 @@ jobs:
displayName: 'pytest-onnxmltools'
condition: eq(variables['run.example'], '1')

# Check flake8 after the tests to get more feedback.
# It is checked before the tests on the windows build.
- script: |
flake8 skl2onnx tests tests_onnxmltools
displayName: 'flake8'
python -m pip install ruff
ruff skl2onnx tests tests_onnxmltools
displayName: 'ruff'

- script: |
if [ '$(onnx.target_opset)' != '' ]
Expand Down
10 changes: 3 additions & 7 deletions .azure-pipelines/win32-conda-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,9 @@ jobs:

- script: |
call activate skl2onnxEnvironment
pip install flake8
displayName: 'install flake8'

- script: |
call activate skl2onnxEnvironment
flake8 skl2onnx tests tests_onnxmltools
displayName: 'flake8'
python -m pip install ruff
ruff skl2onnx tests tests_onnxmltools
displayName: 'ruff'

- script: |
call activate skl2onnxEnvironment
Expand Down
16 changes: 16 additions & 0 deletions .github/workflows/black-ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Black Format Checker
on: [push, pull_request]
jobs:
black-format-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: psf/black@stable
with:
options: "--diff --check"
src: "."
ruff-format-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

<p align="center"><img width="50%" src="docs/logo_main.png" /></p>

| Linux | Windows |
|-------|---------|
| [![Build Status](https://dev.azure.com/onnxmltools/sklearn-onnx/_apis/build/status/sklearn-onnx-linux-conda-ci?branchName=master)](https://dev.azure.com/onnxmltools/sklearn-onnx/_build/latest?definitionId=5?branchName=master) | [![Build Status](https://dev.azure.com/onnxmltools/sklearn-onnx/_apis/build/status/sklearn-onnx-win32-conda-ci?branchName=master)](https://dev.azure.com/onnxmltools/sklearn-onnx/_build/latest?definitionId=5?branchName=master)|
[![Build Status Linux](https://dev.azure.com/onnxmltools/sklearn-onnx/_apis/build/status%2Fonnx.sklearn-onnx.linux.CI?branchName=refs%2Fpull%2F1009%2Fmerge)](https://dev.azure.com/onnxmltools/sklearn-onnx/_build/latest?definitionId=21&branchName=refs%2Fpull%2F1009%2Fmerge)

[![Build Status Windows](https://dev.azure.com/onnxmltools/sklearn-onnx/_apis/build/status%2Fonnx.sklearn-onnx.win.CI?branchName=refs%2Fpull%2F1009%2Fmerge)](https://dev.azure.com/onnxmltools/sklearn-onnx/_build/latest?definitionId=22&branchName=refs%2Fpull%2F1009%2Fmerge)

[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

## Introduction
*sklearn-onnx* converts [scikit-learn](https://scikit-learn.org/stable/) models to [ONNX](https://github.com/onnx/onnx).
Once in the ONNX format, you can use tools like [ONNX Runtime](https://github.com/Microsoft/onnxruntime) for high performance scoring.
All converters are tested with [onnxruntime](https://onnxruntime.ai/).
Any external converter can be registered to convert scikit-learn pipeline
including models or transformers coming from external libraries.

## Documentation
Full documentation including tutorials is available at [https://onnx.ai/sklearn-onnx/](https://onnx.ai/sklearn-onnx/).
Expand Down
115 changes: 69 additions & 46 deletions benchmarks/bench_plot_onnxruntime_decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pandas
from sklearn import config_context
from sklearn.tree import DecisionTreeClassifier

try:
# scikit-learn >= 0.22
from sklearn.utils._testing import ignore_warnings
Expand All @@ -29,14 +30,18 @@
# Implementations to benchmark.
##############################


def fcts_model(X, y, max_depth):
"DecisionTreeClassifier."
rf = DecisionTreeClassifier(max_depth=max_depth)
rf.fit(X, y)

initial_types = [('X', FloatTensorType([None, X.shape[1]]))]
onx = convert_sklearn(rf, initial_types=initial_types,
options={DecisionTreeClassifier: {'zipmap': False}})
initial_types = [("X", FloatTensorType([None, X.shape[1]]))]
onx = convert_sklearn(
rf,
initial_types=initial_types,
options={DecisionTreeClassifier: {"zipmap": False}},
)
f = BytesIO()
f.write(onx.SerializeToString())
content = f.getvalue()
Expand All @@ -51,30 +56,29 @@ def predict_skl_predict_proba(X, model=rf):
return rf.predict_proba(X)

def predict_onnxrt_predict(X, sess=sess):
return sess.run(outputs[:1], {'X': X})[0]
return sess.run(outputs[:1], {"X": X})[0]

def predict_onnxrt_predict_proba(X, sess=sess):
return sess.run(outputs[1:], {'X': X})[0]
return sess.run(outputs[1:], {"X": X})[0]

return {'predict': (predict_skl_predict,
predict_onnxrt_predict),
'predict_proba': (predict_skl_predict_proba,
predict_onnxrt_predict_proba)}
return {
"predict": (predict_skl_predict, predict_onnxrt_predict),
"predict_proba": (predict_skl_predict_proba, predict_onnxrt_predict_proba),
}


##############################
# Benchmarks
##############################


def allow_configuration(**kwargs):
return True


def bench(n_obs, n_features, max_depths, methods,
repeat=10, verbose=False):
def bench(n_obs, n_features, max_depths, methods, repeat=10, verbose=False):
res = []
for nfeat in n_features:

ntrain = 100000
X_train = np.empty((ntrain, nfeat))
X_train[:, :] = rand(ntrain, nfeat)[:, :].astype(np.float32)
Expand All @@ -88,15 +92,12 @@ def bench(n_obs, n_features, max_depths, methods,

for n in n_obs:
for method in methods:

fct1, fct2 = fcts[method]

if not allow_configuration(
n=n, nfeat=nfeat, max_depth=max_depth):
if not allow_configuration(n=n, nfeat=nfeat, max_depth=max_depth):
continue

obs = dict(n_obs=n, nfeat=nfeat,
max_depth=max_depth, method=method)
obs = dict(n_obs=n, nfeat=nfeat, max_depth=max_depth, method=method)

# creates different inputs to avoid caching in any ways
Xs = []
Expand Down Expand Up @@ -143,11 +144,11 @@ def bench(n_obs, n_features, max_depths, methods,
# Plots.
##############################


def plot_results(df, verbose=False):
nrows = max(len(set(df.max_depth)) * len(set(df.n_obs)), 2)
ncols = max(len(set(df.method)), 2)
fig, ax = plt.subplots(nrows, ncols,
figsize=(ncols * 4, nrows * 4))
fig, ax = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))
pos = 0
row = 0
for n_obs in sorted(set(df.n_obs)):
Expand All @@ -156,31 +157,49 @@ def plot_results(df, verbose=False):
for method in sorted(set(df.method)):
a = ax[row, pos]
if row == ax.shape[0] - 1:
a.set_xlabel("N features", fontsize='x-small')
a.set_xlabel("N features", fontsize="x-small")
if pos == 0:
a.set_ylabel(
"Time (s) n_obs={}\nmax_depth={}".format(
n_obs, max_depth),
fontsize='x-small')

color = 'b'
subset = df[(df.method == method) & (df.n_obs == n_obs) &
(df.max_depth == max_depth)]
"Time (s) n_obs={}\nmax_depth={}".format(n_obs, max_depth),
fontsize="x-small",
)

color = "b"
subset = df[
(df.method == method)
& (df.n_obs == n_obs)
& (df.max_depth == max_depth)
]
if subset.shape[0] == 0:
continue
subset = subset.sort_values("nfeat")
if verbose:
print(subset)
label = "skl"
subset.plot(x="nfeat", y="time_skl", label=label, ax=a,
logx=True, logy=True, c=color, style='--')
subset.plot(
x="nfeat",
y="time_skl",
label=label,
ax=a,
logx=True,
logy=True,
c=color,
style="--",
)
label = "ort"
subset.plot(x="nfeat", y="time_ort", label=label, ax=a,
logx=True, logy=True, c=color)

a.legend(loc=0, fontsize='x-small')
subset.plot(
x="nfeat",
y="time_ort",
label=label,
ax=a,
logx=True,
logy=True,
c=color,
)

a.legend(loc=0, fontsize="x-small")
if row == 0:
a.set_title("method={}".format(method), fontsize='x-small')
a.set_title("method={}".format(method), fontsize="x-small")
pos += 1
row += 1

Expand All @@ -190,13 +209,14 @@ def plot_results(df, verbose=False):
@ignore_warnings(category=FutureWarning)
def run_bench(repeat=100, verbose=False):
n_obs = [1, 10, 100, 1000, 10000, 100000]
methods = ['predict', 'predict_proba']
methods = ["predict", "predict_proba"]
n_features = [1, 5, 10, 20, 50, 100, 200]
max_depths = [2, 5, 10, 20]

start = time()
results = bench(n_obs, n_features, max_depths, methods,
repeat=repeat, verbose=verbose)
results = bench(
n_obs, n_features, max_depths, methods, repeat=repeat, verbose=verbose
)
end = time()

results_df = pandas.DataFrame(results)
Expand All @@ -207,21 +227,24 @@ def run_bench(repeat=100, verbose=False):
return results_df


if __name__ == '__main__':
if __name__ == "__main__":
from datetime import datetime
import sklearn
import numpy
import onnx
import onnxruntime
import skl2onnx
df = pandas.DataFrame([
{"name": "date", "version": str(datetime.now())},
{"name": "numpy", "version": numpy.__version__},
{"name": "scikit-learn", "version": sklearn.__version__},
{"name": "onnx", "version": onnx.__version__},
{"name": "onnxruntime", "version": onnxruntime.__version__},
{"name": "skl2onnx", "version": skl2onnx.__version__},
])

df = pandas.DataFrame(
[
{"name": "date", "version": str(datetime.now())},
{"name": "numpy", "version": numpy.__version__},
{"name": "scikit-learn", "version": sklearn.__version__},
{"name": "onnx", "version": onnx.__version__},
{"name": "onnxruntime", "version": onnxruntime.__version__},
{"name": "skl2onnx", "version": skl2onnx.__version__},
]
)
df.to_csv("bench_plot_onnxruntime_decision_tree.time.csv", index=False)
print(df)
df = run_bench(verbose=True)
Expand Down
Loading
Loading