Skip to content

Commit

Permalink
More tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 16, 2021
1 parent f232d7e commit 0257923
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/python/test_dmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def set_base_margin_info(DType, DMatrixT, tm: str):
xgb.train({"tree_method": tm}, Xy)

if not hasattr(X, "iloc"):
# column major matrix
got = DType(Xy.get_base_margin().reshape(50, 2))
assert (got == base_margin).all()

Expand All @@ -39,6 +40,20 @@ def set_base_margin_info(DType, DMatrixT, tm: str):
got = DType(Xy.get_base_margin().reshape(2, 50))
assert (got == base_margin.T).all()

# Row vs col vec.
base_margin = y
Xy.set_base_margin(base_margin)
bm_col = Xy.get_base_margin()
Xy.set_base_margin(base_margin.reshape(1, base_margin.size))
bm_row = Xy.get_base_margin()
assert (bm_row == bm_col).all()

# type
base_margin = base_margin.astype(np.float64)
Xy.set_base_margin(base_margin)
bm_f64 = Xy.get_base_margin()
assert (bm_f64 == bm_col).all()


class TestDMatrix:
def test_warn_missing(self):
Expand Down

0 comments on commit 0257923

Please sign in to comment.