Skip to content

Commit

Permalink
added test for warnings about unchangeable params
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Dec 1, 2019
1 parent ec6aab4 commit c3c913b
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding: utf-8
import os
import sys
import tempfile
import unittest

Expand Down Expand Up @@ -280,3 +281,77 @@ def check_asserts(data):
lgb_data.set_weight(sequence)
lgb_data.set_init_score(sequence)
check_asserts(lgb_data)

def test_dataset_update_params_warning(self):
default_params = {"max_bin": 100,
"max_bin_by_feature": [20, 10],
"bin_construct_sample_cnt": 10000,
"min_data_in_bin": 1,
"use_missing": False,
"zero_as_missing": False,
"sparse_threshold": 1,
"categorical_feature": [0],
"feature_pre_filter": True,
"min_data_in_leaf": 8}
unchangeable_params = {"max_bin": 150,
"max_bin_by_feature": [30, 5],
"bin_construct_sample_cnt": 5000,
"min_data_in_bin": 2,
"use_missing": True,
"zero_as_missing": True,
"sparse_threshold": 0.4,
"categorical_feature": [0, 1],
"feature_pre_filter": False,
"min_data_in_leaf": 3}
stdout_backup = os.dup(sys.__stdout__.fileno())
log_filename = 'log.txt'

with open(log_filename, 'a') as stdout_redirect:
os.dup2(stdout_redirect.fileno(), sys.__stdout__.fileno())
X = np.random.random((100, 2))
y = np.random.random(100)
lgb_data = lgb.Dataset(X, y)
lgb.train(default_params, lgb_data, num_boost_round=5)
with open(log_filename, 'r') as stdout_log:
log = stdout_log.read()
self.assertEqual(log.find('[LightGBM] [Warning] Cannot change '), -1)
self.assertEqual(log.find('[LightGBM] [Warning] Reducing `min_data_in_leaf` '
'with `feature_pre_filter=true` may cause '), -1)

with open(log_filename, 'a') as stdout_redirect:
os.dup2(stdout_redirect.fileno(), sys.__stdout__.fileno())
default_params["min_data_in_leaf"] += 1
lgb.train(default_params, lgb_data, num_boost_round=5)
with open(log_filename, 'r') as stdout_log:
log = stdout_log.read()
self.assertEqual(log.find('[LightGBM] [Warning] Cannot change '), -1)
self.assertEqual(log.find('[LightGBM] [Warning] Reducing `min_data_in_leaf` '
'with `feature_pre_filter=true` may cause '), -1)

with open(log_filename, 'a') as stdout_redirect:
os.dup2(stdout_redirect.fileno(), sys.__stdout__.fileno())
default_params["feature_pre_filter"] = False
lgb_data_new = lgb.Dataset(X, y, params=default_params).construct()
default_params["min_data_in_leaf"] -= 2
lgb.train(default_params, lgb_data_new, num_boost_round=5)
with open(log_filename, 'r') as stdout_log:
log = stdout_log.read()
self.assertEqual(log.find('[LightGBM] [Warning] Cannot change '), -1)
self.assertEqual(log.find('[LightGBM] [Warning] Reducing `min_data_in_leaf` '
'with `feature_pre_filter=true` may cause '), -1)

with open(log_filename, 'a') as stdout_redirect:
os.dup2(stdout_redirect.fileno(), sys.__stdout__.fileno())
lgb.train(unchangeable_params, lgb_data, num_boost_round=5)
with open(log_filename, 'r') as stdout_log:
log = stdout_log.read()
for param in unchangeable_params.keys():
if param == "min_data_in_leaf":
self.assertNotEqual(log.find('[LightGBM] [Warning] Reducing `min_data_in_leaf` '
'with `feature_pre_filter=true` may cause '), -1)
else:
self.assertNotEqual(log.find('[LightGBM] [Warning] Cannot '
'change {} '.format(param)), -1)

os.dup2(stdout_backup, sys.__stdout__.fileno())
os.remove(log_filename)

0 comments on commit c3c913b

Please sign in to comment.