Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] C-API fix; attribute accessors #1166

Merged
merged 5 commits into from
May 7, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ S3method(predict,xgb.Booster)
S3method(predict,xgb.Booster.handle)
S3method(setinfo,xgb.DMatrix)
S3method(slice,xgb.DMatrix)
export("xgb.attr<-")
export(getinfo)
export(print.xgb.DMatrix)
export(setinfo)
export(slice)
export(xgb.DMatrix)
export(xgb.DMatrix.save)
export(xgb.attr)
export(xgb.create.features)
export(xgb.cv)
export(xgb.dump)
Expand Down
74 changes: 74 additions & 0 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,77 @@ predict.xgb.Booster.handle <- function(object, ...) {
ret <- predict(bst, ...)
return(ret)
}


#' Accessors for serializable attributes of a model.
#'
#' These methods allow to manipulate key-value attribute strings of an xgboost model.
#'
#' @param object Object of class \code{xgb.Booster} or \code{xgb.Booster.handle}.
#' @param which a non-empty character string specifying which attribute is to be accessed.
#' @param value a value of an attribute. Non-character values are converted to character.
#' When length of a \code{value} vector is more than one, only the first element is used.
#'
#' @details
#' Note that the xgboost model attributes are a separate concept from the attributes in R.
#' Specifically, they refer to key-value strings that can be attached to an xgboost model
#' and stored within the model's binary representation.
#' In contrast, any R-attribute assigned to an R-object of \code{xgb.Booster} class
#' would not be saved by \code{xgb.save}, since xgboost model is an external memory object
#' and its serialization is handled extrnally.
#'
#' Also note that the attribute setter would usually work more efficiently for \code{xgb.Booster.handle}
#' than for \code{xgb.Booster}, since only just a handle would need to be copied.
#'
#' @return
#' \code{xgb.attr} returns either a string value of an attribute
#' or \code{NULL} if an attribute wasn't stored in a model.
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' train <- agaricus.train
#'
#' bst <- xgboost(data = train$data, label = train$label, max.depth = 2,
#' eta = 1, nthread = 2, nround = 2, objective = "binary:logistic")
#'
#' xgb.attr(bst, "my_attribute") <- "my attribute value"
#' print(xgb.attr(bst, "my_attribute"))
#'
#' xgb.save(bst, 'xgb.model')
#' bst1 <- xgb.load('xgb.model')
#' print(xgb.attr(bst1, "my_attribute"))
#'
#' @rdname xgb.attr
#' @export
xgb.attr <- function(object, which) {
if (is.null(which) | nchar(as.character(which)[1]) == 0) stop("invalid attribute name")
handle = xgb.get.handle(object, "xgb.attr")
.Call("XGBoosterGetAttr_R", handle, as.character(which)[1], PACKAGE="xgboost")
}

#' @rdname xgb.attr
#' @export
`xgb.attr<-` <- function(object, which, value) {
if (is.null(which) | nchar(as.character(which)[1]) == 0) stop("invalid attribute name")
handle = xgb.get.handle(object, "xgb.attr")
# TODO: setting NULL value to remove an attribute
.Call("XGBoosterSetAttr_R", handle, as.character(which)[1], as.character(value)[1], PACKAGE="xgboost")
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle)
}
object
}

# Return a valid handle out of either xgb.Booster.handle or xgb.Booster
# internal utility function
xgb.get.handle <- function(object, caller="") {
handle = switch(class(object),
xgb.Booster = object$handle,
xgb.Booster.handle = object,
stop(caller, ": argument must be either xgb.Booster or xgb.Booster.handle")
)
if (is.null(handle) | .Call("XGCheckNullPtr_R", handle, PACKAGE="xgboost")) {
stop(caller, ": invalid xgb.Booster.handle")
}
handle
}
53 changes: 53 additions & 0 deletions R-package/man/xgb.attr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 46 additions & 10 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,16 @@ SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
return ret;
}

void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
R_API_BEGIN();
CHECK_CALL(XGDMatrixSaveBinary(R_ExternalPtrAddr(handle),
CHAR(asChar(fname)),
asInteger(silent)));
R_API_END();
return R_NilValue;
}

void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
R_API_BEGIN();
int len = length(array);
const char *name = CHAR(asChar(field));
Expand All @@ -167,6 +168,7 @@ void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
BeginPtr(vec), len));
}
R_API_END();
return R_NilValue;
}

SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
Expand Down Expand Up @@ -227,23 +229,25 @@ SEXP XGBoosterCreate_R(SEXP dmats) {
return ret;
}

