-
Notifications
You must be signed in to change notification settings - Fork 184
/
test_objective_functions.py
118 lines (101 loc) · 4.11 KB
/
test_objective_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import unittest
from typing import Dict, List, Tuple
import numpy as np
import onnxruntime
import pandas as pd
from onnx import ModelProto
from onnx.defs import onnx_opset_version
from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER
from onnxconverter_common.data_types import DoubleTensorType, TensorType
from onnxmltools import convert_lightgbm
from onnxruntime import InferenceSession
from pandas.core.frame import DataFrame
from lightgbm import LGBMRegressor
_N_ROWS = 10_000
_N_COLS = 10
_N_DECIMALS = 5
_FRAC = 0.9997
_X = pd.DataFrame(np.random.random(size=(_N_ROWS, _N_COLS)))
_Y = pd.Series(np.random.random(size=_N_ROWS))
_DTYPE_MAP: Dict[str, TensorType] = {
"float64": DoubleTensorType,
}
TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version())
class ObjectiveTest(unittest.TestCase):
_objectives: Tuple[str] = ("regression", "poisson", "gamma", "quantile")
@staticmethod
def _calc_initial_types(X: DataFrame) -> List[Tuple[str, TensorType]]:
dtypes = set(str(dtype) for dtype in X.dtypes)
if len(dtypes) > 1:
raise RuntimeError(
f"Test expects homogenous input matrix. "
f"Found multiple dtypes: {dtypes}."
)
dtype = dtypes.pop()
tensor_type = _DTYPE_MAP[dtype]
return [("input", tensor_type(X.shape))]
@staticmethod
def _predict_with_onnx(model: ModelProto, X: DataFrame) -> np.array:
session = InferenceSession(
model.SerializeToString(), providers=["CPUExecutionProvider"]
)
output_names = [s_output.name for s_output in session.get_outputs()]
input_names = [s_input.name for s_input in session.get_inputs()]
if len(input_names) > 1:
raise RuntimeError(
f"Test expects one input. Found multiple inputs: {input_names}."
)
input_name = input_names[0]
return session.run(output_names, {input_name: X.values})[0][:, 0]
@staticmethod
def _assert_almost_equal(
actual: np.array, desired: np.array, decimal: int = 7, frac: float = 1.0
):
"""
Assert that almost all rows in actual and desired
are almost equal to each other.
Similar to np.testing.assert_almost_equal but allows to define
a fraction of rows to be almost
equal instead of expecting all rows to be almost equal.
"""
assert 0 <= frac <= 1, "frac must be in range(0, 1)."
success_abs = (abs(actual - desired) <= (10**-decimal)).sum()
success_rel = success_abs / len(actual)
assert success_rel >= frac, (
f"Only {success_abs} out of {len(actual)} "
f"rows are almost equal to {decimal} decimals."
)
@unittest.skipIf(
tuple(int(ver) for ver in onnxruntime.__version__.split(".")[:2]) < (1, 3),
"not supported in this library version",
)
def test_objective(self):
"""
Test if a LGBMRegressor a with certain objective (e.g. 'poisson')
can be converted to ONNX
and whether the ONNX graph and the original model produce
almost equal predictions.
Note that this tests is a bit flaky because of precision
differences with ONNX and LightGBM
and therefore sometimes fails randomly. In these cases,
a retry should resolve the issue.
"""
for objective in self._objectives:
with self.subTest(X=_X, objective=objective):
regressor = LGBMRegressor(objective=objective, num_thread=1)
regressor.fit(_X, _Y)
regressor_onnx: ModelProto = convert_lightgbm(
regressor,
initial_types=self._calc_initial_types(_X),
target_opset=TARGET_OPSET,
)
y_pred = regressor.predict(_X)
y_pred_onnx = self._predict_with_onnx(regressor_onnx, _X)
self._assert_almost_equal(
y_pred,
y_pred_onnx,
decimal=_N_DECIMALS,
frac=_FRAC,
)
if __name__ == "__main__":
unittest.main()