Skip to content

Commit

Permalink
Fix parsing empty vector in parameter. (#5087)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Dec 5, 2019
1 parent f5e13dc commit df9bdbb
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 57 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

// tress
#include "../src/tree/split_evaluator.cc"
#include "../src/tree/param.cc"
#include "../src/tree/tree_model.cc"
#include "../src/tree/tree_updater.cc"
#include "../src/tree/updater_colmaker.cc"
Expand Down
81 changes: 81 additions & 0 deletions src/tree/param.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*!
* Copyright by Contributors 2019
*/
#include <iostream>
#include <vector>
#include <utility>

#include "param.h"

namespace std {
std::istream &operator>>(std::istream &is, std::vector<int> &t) {
t.clear();
// get (
while (true) {
char ch = is.peek();
if (isdigit(ch)) {
int idx;
if (is >> idx) {
t.emplace_back(idx);
}
return is;
}
is.get();
if (ch == '(') {
break;
}
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
}
}
int idx;
std::vector<int> tmp;
while (true) {
char ch = is.peek();
if (isspace(ch)) {
is.get();
} else {
break;
}
}
if (is.peek() == ')') {
is.get();
return is;
}
while (is >> idx) {
tmp.push_back(idx);
char ch;
do {
ch = is.get();
} while (isspace(ch));
if (ch == 'L') {
ch = is.get();
}
if (ch == ',') {
while (true) {
ch = is.peek();
if (isspace(ch)) {
is.get();
continue;
}
if (ch == ')') {
is.get();
break;
}
break;
}
if (ch == ')') {
break;
}
} else if (ch == ')') {
break;
} else {
is.setstate(std::ios::failbit);
return is;
}
}
t = std::move(tmp);
return is;
}
} // namespace std
58 changes: 1 addition & 57 deletions src/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,63 +495,7 @@ inline std::ostream &operator<<(std::ostream &os, const std::vector<int> &t) {
return os;
}

inline std::istream &operator>>(std::istream &is, std::vector<int> &t) {
// get (
while (true) {
char ch = is.peek();
if (isdigit(ch)) {
int idx;
if (is >> idx) {
t.assign(&idx, &idx + 1);
}
return is;
}
is.get();
if (ch == '(') {
break;
}
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
}
}
int idx;
std::vector<int> tmp;
while (is >> idx) {
tmp.push_back(idx);
char ch;
do {
ch = is.get();
} while (isspace(ch));
if (ch == 'L') {
ch = is.get();
}
if (ch == ',') {
while (true) {
ch = is.peek();
if (isspace(ch)) {
is.get();
continue;
}
if (ch == ')') {
is.get();
break;
}
break;
}
if (ch == ')') {
break;
}
} else if (ch == ')') {
break;
} else {
is.setstate(std::ios::failbit);
return is;
}
}
t.assign(tmp.begin(), tmp.end());
return is;
}
std::istream &operator>>(std::istream &is, std::vector<int> &t);
} // namespace std

#endif // XGBOOST_TREE_PARAM_H_
6 changes: 6 additions & 0 deletions tests/cpp/tree/test_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ TEST(Param, VectorStreamRead) {
ss << "(3,2,1";
ss >> vals_in;
EXPECT_NE(vals_in, vals);

vals_in.clear(); ss.flush(); ss.clear(); ss.str("");
vals_in.emplace_back(3);
ss << "( )";
ss >> vals_in;
ASSERT_TRUE(ss.good());
}

TEST(Param, SplitEntry) {
Expand Down

0 comments on commit df9bdbb

Please sign in to comment.