diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 6563daa44776..d3e0c23c616c 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -45,6 +45,7 @@ // tress #include "../src/tree/split_evaluator.cc" +#include "../src/tree/param.cc" #include "../src/tree/tree_model.cc" #include "../src/tree/tree_updater.cc" #include "../src/tree/updater_colmaker.cc" diff --git a/src/tree/param.cc b/src/tree/param.cc new file mode 100644 index 000000000000..8049501ea094 --- /dev/null +++ b/src/tree/param.cc @@ -0,0 +1,81 @@ +/*! + * Copyright by Contributors 2019 + */ +#include +#include +#include + +#include "param.h" + +namespace std { +std::istream &operator>>(std::istream &is, std::vector &t) { + t.clear(); + // get ( + while (true) { + char ch = is.peek(); + if (isdigit(ch)) { + int idx; + if (is >> idx) { + t.emplace_back(idx); + } + return is; + } + is.get(); + if (ch == '(') { + break; + } + if (!isspace(ch)) { + is.setstate(std::ios::failbit); + return is; + } + } + int idx; + std::vector tmp; + while (true) { + char ch = is.peek(); + if (isspace(ch)) { + is.get(); + } else { + break; + } + } + if (is.peek() == ')') { + is.get(); + return is; + } + while (is >> idx) { + tmp.push_back(idx); + char ch; + do { + ch = is.get(); + } while (isspace(ch)); + if (ch == 'L') { + ch = is.get(); + } + if (ch == ',') { + while (true) { + ch = is.peek(); + if (isspace(ch)) { + is.get(); + continue; + } + if (ch == ')') { + is.get(); + break; + } + break; + } + if (ch == ')') { + break; + } + } else if (ch == ')') { + break; + } else { + is.setstate(std::ios::failbit); + return is; + } + } + t = std::move(tmp); + return is; +} +} // namespace std diff --git a/src/tree/param.h b/src/tree/param.h index ded36a313397..7c858c975284 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -495,63 +495,7 @@ inline std::ostream &operator<<(std::ostream &os, const std::vector &t) { return os; } -inline std::istream &operator>>(std::istream &is, std::vector &t) { - // get ( - while (true) { - char ch = is.peek(); - if (isdigit(ch)) { - int idx; - if (is >> idx) { - t.assign(&idx, &idx + 1); - } - return is; - } - is.get(); - if (ch == '(') { - break; - } - if (!isspace(ch)) { - is.setstate(std::ios::failbit); - return is; - } - } - int idx; - std::vector tmp; - while (is >> idx) { - tmp.push_back(idx); - char ch; - do { - ch = is.get(); - } while (isspace(ch)); - if (ch == 'L') { - ch = is.get(); - } - if (ch == ',') { - while (true) { - ch = is.peek(); - if (isspace(ch)) { - is.get(); - continue; - } - if (ch == ')') { - is.get(); - break; - } - break; - } - if (ch == ')') { - break; - } - } else if (ch == ')') { - break; - } else { - is.setstate(std::ios::failbit); - return is; - } - } - t.assign(tmp.begin(), tmp.end()); - return is; -} +std::istream &operator>>(std::istream &is, std::vector &t); } // namespace std #endif // XGBOOST_TREE_PARAM_H_ diff --git a/tests/cpp/tree/test_param.cc b/tests/cpp/tree/test_param.cc index 3f4e50ba245c..b4cc4005e3ad 100644 --- a/tests/cpp/tree/test_param.cc +++ b/tests/cpp/tree/test_param.cc @@ -74,6 +74,12 @@ TEST(Param, VectorStreamRead) { ss << "(3,2,1"; ss >> vals_in; EXPECT_NE(vals_in, vals); + + vals_in.clear(); ss.flush(); ss.clear(); ss.str(""); + vals_in.emplace_back(3); + ss << "( )"; + ss >> vals_in; + ASSERT_TRUE(ss.good()); } TEST(Param, SplitEntry) {