diff --git a/src/frontend/lightgbm.cc b/src/frontend/lightgbm.cc index 7aa9d3ab..e5c37df9 100644 --- a/src/frontend/lightgbm.cc +++ b/src/frontend/lightgbm.cc @@ -554,25 +554,19 @@ inline std::unique_ptr ParseStream(std::istream& fi) { tree.AddChilds(new_id); if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) { - // categorical + // categorical split const int cat_idx = static_cast(lgb_tree.threshold[old_id]); const std::vector left_categories = BitsetToList(lgb_tree.cat_threshold.data() + lgb_tree.cat_boundaries[cat_idx], lgb_tree.cat_boundaries[cat_idx + 1] - lgb_tree.cat_boundaries[cat_idx]); - const bool missing_value_to_zero = missing_type != MissingType::kNaN; + // For categorical splits, we ignore the missing type field. NaNs always get mapped to + // the right child node. bool default_left = false; - if (missing_value_to_zero) { - // If missing_value_to_zero flag is true, all missing values get mapped to 0.0, so - // we need to override the default_left flag - default_left - = (std::find(left_categories.begin(), left_categories.end(), - static_cast(0)) != left_categories.end()); - } tree.SetCategoricalSplit(new_id, split_index, default_left, left_categories, false); } else { - // numerical + // numerical split const auto threshold = static_cast(lgb_tree.threshold[old_id]); bool default_left = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask); diff --git a/tests/examples/sparse_categorical/sparse_categorical.test.margin b/tests/examples/sparse_categorical/sparse_categorical.test.margin index d5823a0e..ee45f9de 100644 --- a/tests/examples/sparse_categorical/sparse_categorical.test.margin +++ b/tests/examples/sparse_categorical/sparse_categorical.test.margin @@ -1,3 +1,3 @@ -0.036716360109071977 --0.040451733276673375 +-0.088410677452890038 0.21990291345117738 diff --git a/tests/python/test_gtil.py b/tests/python/test_gtil.py index 0a2e9196..2d3aa6d5 100644 --- a/tests/python/test_gtil.py +++ b/tests/python/test_gtil.py @@ -340,5 +340,8 @@ def test_lightgbm_sparse_categorical_model(): X, _ = load_svmlight_file(dataset_db[dataset].dtest, zero_based=True, n_features=tl_model.num_feature) expected_pred = load_txt(dataset_db[dataset].expected_margin) - out_pred = treelite.gtil.predict(tl_model, X.toarray(), pred_margin=True) + # GTIL doesn't yet support sparse matrix; so use NaN to represent missing values + Xa = X.toarray() + Xa[Xa == 0] = 'nan' + out_pred = treelite.gtil.predict(tl_model, Xa, pred_margin=True) np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5) diff --git a/tests/python/test_lightgbm_integration.py b/tests/python/test_lightgbm_integration.py index de766632..7ec1da62 100644 --- a/tests/python/test_lightgbm_integration.py +++ b/tests/python/test_lightgbm_integration.py @@ -290,3 +290,29 @@ def test_constant_tree(): model_path = _qualify_path('lightgbm_constant_tree', 'model_with_constant_tree.txt') model = treelite.Model.load(model_path, model_format='lightgbm') assert model.num_tree == 2 + + +@pytest.mark.parametrize('toolchain', os_compatible_toolchains()) +def test_nan_handling_with_categorical_splits(tmpdir, toolchain): + """Test that NaN inputs are handled correctly in categorical splits""" + + # Test case taken from https://github.com/dmlc/treelite/issues/277 + X = np.array(30 * [[1]] + 30 * [[2]] + 30 * [[0]]) + y = np.array(60 * [5] + 30 * [10]) + train_data = lightgbm.Dataset(X, label=y, categorical_feature=[0]) + bst = lightgbm.train({}, train_data, 1) + + model_path = os.path.join(tmpdir, 'dummy_categorical.txt') + libpath = os.path.join(tmpdir, 'dummy_categorical_lgb' + _libext()) + + input_with_nan = np.array([[np.NaN], [0.0]]) + + lgb_pred = bst.predict(input_with_nan) + bst.save_model(model_path) + + model = treelite.Model.load(model_path, model_format='lightgbm') + model.export_lib(toolchain=toolchain, libpath=libpath) + predictor = treelite_runtime.Predictor(libpath) + dmat = treelite_runtime.DMatrix(input_with_nan) + tl_pred = predictor.predict(dmat) + np.testing.assert_almost_equal(tl_pred, lgb_pred)