Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

boolean_mask_assign with start_axis #16886

Merged
merged 1 commit into from
Dec 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,15 @@ inline std::string dev_type_string(const int dev_type) {
return "unknown";
}

inline std::string attr_value_string(const nnvm::NodeAttrs& attrs,
const std::string& attr_name,
std::string default_val = "") {
if (attrs.dict.find(attr_name) == attrs.dict.end()) {
return default_val;
}
return attrs.dict.at(attr_name);
}

/*! \brief get string representation of the operator stypes */
inline std::string operator_stype_string(const nnvm::NodeAttrs& attrs,
const int dev_mask,
Expand Down Expand Up @@ -463,10 +472,10 @@ inline std::string operator_stype_string(const nnvm::NodeAttrs& attrs,

/*! \brief get string representation of the operator */
inline std::string operator_string(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
std::string result = "";
std::vector<int> in_stypes;
std::vector<int> out_stypes;
Expand Down
75 changes: 60 additions & 15 deletions src/operator/numpy/np_boolean_mask_assign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief CPU implementation of Boolean Mask Assign
*/

#include "../../common/utils.h"
#include "../contrib/boolean_mask-inl.h"

namespace mxnet {
Expand Down Expand Up @@ -88,14 +89,16 @@ struct BooleanAssignCPUKernel {
const size_t idx_size,
const size_t leading,
const size_t middle,
const size_t valid_num,
const size_t trailing,
DType* tensor) {
// binary search for the turning point
size_t mid = bin_search(idx, idx_size, i);
// final answer is in mid
for (size_t l = 0; l < leading; ++l) {
for (size_t t = 0; t < trailing; ++t) {
data[(l * middle + mid) * trailing + t] = (scalar) ? tensor[0] : tensor[i];
data[(l * middle + mid) * trailing + t] =
(scalar) ? tensor[0] : tensor[(l * valid_num + i) * trailing + t];
}
}
}
Expand All @@ -106,19 +109,47 @@ bool BooleanAssignShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *out_attrs) {
CHECK(in_attrs->size() == 2U || in_attrs->size() == 3U);
CHECK_EQ(out_attrs->size(), 1U);
CHECK(shape_is_known(in_attrs->at(0)) && shape_is_known(in_attrs->at(1)))
<< "shape of both input and mask should be known";
const TShape& dshape = in_attrs->at(0);
const TShape& mshape = in_attrs->at(1);
const int start_axis = std::stoi(common::attr_value_string(attrs, "start_axis", "0"));

// mask should have the same shape as the input
SHAPE_ASSIGN_CHECK(*in_attrs, 1, dshape);
for (int i = 0; i < mshape.ndim(); ++i) {
CHECK_EQ(dshape[i + start_axis], mshape[i])
<< "boolean index did not match indexed array along dimension " << i + start_axis
<< "; dimension is " << dshape[i + start_axis] << " but corresponding boolean dimension is "
<< mshape[i];
}

// check if output shape is the same as the input data
SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);

// for tensor version, the tensor should have less than 1 dimension
if (in_attrs->size() == 3U) {
CHECK_LE(in_attrs->at(2).ndim(), 1U)
<< "boolean array indexing assignment requires a 0 or 1-dimensional input, input has "
<< in_attrs->at(2).ndim() <<" dimensions";
if (mshape.ndim() == dshape.ndim()) {
CHECK_LE(in_attrs->at(2).ndim(), 1U)
<< "boolean array indexing assignment requires a 0 or 1-dimensional input, input has "
<< in_attrs->at(2).ndim() <<" dimensions";
} else {
const TShape& vshape = in_attrs->at(2);
if (vshape.Size() > 1) {
for (int i = 0; i < dshape.ndim(); ++i) {
if (i < start_axis) {
CHECK_EQ(dshape[i], vshape[i])
<< "shape mismatch of value with input at dimension " << i
<< "; dimension is " << dshape[i] << " but corresponding value dimension is "
<< vshape[i];
}
if (i >= start_axis + mshape.ndim()) {
CHECK_EQ(dshape[i], vshape[i - mshape.ndim() + 1])
<< "shape mismatch of value with input at dimension " << i
<< "; dimension is " << dshape[i] << " but corresponding value dimension is "
<< vshape[i - mshape.ndim() + 1];
}
}
}
}
}

return shape_is_known(out_attrs->at(0));
Expand Down Expand Up @@ -170,22 +201,26 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
Stream<cpu>* s = ctx.get_stream<cpu>();

const TBlob& data = inputs[0];
const TShape& dshape = data.shape_;
const TBlob& mask = inputs[1];
const TShape& mshape = mask.shape_;
const int start_axis = std::stoi(common::attr_value_string(attrs, "start_axis", "0"));
// Get valid_num
size_t valid_num = 0;
size_t mask_size = mask.shape_.Size();
std::vector<size_t> prefix_sum(mask_size + 1, 0);
MSHADOW_TYPE_SWITCH(mask.type_flag_, MType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(mask.type_flag_, MType, {
valid_num = GetValidNumCPU(mask.dptr<MType>(), prefix_sum.data(), mask_size);
});
// If there's no True in mask, return directly
if (valid_num == 0) return;

if (inputs.size() == 3U) {
const TShape& vshape = inputs[2].shape_;
if (inputs[2].shape_.Size() != 1) {
// tensor case, check tensor size with the valid_num
CHECK_EQ(static_cast<size_t>(valid_num), inputs[2].shape_.Size())
<< "boolean array indexing assignment cannot assign " << inputs[2].shape_.Size()
CHECK_EQ(static_cast<size_t>(valid_num), vshape[start_axis])
<< "boolean array indexing assignment cannot assign " << vshape
<< " input values to the " << valid_num << " output values where the mask is true"
<< std::endl;
}
Expand All @@ -195,21 +230,29 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
size_t middle = mask_size;
size_t trailing = 1U;

for (int i = 0; i < dshape.ndim(); ++i) {
if (i < start_axis) {
leading *= dshape[i];
}
if (i >= start_axis + mshape.ndim()) {
trailing *= dshape[i];
}
}

if (inputs.size() == 3U) {
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
if (inputs[2].shape_.Size() == 1) {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
s, valid_num, data.dptr<DType>(), prefix_sum.data(), prefix_sum.size(),
leading, middle, trailing, inputs[2].dptr<DType>());
leading, middle, valid_num, trailing, inputs[2].dptr<DType>());
} else {
Kernel<BooleanAssignCPUKernel<false>, cpu>::Launch(
s, valid_num, data.dptr<DType>(), prefix_sum.data(), prefix_sum.size(),
leading, middle, trailing, inputs[2].dptr<DType>());
leading, middle, valid_num, trailing, inputs[2].dptr<DType>());
}
});
} else {
CHECK(attrs.dict.find("value") != attrs.dict.end())
<< "value needs be provided";
CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value needs be provided";
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
s, valid_num, data.dptr<DType>(), prefix_sum.data(), prefix_sum.size(),
Expand Down Expand Up @@ -240,7 +283,8 @@ NNVM_REGISTER_OP(_npi_boolean_mask_assign_scalar)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("mask", "NDArray-or-Symbol", "mask")
.add_argument("value", "float", "value to be assigned to masked positions");
.add_argument("value", "float", "value to be assigned to masked positions")
.add_argument("start_axis", "int", "starting axis of boolean mask");

NNVM_REGISTER_OP(_npi_boolean_mask_assign_tensor)
.describe(R"code(Tensor version of boolean assign)code" ADD_FILELINE)
Expand All @@ -264,7 +308,8 @@ NNVM_REGISTER_OP(_npi_boolean_mask_assign_tensor)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "input")
.add_argument("mask", "NDArray-or-Symbol", "mask")
.add_argument("value", "NDArray-or-Symbol", "assignment");
.add_argument("value", "NDArray-or-Symbol", "assignment")
.add_argument("start_axis", "int", "starting axis of boolean mask");

} // namespace op
} // namespace mxnet
51 changes: 39 additions & 12 deletions src/operator/numpy/np_boolean_mask_assign.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/

