Skip to content

Commit

Permalink
Implement get_group. (#7564)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 15, 2022
1 parent 52277cc commit 13b0fa4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,10 +884,20 @@ def get_base_margin(self) -> np.ndarray:
Returns
-------
base_margin : float
base_margin
"""
return self.get_float_info('base_margin')

def get_group(self) -> np.ndarray:
"""Get the group of the DMatrix.
Returns
-------
group
"""
group_ptr = self.get_uint_info("group_ptr")
return np.diff(group_ptr)

def num_row(self) -> int:
"""Get the number of rows in the DMatrix.
Expand Down
4 changes: 4 additions & 0 deletions tests/python/test_dmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ def test_get_info(self):
dtrain.get_float_info('base_margin')
dtrain.get_uint_info('group_ptr')

group_len = np.array([2, 3, 4])
dtrain.set_group(group_len)
np.testing.assert_equal(group_len, dtrain.get_group())

def test_qid(self):
rows = 100
cols = 10
Expand Down

0 comments on commit 13b0fa4

Please sign in to comment.