void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
R_API_BEGIN();
CHECK_CALL(XGBoosterSetParam(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
CHAR(asChar(val))));
CHAR(asChar(name)),
CHAR(asChar(val))));
R_API_END();
return R_NilValue;
}

void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
R_API_BEGIN();
CHECK_CALL(XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
R_ExternalPtrAddr(dtrain)));
R_API_END();
return R_NilValue;
}

void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
R_API_BEGIN();
CHECK_EQ(length(grad), length(hess))
<< "gradient and hess must have same length";
Expand All @@ -259,6 +263,7 @@ void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
BeginPtr(tgrad), BeginPtr(thess),
len));
R_API_END();
return R_NilValue;
}

SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
Expand Down Expand Up @@ -305,24 +310,27 @@ SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_lim
return ret;
}

void XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
R_API_END();
return R_NilValue;
}

void XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
R_API_BEGIN();
CHECK_CALL(XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
R_API_END();
return R_NilValue;
}

void XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw),
length(raw)));
R_API_END();
return R_NilValue;
}

SEXP XGBoosterModelToRaw_R(SEXP handle) {
Expand Down Expand Up @@ -360,3 +368,31 @@ SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats) {
return out;
}

SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name) {
SEXP out;
R_API_BEGIN();
int success;
const char *val;
CHECK_CALL(XGBoosterGetAttr(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
&val,
&success));
if (success) {
out = PROTECT(allocVector(STRSXP, 1));
SET_STRING_ELT(out, 0, mkChar(val));
} else {
out = PROTECT(R_NilValue);
}
UNPROTECT(1);
R_API_END();
return out;
}

SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) {
R_API_BEGIN();
CHECK_CALL(XGBoosterSetAttr(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
CHAR(asChar(val))));
R_API_END();
return R_NilValue;
}
44 changes: 35 additions & 9 deletions R-package/src/xgboost_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,18 @@ XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset);
* \param handle a instance of data matrix
* \param fname file name
* \param silent print statistics when saving
* \return R_NilValue
*/
XGB_DLL void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent);
XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent);

/*!
* \brief set information to dmatrix
* \param handle a instance of data matrix
* \param field field name, can be label, weight
* \param array pointer to float vector
* \return R_NilValue
*/
XGB_DLL void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array);
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array);

/*!
* \brief get info vector from matrix
Expand Down Expand Up @@ -104,16 +106,18 @@ XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats);
* \param handle handle
* \param name parameter name
* \param val value of parameter
* \return R_NilValue
*/
XGB_DLL void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val);
XGB_DLL SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val);

/*!
* \brief update the model in one round using dtrain
* \param handle handle
* \param iter current iteration rounds
* \param dtrain training data
* \return R_NilValue
*/
XGB_DLL void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);

/*!
* \brief update the model, by directly specify gradient and second order gradient,
Expand All @@ -122,16 +126,17 @@ XGB_DLL void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
* \param dtrain training data
* \param grad gradient statistics
* \param hess second order gradient statistics
* \return R_NilValue
*/
XGB_DLL void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess);
XGB_DLL SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess);

/*!
* \brief get evaluation statistics for xgboost
* \param handle handle
* \param iter current iteration rounds
* \param dmats list of handles to dmatrices
* \param evname name of evaluation
* \return the string containing evaluation stati
* \return the string containing evaluation stats
*/
XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames);

Expand All @@ -147,21 +152,24 @@ XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP n
* \brief load model from existing file
* \param handle handle
* \param fname file name
* \return R_NilValue
*/
XGB_DLL void XGBoosterLoadModel_R(SEXP handle, SEXP fname);
XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname);

/*!
* \brief save model into existing file
* \param handle handle
* \param fname file name
* \return R_NilValue
*/
XGB_DLL void XGBoosterSaveModel_R(SEXP handle, SEXP fname);
XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname);

/*!
* \brief load model from raw array
* \param handle handle
* \return R_NilValue
*/
XGB_DLL void XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);

/*!
* \brief save model into R's raw array
Expand All @@ -177,4 +185,22 @@ XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);
* \param with_stats whether dump statistics of splits
*/
XGB_DLL SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats);

/*!
* \brief get learner attribute value
* \param handle handle
* \param name attribute name
* \return character containing attribute value
*/
XGB_DLL SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name);

/*!
* \brief set learner attribute value
* \param handle handle
* \param name attribute name
* \param val attribute value
* \return R_NilValue
*/
XGB_DLL SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val);

#endif // XGBOOST_WRAPPER_R_H_ // NOLINT(*)
Loading