#include <cub/cub.cuh>
#include "../../common/utils.h"
#include "../contrib/boolean_mask-inl.h"

namespace mxnet {
Expand Down Expand Up @@ -70,13 +71,17 @@ struct BooleanAssignGPUKernel {
const size_t idx_size,
const size_t leading,
const size_t middle,
const size_t valid_num,
const size_t trailing,
const DType val) {
// binary search for the turning point
size_t m = i / trailing % middle;
size_t m = i / trailing % valid_num;
size_t l = i / trailing / valid_num;
size_t mid = bin_search(idx, idx_size, m);
// final answer is in mid
data[i + (mid - m) * trailing] = val;
// i = l * valid_num * trailing + m * trailing + t
// dst = l * middle * trailing + mid * trailing + t
data[i + (l * (middle - valid_num) + (mid - m)) * trailing] = val;
}

template<typename DType>
Expand All @@ -86,13 +91,20 @@ struct BooleanAssignGPUKernel {
const size_t idx_size,
const size_t leading,
const size_t middle,
const size_t valid_num,
const size_t trailing,
DType* tensor) {
// binary search for the turning point
size_t m = i / trailing % middle;
size_t m = i / trailing % valid_num;
size_t l = i / trailing / valid_num;
size_t mid = bin_search(idx, idx_size, m);
size_t dst = i + (l * (middle - valid_num) + (mid - m)) * trailing;
// final answer is in mid
data[i + (mid - m) * trailing] = (scalar) ? tensor[0] : tensor[m];
if (scalar) {
data[dst] = tensor[0];
} else {
data[dst] = tensor[i];
}
}
};

