diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index 7c5b5387dda9..a19a22a56a9e 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -28,15 +28,13 @@ #define R_API_BEGIN() \ try { #define R_API_END() } \ - catch(std::exception& ex) { LGBM_SetLastError(ex.what()); return R_NilValue;} \ - catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); return R_NilValue; } \ - catch(...) { LGBM_SetLastError("unknown exception"); return R_NilValue;} \ - return R_NilValue; + catch(std::exception& ex) { LGBM_SetLastError(ex.what()); } \ + catch(std::string& ex) { LGBM_SetLastError(ex.c_str()); } \ + catch(...) { LGBM_SetLastError("unknown exception"); } #define CHECK_CALL(x) \ if ((x) != 0) { \ Rf_error(LGBM_GetLastError()); \ - return R_NilValue; \ } using LightGBM::Common::Split; @@ -54,19 +52,20 @@ SEXP LGBM_DatasetCreateFromFile_R(SEXP filename, SEXP parameters, SEXP reference) { SEXP ret; - R_API_BEGIN(); DatasetHandle handle = nullptr; DatasetHandle ref = nullptr; if (!Rf_isNull(reference)) { ref = R_ExternalPtrAddr(reference); } - CHECK_CALL(LGBM_DatasetCreateFromFile(CHAR(Rf_asChar(filename)), CHAR(Rf_asChar(parameters)), - ref, &handle)); + const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); + const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); + R_API_BEGIN(); + CHECK_CALL(LGBM_DatasetCreateFromFile(filename_ptr, parameters_ptr, ref, &handle)); + R_API_END(); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); - UNPROTECT(1); + UNPROTECT(3); return ret; - R_API_END(); } SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, @@ -78,27 +77,27 @@ SEXP LGBM_DatasetCreateFromCSC_R(SEXP indptr, SEXP parameters, SEXP reference) { SEXP ret; - R_API_BEGIN(); const int* p_indptr = INTEGER(indptr); const int* p_indices = INTEGER(indices); const double* p_data = REAL(data); - int64_t nindptr = static_cast(Rf_asInteger(num_indptr)); int64_t ndata = static_cast(Rf_asInteger(nelem)); int64_t nrow = static_cast(Rf_asInteger(num_row)); + const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); DatasetHandle handle = nullptr; DatasetHandle ref = nullptr; if (!Rf_isNull(reference)) { ref = R_ExternalPtrAddr(reference); } + R_API_BEGIN(); CHECK_CALL(LGBM_DatasetCreateFromCSC(p_indptr, C_API_DTYPE_INT32, p_indices, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, - nrow, CHAR(Rf_asChar(parameters)), ref, &handle)); + nrow, parameters_ptr, ref, &handle)); + R_API_END(); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); - UNPROTECT(1); + UNPROTECT(2); return ret; - R_API_END(); } SEXP LGBM_DatasetCreateFromMat_R(SEXP data, @@ -107,22 +106,23 @@ SEXP LGBM_DatasetCreateFromMat_R(SEXP data, SEXP parameters, SEXP reference) { SEXP ret; - R_API_BEGIN(); int32_t nrow = static_cast(Rf_asInteger(num_row)); int32_t ncol = static_cast(Rf_asInteger(num_col)); double* p_mat = REAL(data); + const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); DatasetHandle handle = nullptr; DatasetHandle ref = nullptr; if (!Rf_isNull(reference)) { ref = R_ExternalPtrAddr(reference); } + R_API_BEGIN(); CHECK_CALL(LGBM_DatasetCreateFromMat(p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, - CHAR(Rf_asChar(parameters)), ref, &handle)); + parameters_ptr, ref, &handle)); + R_API_END(); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); - UNPROTECT(1); + UNPROTECT(2); return ret; - R_API_END(); } SEXP LGBM_DatasetGetSubset_R(SEXP handle, @@ -130,7 +130,6 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, SEXP len_used_row_indices, SEXP parameters) { SEXP ret; - R_API_BEGIN(); int32_t len = static_cast(Rf_asInteger(len_used_row_indices)); std::vector idxvec(len); // convert from one-based to zero-based index @@ -138,36 +137,41 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, for (int32_t i = 0; i < len; ++i) { idxvec[i] = static_cast(INTEGER(used_row_indices)[i] - 1); } + const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); DatasetHandle res = nullptr; + R_API_BEGIN(); CHECK_CALL(LGBM_DatasetGetSubset(R_ExternalPtrAddr(handle), - idxvec.data(), len, CHAR(Rf_asChar(parameters)), + idxvec.data(), len, parameters_ptr, &res)); + R_API_END(); ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue)); R_RegisterCFinalizerEx(ret, _DatasetFinalizer, TRUE); - UNPROTECT(1); + UNPROTECT(2); return ret; - R_API_END(); } SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle, SEXP feature_names) { - R_API_BEGIN(); - auto vec_names = Split(CHAR(Rf_asChar(feature_names)), '\t'); + auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t'); std::vector vec_sptr; int len = static_cast(vec_names.size()); for (int i = 0; i < len; ++i) { vec_sptr.push_back(vec_names[i].c_str()); } + R_API_BEGIN(); CHECK_CALL(LGBM_DatasetSetFeatureNames(R_ExternalPtrAddr(handle), vec_sptr.data(), len)); R_API_END(); + UNPROTECT(1); + return R_NilValue; } SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { SEXP feature_names; - R_API_BEGIN(); int len = 0; + R_API_BEGIN(); CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len)); + R_API_END(); const size_t reserved_string_size = 256; std::vector> names(len); std::vector ptr_names(len); @@ -177,12 +181,14 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { } int out_len; size_t required_string_size; + R_API_BEGIN(); CHECK_CALL( LGBM_DatasetGetFeatureNames( R_ExternalPtrAddr(handle), len, &out_len, reserved_string_size, &required_string_size, ptr_names.data())); + R_API_END(); // if any feature names were larger than allocated size, // allow for a larger size and try again if (required_string_size > reserved_string_size) { @@ -190,6 +196,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { names[i].resize(required_string_size); ptr_names[i] = names[i].data(); } + R_API_BEGIN(); CHECK_CALL( LGBM_DatasetGetFeatureNames( R_ExternalPtrAddr(handle), @@ -198,6 +205,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { required_string_size, &required_string_size, ptr_names.data())); + R_API_END(); } CHECK_EQ(len, out_len); feature_names = PROTECT(Rf_allocVector(STRSXP, len)); @@ -206,15 +214,17 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { } UNPROTECT(1); return feature_names; - R_API_END(); } SEXP LGBM_DatasetSaveBinary_R(SEXP handle, SEXP filename) { + const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); R_API_BEGIN(); CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle), - CHAR(Rf_asChar(filename)))); + filename_ptr)); R_API_END(); + UNPROTECT(1); + return R_NilValue; } SEXP LGBM_DatasetFree_R(SEXP handle) { @@ -224,15 +234,16 @@ SEXP LGBM_DatasetFree_R(SEXP handle) { R_ClearExternalPtr(handle); } R_API_END(); + return R_NilValue; } SEXP LGBM_DatasetSetField_R(SEXP handle, SEXP field_name, SEXP field_data, SEXP num_element) { - R_API_BEGIN(); int len = Rf_asInteger(num_element); - const char* name = CHAR(Rf_asChar(field_name)); + const char* name = CHAR(PROTECT(Rf_asChar(field_name))); + R_API_BEGIN(); if (!strcmp("group", name) || !strcmp("query", name)) { std::vector vec(len); #pragma omp parallel for schedule(static, 512) if (len >= 1024) @@ -251,18 +262,19 @@ SEXP LGBM_DatasetSetField_R(SEXP handle, CHECK_CALL(LGBM_DatasetSetField(R_ExternalPtrAddr(handle), name, vec.data(), len, C_API_DTYPE_FLOAT32)); } R_API_END(); + UNPROTECT(1); + return R_NilValue; } SEXP LGBM_DatasetGetField_R(SEXP handle, SEXP field_name, SEXP field_data) { - R_API_BEGIN(); - const char* name = CHAR(Rf_asChar(field_name)); + const char* name = CHAR(PROTECT(Rf_asChar(field_name))); int out_len = 0; int out_type = 0; const void* res; + R_API_BEGIN(); CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type)); - if (!strcmp("group", name) || !strcmp("query", name)) { auto p_data = reinterpret_cast(res); // convert from boundaries to size @@ -284,29 +296,37 @@ SEXP LGBM_DatasetGetField_R(SEXP handle, } } R_API_END(); + UNPROTECT(1); + return R_NilValue; } SEXP LGBM_DatasetGetFieldSize_R(SEXP handle, SEXP field_name, SEXP out) { - R_API_BEGIN(); - const char* name = CHAR(Rf_asChar(field_name)); + const char* name = CHAR(PROTECT(Rf_asChar(field_name))); int out_len = 0; int out_type = 0; const void* res; + R_API_BEGIN(); CHECK_CALL(LGBM_DatasetGetField(R_ExternalPtrAddr(handle), name, &out_len, &res, &out_type)); if (!strcmp("group", name) || !strcmp("query", name)) { out_len -= 1; } INTEGER(out)[0] = out_len; R_API_END(); + UNPROTECT(1); + return R_NilValue; } SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params, SEXP new_params) { + const char* old_params_ptr = CHAR(PROTECT(Rf_asChar(old_params))); + const char* new_params_ptr = CHAR(PROTECT(Rf_asChar(new_params))); R_API_BEGIN(); - CHECK_CALL(LGBM_DatasetUpdateParamChecking(CHAR(Rf_asChar(old_params)), CHAR(Rf_asChar(new_params)))); + CHECK_CALL(LGBM_DatasetUpdateParamChecking(old_params_ptr, new_params_ptr)); R_API_END(); + UNPROTECT(2); + return R_NilValue; } SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) { @@ -315,6 +335,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) { CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow)); INTEGER(out)[0] = nrow; R_API_END(); + return R_NilValue; } SEXP LGBM_DatasetGetNumFeature_R(SEXP handle, @@ -324,6 +345,7 @@ SEXP LGBM_DatasetGetNumFeature_R(SEXP handle, CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature)); INTEGER(out)[0] = nfeature; R_API_END(); + return R_NilValue; } // --- start Booster interfaces @@ -339,45 +361,49 @@ SEXP LGBM_BoosterFree_R(SEXP handle) { R_ClearExternalPtr(handle); } R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterCreate_R(SEXP train_data, SEXP parameters) { SEXP ret; - R_API_BEGIN(); + const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); BoosterHandle handle = nullptr; - CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), CHAR(Rf_asChar(parameters)), &handle)); + R_API_BEGIN(); + CHECK_CALL(LGBM_BoosterCreate(R_ExternalPtrAddr(train_data), parameters_ptr, &handle)); + R_API_END(); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); - UNPROTECT(1); + UNPROTECT(2); return ret; - R_API_END(); } SEXP LGBM_BoosterCreateFromModelfile_R(SEXP filename) { SEXP ret; - R_API_BEGIN(); int out_num_iterations = 0; + const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); BoosterHandle handle = nullptr; - CHECK_CALL(LGBM_BoosterCreateFromModelfile(CHAR(Rf_asChar(filename)), &out_num_iterations, &handle)); + R_API_BEGIN(); + CHECK_CALL(LGBM_BoosterCreateFromModelfile(filename_ptr, &out_num_iterations, &handle)); + R_API_END(); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); - UNPROTECT(1); + UNPROTECT(2); return ret; - R_API_END(); } SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) { SEXP ret; - R_API_BEGIN(); int out_num_iterations = 0; + const char* model_str_ptr = CHAR(PROTECT(Rf_asChar(model_str))); BoosterHandle handle = nullptr; - CHECK_CALL(LGBM_BoosterLoadModelFromString(CHAR(Rf_asChar(model_str)), &out_num_iterations, &handle)); + R_API_BEGIN(); + CHECK_CALL(LGBM_BoosterLoadModelFromString(model_str_ptr, &out_num_iterations, &handle)); + R_API_END(); ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); - UNPROTECT(1); + UNPROTECT(2); return ret; - R_API_END(); } SEXP LGBM_BoosterMerge_R(SEXP handle, @@ -385,6 +411,7 @@ SEXP LGBM_BoosterMerge_R(SEXP handle, R_API_BEGIN(); CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle))); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterAddValidData_R(SEXP handle, @@ -392,6 +419,7 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle, R_API_BEGIN(); CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data))); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, @@ -399,13 +427,17 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, R_API_BEGIN(); CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data))); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterResetParameter_R(SEXP handle, SEXP parameters) { + const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); R_API_BEGIN(); - CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(parameters)))); + CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr)); R_API_END(); + UNPROTECT(1); + return R_NilValue; } SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, @@ -415,6 +447,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class)); INTEGER(out)[0] = num_class; R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) { @@ -422,6 +455,7 @@ SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) { R_API_BEGIN(); CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished)); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, @@ -439,12 +473,14 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, } CHECK_CALL(LGBM_BoosterUpdateOneIterCustom(R_ExternalPtrAddr(handle), tgrad.data(), thess.data(), &is_finished)); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) { R_API_BEGIN(); CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle))); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, @@ -454,6 +490,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration)); INTEGER(out)[0] = out_iteration; R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, @@ -462,6 +499,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, double* ptr_ret = REAL(out_result); CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, @@ -470,14 +508,15 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, double* ptr_ret = REAL(out_result); CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { SEXP eval_names; - R_API_BEGIN(); int len; + R_API_BEGIN(); CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len)); - + R_API_END(); const size_t reserved_string_size = 128; std::vector> names(len); std::vector ptr_names(len); @@ -488,12 +527,14 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { int out_len; size_t required_string_size; + R_API_BEGIN(); CHECK_CALL( LGBM_BoosterGetEvalNames( R_ExternalPtrAddr(handle), len, &out_len, reserved_string_size, &required_string_size, ptr_names.data())); + R_API_END(); // if any eval names were larger than allocated size, // allow for a larger size and try again if (required_string_size > reserved_string_size) { @@ -501,6 +542,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { names[i].resize(required_string_size); ptr_names[i] = names[i].data(); } + R_API_BEGIN(); CHECK_CALL( LGBM_BoosterGetEvalNames( R_ExternalPtrAddr(handle), @@ -509,6 +551,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { required_string_size, &required_string_size, ptr_names.data())); + R_API_END(); } CHECK_EQ(out_len, len); eval_names = PROTECT(Rf_allocVector(STRSXP, len)); @@ -517,7 +560,6 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { } UNPROTECT(1); return eval_names; - R_API_END(); } SEXP LGBM_BoosterGetEval_R(SEXP handle, @@ -531,6 +573,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle, CHECK_CALL(LGBM_BoosterGetEval(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); CHECK_EQ(out_len, len); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, @@ -541,6 +584,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len)); INTEGER(out)[0] = static_cast(len); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterGetPredict_R(SEXP handle, @@ -551,6 +595,7 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle, int64_t out_len; CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); R_API_END(); + return R_NilValue; } int GetPredictType(SEXP is_rawscore, SEXP is_leafidx, SEXP is_predcontrib) { @@ -577,12 +622,17 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle, SEXP num_iteration, SEXP parameter, SEXP result_filename) { - R_API_BEGIN(); + const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename))); + const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter))); + const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename))); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); - CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), CHAR(Rf_asChar(data_filename)), - Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), - CHAR(Rf_asChar(result_filename)))); + R_API_BEGIN(); + CHECK_CALL(LGBM_BoosterPredictForFile(R_ExternalPtrAddr(handle), data_filename_ptr, + Rf_asInteger(data_has_header), pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, + result_filename_ptr)); R_API_END(); + UNPROTECT(3); + return R_NilValue; } SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle, @@ -600,6 +650,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), &len)); INTEGER(out_len)[0] = static_cast(len); R_API_END(); + return R_NilValue; } SEXP LGBM_BoosterPredictForCSC_R(SEXP handle, @@ -616,23 +667,24 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle, SEXP num_iteration, SEXP parameter, SEXP out_result) { - R_API_BEGIN(); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); - const int* p_indptr = INTEGER(indptr); const int32_t* p_indices = reinterpret_cast(INTEGER(indices)); const double* p_data = REAL(data); - int64_t nindptr = static_cast(Rf_asInteger(num_indptr)); int64_t ndata = static_cast(Rf_asInteger(nelem)); int64_t nrow = static_cast(Rf_asInteger(num_row)); double* ptr_ret = REAL(out_result); int64_t out_len; + const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter))); + R_API_BEGIN(); CHECK_CALL(LGBM_BoosterPredictForCSC(R_ExternalPtrAddr(handle), p_indptr, C_API_DTYPE_INT32, p_indices, p_data, C_API_DTYPE_FLOAT64, nindptr, ndata, - nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret)); + nrow, pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret)); R_API_END(); + UNPROTECT(1); + return R_NilValue; } SEXP LGBM_BoosterPredictForMat_R(SEXP handle, @@ -646,75 +698,82 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle, SEXP num_iteration, SEXP parameter, SEXP out_result) { - R_API_BEGIN(); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); - int32_t nrow = static_cast(Rf_asInteger(num_row)); int32_t ncol = static_cast(Rf_asInteger(num_col)); - const double* p_mat = REAL(data); double* ptr_ret = REAL(out_result); + const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter))); int64_t out_len; + R_API_BEGIN(); CHECK_CALL(LGBM_BoosterPredictForMat(R_ExternalPtrAddr(handle), p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR, - pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), CHAR(Rf_asChar(parameter)), &out_len, ptr_ret)); - + pred_type, Rf_asInteger(start_iteration), Rf_asInteger(num_iteration), parameter_ptr, &out_len, ptr_ret)); R_API_END(); + UNPROTECT(1); + return R_NilValue; } SEXP LGBM_BoosterSaveModel_R(SEXP handle, SEXP num_iteration, SEXP feature_importance_type, SEXP filename) { + const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); R_API_BEGIN(); - CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), CHAR(Rf_asChar(filename)))); + CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr)); R_API_END(); + UNPROTECT(1); + return R_NilValue; } SEXP LGBM_BoosterSaveModelToString_R(SEXP handle, SEXP num_iteration, SEXP feature_importance_type) { SEXP model_str; - R_API_BEGIN(); int64_t out_len = 0; int64_t buf_len = 1024 * 1024; int num_iter = Rf_asInteger(num_iteration); int importance_type = Rf_asInteger(feature_importance_type); std::vector inner_char_buf(buf_len); + R_API_BEGIN(); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); + R_API_END(); // if the model string was larger than the initial buffer, allocate a bigger buffer and try again if (out_len > buf_len) { inner_char_buf.resize(out_len); + R_API_BEGIN(); CHECK_CALL(LGBM_BoosterSaveModelToString(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); + R_API_END(); } model_str = PROTECT(Rf_allocVector(STRSXP, 1)); SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data())); UNPROTECT(1); return model_str; - R_API_END(); } SEXP LGBM_BoosterDumpModel_R(SEXP handle, SEXP num_iteration, SEXP feature_importance_type) { SEXP model_str; - R_API_BEGIN(); int64_t out_len = 0; int64_t buf_len = 1024 * 1024; int num_iter = Rf_asInteger(num_iteration); int importance_type = Rf_asInteger(feature_importance_type); std::vector inner_char_buf(buf_len); + R_API_BEGIN(); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, buf_len, &out_len, inner_char_buf.data())); + R_API_END(); // if the model string was larger than the initial buffer, allocate a bigger buffer and try again if (out_len > buf_len) { inner_char_buf.resize(out_len); + R_API_BEGIN(); CHECK_CALL(LGBM_BoosterDumpModel(R_ExternalPtrAddr(handle), 0, num_iter, importance_type, out_len, &out_len, inner_char_buf.data())); + R_API_END(); } model_str = PROTECT(Rf_allocVector(STRSXP, 1)); SET_STRING_ELT(model_str, 0, Rf_mkChar(inner_char_buf.data())); UNPROTECT(1); return model_str; - R_API_END(); } // .Call() calls