diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index c6ba2ab0c711..c9628ba8d7f1 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1245,6 +1245,7 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False): Whether to print messages during construction. """ self.handle = None + self.network = False self.__need_reload_eval_info = True self.__train_data_name = "training" self.__attr = {} @@ -1288,6 +1289,20 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False): self.__is_predicted_cur_iter = [False] self.__get_eval_info() self.pandas_categorical = train_set.pandas_categorical + """set network if necessary""" + if "machines" in params: + machines = params["machines"] + if isinstance(machines, string_type): + num_machines = len(machines.split(',')) + elif isinstance(machines, (list, set)): + num_machines = len(machines) + machines = ','.join(machines) + else: + raise ValueError("Invalid machines in params.") + self.set_network(machines, + local_listen_port=params.get("local_listen_port", 12400), + listen_time_out=params.get("listen_time_out", 120), + num_machines=params.get("num_machines", num_machines)) elif model_file is not None: """Prediction task""" out_num_iterations = ctypes.c_int(0) @@ -1308,6 +1323,8 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False): raise TypeError('Need at least one training dataset or model file to create booster instance') def __del__(self): + if self.network: + self.free_network() if self.handle is not None: _safe_call(_LIB.LGBM_BoosterFree(self.handle)) @@ -1351,6 +1368,32 @@ def _free_buffer(self): self.__inner_predict_buffer = [] self.__is_predicted_cur_iter = [] + def set_network(self, machines, local_listen_port=12400, + listen_time_out=120, num_machines=1): + """Set the network configuration. + + Parameters + ---------- + machines: list, set or string + Names of machines. + local_listen_port: int, optional (default=12400) + TCP listen port for local machines. + listen_time_out: int, optional (default=120) + Socket time-out in minutes. + num_machines: int, optional (default=1) + The number of machines for parallel learning application. + """ + _safe_call(_LIB.LGBM_NetworkInit(c_str(machines), + ctypes.c_int(local_listen_port), + ctypes.c_int(listen_time_out), + ctypes.c_int(num_machines))) + self.network = True + + def free_network(self): + """Free Network.""" + _safe_call(_LIB.LGBM_NetworkFree()) + self.network = False + def set_train_data_name(self, name): """Set the name to the training Dataset.