Skip to content

Commit

Permalink
[python-package] add more type hints on Dataset (#5431)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Aug 23, 2022
1 parent 78f95e4 commit 01774bb
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,10 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
return pandas_str


def _load_pandas_categorical(file_name=None, model_str=None):
def _load_pandas_categorical(
file_name: Optional[Union[str, Path]] = None,
model_str: Optional[str] = None
) -> Optional[str]:
pandas_key = 'pandas_categorical:'
offset = -len(pandas_key)
if file_name is not None:
Expand Down Expand Up @@ -1879,7 +1882,15 @@ def construct(self) -> "Dataset":
self.feature_name = self.get_feature_name()
return self

def create_valid(self, data, label=None, weight=None, group=None, init_score=None, params=None):
def create_valid(
self,
data,
label=None,
weight=None,
group=None,
init_score=None,
params: Optional[Dict[str, Any]] = None
) -> "Dataset":
"""Create validation data align with current Dataset.
Parameters
Expand Down Expand Up @@ -1966,7 +1977,7 @@ def save_binary(self, filename: Union[str, Path]) -> "Dataset":
c_str(str(filename))))
return self

def _update_params(self, params):
def _update_params(self, params: Optional[Dict[str, Any]]) -> "Dataset":
if not params:
return self
params = deepcopy(params)
Expand Down Expand Up @@ -1999,7 +2010,11 @@ def _reverse_update_params(self) -> "Dataset":
self.params_back_up = None
return self

def set_field(self, field_name, data):
def set_field(
self,
field_name: str,
data
) -> "Dataset":
"""Set property into the Dataset.
Parameters
Expand Down Expand Up @@ -2135,7 +2150,10 @@ def set_categorical_feature(
raise LightGBMError("Cannot set categorical feature after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.")

def _set_predictor(self, predictor):
def _set_predictor(
self,
predictor: Optional[_InnerPredictor]
) -> "Dataset":
"""Set predictor for continued training.
It is not recommended for user to call this function.
Expand All @@ -2156,7 +2174,7 @@ def _set_predictor(self, predictor):
"set free_raw_data=False when construct Dataset to avoid this.")
return self

def set_reference(self, reference):
def set_reference(self, reference: "Dataset") -> "Dataset":
"""Set reference Dataset.
Parameters
Expand Down Expand Up @@ -2207,7 +2225,7 @@ def set_feature_name(self, feature_name: List[str]) -> "Dataset":
ctypes.c_int(len(feature_name))))
return self

def set_label(self, label):
def set_label(self, label) -> "Dataset":
"""Set label of Dataset.
Parameters
Expand All @@ -2227,7 +2245,7 @@ def set_label(self, label):
self.label = self.get_field('label') # original values can be modified at cpp side
return self

def set_weight(self, weight):
def set_weight(self, weight) -> "Dataset":
"""Set weight of each instance.
Parameters
Expand All @@ -2249,7 +2267,7 @@ def set_weight(self, weight):
self.weight = self.get_field('weight') # original values can be modified at cpp side
return self

def set_init_score(self, init_score):
def set_init_score(self, init_score) -> "Dataset":
"""Set init score of Booster to start from.
Parameters
Expand All @@ -2268,7 +2286,7 @@ def set_init_score(self, init_score):
self.init_score = self.get_field('init_score') # original values can be modified at cpp side
return self

def set_group(self, group):
def set_group(self, group) -> "Dataset":
"""Set group size of Dataset (used for ranking).
Parameters
Expand Down Expand Up @@ -2330,7 +2348,7 @@ def get_feature_name(self) -> List[str]:
ptr_string_buffers))
return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)]

def get_label(self):
def get_label(self) -> Optional[np.ndarray]:
"""Get the label of the Dataset.
Returns
Expand All @@ -2342,7 +2360,7 @@ def get_label(self):
self.label = self.get_field('label')
return self.label

def get_weight(self):
def get_weight(self) -> Optional[np.ndarray]:
"""Get the weight of the Dataset.
Returns
Expand All @@ -2354,7 +2372,7 @@ def get_weight(self):
self.weight = self.get_field('weight')
return self.weight

def get_init_score(self):
def get_init_score(self) -> Optional[np.ndarray]:
"""Get the initial score of the Dataset.
Returns
Expand Down Expand Up @@ -2473,7 +2491,7 @@ def feature_num_bin(self, feature: Union[int, str]) -> int:
else:
raise LightGBMError("Cannot get feature_num_bin before construct dataset")

def get_ref_chain(self, ref_limit=100):
def get_ref_chain(self, ref_limit: int = 100) -> Set["Dataset"]:
"""Get a chain of Dataset objects.
Starts with r, then goes to r.reference (if exists),
Expand Down

0 comments on commit 01774bb

Please sign in to comment.