Expand Down Expand Up @@ -166,28 +178,34 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs,
Stream<gpu>* s = ctx.get_stream<gpu>();

const TBlob& data = inputs[0];
const TShape& dshape = data.shape_;
const TBlob& mask = inputs[1];
const TShape& mshape = mask.shape_;
const int start_axis = std::stoi(common::attr_value_string(attrs, "start_axis", "0"));

// Get valid_num
size_t mask_size = mask.shape_.Size();
size_t valid_num = 0;
size_t* prefix_sum = nullptr;
if (mask_size != 0) {
MSHADOW_TYPE_SWITCH(mask.type_flag_, MType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(mask.type_flag_, MType, {
prefix_sum = GetValidNumGPU<MType>(ctx, mask.dptr<MType>(), mask_size);
});
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
CUDA_CALL(cudaMemcpyAsync(&valid_num, &prefix_sum[mask_size], sizeof(size_t),
cudaMemcpyDeviceToHost, stream));
CUDA_CALL(cudaStreamSynchronize(stream));
}

// If there's no True in mask, return directly
if (valid_num == 0) return;

if (inputs.size() == 3U) {
const TShape& vshape = inputs[2].shape_;
if (inputs[2].shape_.Size() != 1) {
// tensor case, check tensor size with the valid_num
CHECK_EQ(static_cast<size_t>(valid_num), inputs[2].shape_.Size())
<< "boolean array indexing assignment cannot assign " << inputs[2].shape_.Size()
CHECK_EQ(static_cast<size_t>(valid_num), vshape[start_axis])
<< "boolean array indexing assignment cannot assign " << vshape
<< " input values to the " << valid_num << " output values where the mask is true"
<< std::endl;
}
Expand All @@ -197,27 +215,36 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs,
size_t middle = mask_size;
size_t trailing = 1U;

for (int i = 0; i < dshape.ndim(); ++i) {
if (i < start_axis) {
leading *= dshape[i];
}
if (i >= start_axis + mshape.ndim()) {
trailing *= dshape[i];
}
}

if (inputs.size() == 3U) {
if (inputs[2].shape_.Size() == 1) {
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
Kernel<BooleanAssignGPUKernel<true>, gpu>::Launch(
s, leading * valid_num * trailing, data.dptr<DType>(), prefix_sum, mask_size + 1,
leading, middle, trailing, inputs[2].dptr<DType>());
leading, middle, valid_num, trailing, inputs[2].dptr<DType>());
});
} else {
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
Kernel<BooleanAssignGPUKernel<false>, gpu>::Launch(
s, leading * valid_num * trailing, data.dptr<DType>(), prefix_sum, mask_size + 1,
leading, middle, trailing, inputs[2].dptr<DType>());
leading, middle, valid_num, trailing, inputs[2].dptr<DType>());
});
}
} else {
CHECK(attrs.dict.find("value") != attrs.dict.end())
<< "value is not provided";
CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value is not provided";
double value = std::stod(attrs.dict.at("value"));
MSHADOW_TYPE_SWITCH(data.type_flag_, DType, {
Kernel<BooleanAssignGPUKernel<true>, gpu>::Launch(
s, leading * valid_num * trailing, data.dptr<DType>(), prefix_sum, mask_size + 1,
leading, middle, trailing, static_cast<DType>(std::stod(attrs.dict.at("value"))));
leading, middle, valid_num, trailing, static_cast<DType>(value));
});
}
}
Expand Down
Loading