Skip to content

Commit

Permalink
[tests] replace pytest.parametrize (#4377)
Browse files Browse the repository at this point in the history
* replace pytest.parametrize

* add informative message for assert
  • Loading branch information
StrikerRUS authored Jun 15, 2021
1 parent a592316 commit c738c83
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,8 +1252,8 @@ def generate_trainset_for_monotone_constraints_tests(x3_to_category=True):
return trainset


@pytest.mark.parametrize("test_with_interaction_constraints", [True, False])
def test_monotone_constraints(test_with_interaction_constraints):
@pytest.mark.parametrize("test_with_categorical_variable", [True, False])
def test_monotone_constraints(test_with_categorical_variable):
def is_increasing(y):
return (np.diff(y) >= 0.0).all()

Expand Down Expand Up @@ -1316,10 +1316,12 @@ def has_interaction(treef):

return not has_interaction_flag.any()

for test_with_categorical_variable in [True, False]:
trainset = generate_trainset_for_monotone_constraints_tests(
test_with_categorical_variable
)
trainset = generate_trainset_for_monotone_constraints_tests(
test_with_categorical_variable
)
for test_with_interaction_constraints in [True, False]:
error_msg = ("Model not correctly constrained "
f"(test_with_interaction_constraints={test_with_interaction_constraints})")
for monotone_constraints_method in ["basic", "intermediate", "advanced"]:
params = {
"min_data": 20,
Expand All @@ -1333,7 +1335,7 @@ def has_interaction(treef):
constrained_model = lgb.train(params, trainset)
assert is_correctly_constrained(
constrained_model, test_with_categorical_variable
)
), error_msg
if test_with_interaction_constraints:
feature_sets = [["Column_0"], ["Column_1"], "Column_2"]
assert are_interactions_enforced(constrained_model, feature_sets)
Expand Down Expand Up @@ -1399,8 +1401,9 @@ def test_monotone_penalty_max():
}

unconstrained_model = lgb.train(params_unconstrained_model, trainset_unconstrained_model, 10)
unconstrained_model_predictions = unconstrained_model.\
predict(x3_negatively_correlated_with_y.reshape(-1, 1))
unconstrained_model_predictions = unconstrained_model.predict(
x3_negatively_correlated_with_y.reshape(-1, 1)
)

for monotone_constraints_method in ["basic", "intermediate", "advanced"]:
params_constrained_model["monotone_constraints_method"] = monotone_constraints_method
Expand Down

0 comments on commit c738c83

Please sign in to comment.