Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

gather_nd: check bound and wrap negative indices #17208

Merged
merged 3 commits into from
Jan 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 1 addition & 31 deletions include/mxnet/c_api_error.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,33 +54,6 @@
on_exit_api(); \
return 0; // NOLINT(*)

//--------------------------------------------------------
// Error handling mechanism
// -------------------------------------------------------
// Standard error message format, {} means optional
//--------------------------------------------------------
// {error_type:} {message0}
// {message1}
// {message2}
// {Stack trace:} // stack traces follow by this line
// {trace 0} // two spaces in the begining.
// {trace 1}
// {trace 2}
//--------------------------------------------------------
/*!
* \brief Normalize error message
*
* Parse them header generated by by LOG(FATAL) and CHECK
* and reformat the message into the standard format.
*
* This function will also merge all the stack traces into
* one trace and trim them.
*
* \param err_msg The error message.
* \return normalized message.
*/
std::string NormalizeError(std::string err_msg);

/*!
* \brief Set the last error message needed by C API
* \param msg The error message to set.
Expand All @@ -91,10 +64,7 @@ void MXAPISetLastError(const char* msg);
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int MXAPIHandleException(const std::exception &e) {
MXAPISetLastError(NormalizeError(e.what()).c_str());
return -1;
}
int MXAPIHandleException(const std::exception &e);

namespace mxnet {
extern void on_enter_api(const char *function);
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ def __init__(self, msg):
register_error("ValueError", ValueError)
register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError)
register_error("IndexError", IndexError)
7 changes: 6 additions & 1 deletion src/c_api/c_api_error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ std::string NormalizeError(std::string err_msg) {
if (!getline(is, file_name, ':')) {
return false;
} else {
if (is.peek() == '\\') {
if (is.peek() == '\\' || is.peek() == '/') {
// windows path
if (!getline(is, line, ':')) return false;
file_name = file_name + ':' + line;
Expand Down Expand Up @@ -192,6 +192,11 @@ std::string NormalizeError(std::string err_msg) {
}
#endif

int MXAPIHandleException(const std::exception &e) {
MXAPISetLastError(NormalizeError(e.what()).c_str());
return -1;
}

const char *MXGetLastError() {
return NNGetLastError();
}
Expand Down
65 changes: 64 additions & 1 deletion src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,65 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,
});
}

/*
* \brief check if any of the indices is out of bound
* \param s the stream
* \param idx_ptr the indices on the stream
* \param N the number of indices in an axis
* \param M the number of axises to exmaine
* \param mshape the array that stores shape for each dimension
* \param is_valid_dim_ptr the temparary workspace that contains out-of-bound indices
*/
template<typename DType>
void GatherNDCheckBoundCPU(mshadow::Stream<cpu> *s, const DType* idx_ptr, index_t N,
index_t M, const mshadow::Shape<10> mshape, DType* is_valid_dim_ptr) {
using namespace mxnet_op;
Kernel<set_zero, cpu>::Launch(s, M, is_valid_dim_ptr);
Kernel<is_valid_check_gather_nd, cpu>::Launch(s, M, is_valid_dim_ptr, idx_ptr, N, mshape);
for (int m = 0; m < M; m++) {
if (is_valid_dim_ptr[m] > mshape[m] - 1 || is_valid_dim_ptr[m] < - mshape[m]) {
LOG(FATAL)<< "IndexError: index " << is_valid_dim_ptr[m] << " is out of bounds for axis "
<< m << " with size " << mshape[m];
}
}
}

void GatherNDForwardCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (req[0] == kNullOp) return;
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
const mxnet::TShape& dshape = inputs[0].shape_;
const mxnet::TShape& ishape = inputs[1].shape_;
int M = ishape[0];
int N = ishape.Size() / M;
int K = dshape.ProdShape(M, dshape.ndim());
mshadow::Shape<10> strides;
mshadow::Shape<10> mshape;
for (int i = M-1, stride = K; i >= 0; stride *= dshape[i], --i) {
strides[i] = stride;
mshape[i] = dshape[i];
}
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
// check whether indices are out of bound
IType* idx_ptr = inputs[1].dptr<IType>();
Tensor<cpu, 1, IType> workspace =
ctx.requested[0].get_space_typed<cpu, 1, IType>(Shape1(M), s);
IType* is_valid_dim_ptr = reinterpret_cast<IType*>(workspace.dptr_);
GatherNDCheckBoundCPU(s, idx_ptr, N, M, mshape, is_valid_dim_ptr);
Kernel<gather_nd, cpu>::Launch(
s, N, req[0], N, M, K, strides, mshape, outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>());
});
});
}

template<typename DType, typename IType>
inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
Expand Down Expand Up @@ -872,7 +931,11 @@ Examples::
})
.set_attr<mxnet::FInferShape>("FInferShape", GatherNDShape)
.set_attr<nnvm::FInferType>("FInferType", GatherNDType)
.set_attr<FCompute>("FCompute<cpu>", GatherNDForward<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", GatherNDForwardCPU)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto p = nnvm::Node::Create();
Expand Down
68 changes: 66 additions & 2 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ namespace op {

/*! \brief If there are out-of-bound indices, out will be assigned to 1.
*/

struct is_valid_check {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, char* out, const DType* data,
Expand Down Expand Up @@ -437,6 +436,71 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const bool deterministic,
});
}

/*
* \brief check if any of the indices is out of bound
* \param s the stream
* \param idx_ptr the indices on the stream
* \param N the number of indices in an axis
* \param M the number of axises to exmaine
* \param mshape the array that stores shape for each dimension
* \param is_valid_dim_ptr the temparary workspace that contains out-of-bound indices
*/
template<typename DType>
void GatherNDCheckBoundGPU(mshadow::Stream<gpu> *s, const DType* idx_ptr, index_t N,
index_t M, const mshadow::Shape<10> mshape, DType* is_valid_dim_ptr) {
using namespace mxnet_op;
Kernel<set_zero, gpu>::Launch(s, M, is_valid_dim_ptr);
Kernel<is_valid_check_gather_nd, gpu>::Launch(s, M, is_valid_dim_ptr, idx_ptr, N, mshape);

std::vector<DType> is_valid_dim(M);
CUDA_CALL(cudaMemcpyAsync(is_valid_dim.data(), is_valid_dim_ptr, sizeof(DType)*M,
cudaMemcpyDeviceToHost, mshadow::Stream<gpu>::GetStream(s)));
CUDA_CALL(cudaStreamSynchronize(mshadow::Stream<gpu>::GetStream(s)));
for (int m = 0; m < M; m++) {
if (is_valid_dim[m] > mshape[m] - 1 || is_valid_dim[m] < - mshape[m]) {
LOG(FATAL)<< "IndexError: index " << is_valid_dim[m] << " is out of bounds for axis "
<< m << " with size " << mshape[m];
}
}
}

void GatherNDForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (req[0] == kNullOp) return;
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
const mxnet::TShape& dshape = inputs[0].shape_;
const mxnet::TShape& ishape = inputs[1].shape_;
int M = ishape[0];
int N = ishape.Size() / M;
int K = dshape.ProdShape(M, dshape.ndim());
mshadow::Shape<10> strides;
mshadow::Shape<10> mshape;
for (int i = M-1, stride = K; i >= 0; stride *= dshape[i], --i) {
strides[i] = stride;
mshape[i] = dshape[i];
}
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
// check whether indices are out of bound
IType* idx_ptr = inputs[1].dptr<IType>();
Tensor<gpu, 1, IType> workspace =
ctx.requested[0].get_space_typed<gpu, 1, IType>(Shape1(M), s);
IType* is_valid_dim_ptr = reinterpret_cast<IType*>(workspace.dptr_);
GatherNDCheckBoundGPU(s, idx_ptr, N, M, mshape, is_valid_dim_ptr);
Kernel<gather_nd, gpu>::Launch(
s, N, req[0], N, M, K, strides, mshape, outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>());
});
});
}

struct backward_gather_nd_gpu {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, index_t N, index_t M, index_t K,
Expand Down Expand Up @@ -813,7 +877,7 @@ NNVM_REGISTER_OP(one_hot)
.set_attr<FCompute>("FCompute<gpu>", OneHotOpForward<gpu>);

NNVM_REGISTER_OP(gather_nd)
.set_attr<FCompute>("FCompute<gpu>", GatherNDForward<gpu>);
.set_attr<FCompute>("FCompute<gpu>", GatherNDForwardGPU);

NNVM_REGISTER_OP(scatter_nd)
.set_attr<FCompute>("FCompute<gpu>", ScatterNDForward<gpu>);
Expand Down
83 changes: 37 additions & 46 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
}

/*! \brief name the struct TakeNonzeroAxis for general take when
* axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero
* axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero
*/
template<bool clip = true>
struct TakeNonzeroAxis {
Expand Down Expand Up @@ -1272,6 +1272,42 @@ void OneHotOpForward(const nnvm::NodeAttrs& attrs,
});
}

struct gather_nd {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, OpReqType req, index_t N, index_t M, index_t K,
const mshadow::Shape<10> strides,
const mshadow::Shape<10> mshape,
DType* out, const DType* data,
const IType* indices) {
index_t offset = 0;
for (index_t j = 0; j < M; ++j) {
offset += strides[j] * (static_cast<index_t>(indices[j*N + i] + mshape[j])%mshape[j]);
}
for (index_t j = 0; j < K; ++j) {
KERNEL_ASSIGN(out[i*K + j], req, data[offset+j]);
}
}
};

/*!
* \brief If any index in a dimension is out of bound,
then the value in this dimension will be set to be the out-of-bound index
*/
struct is_valid_check_gather_nd {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* is_valid_dim_ptr, const DType* idx_ptr,
const index_t N, const mshadow::Shape<10> mshape) {
index_t n = N - 1;
while (n >= 0) {
if (idx_ptr[i*N + n] < -mshape[i] || idx_ptr[i*N + n] > mshape[i] - 1) {
is_valid_dim_ptr[i] = idx_ptr[i*N + n];
break;
}
n--;
}
}
};

inline bool GatherNDShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
Expand Down Expand Up @@ -1315,51 +1351,6 @@ inline bool GatherNDType(const nnvm::NodeAttrs& attrs,
return true;
}

struct gather_nd {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i, OpReqType req, index_t N, index_t M, index_t K,
const mshadow::Shape<10> strides,
DType* out, const DType* data,
const IType* indices) {
index_t offset = 0;
for (index_t j = 0; j < M; ++j) {
offset += strides[j] * static_cast<index_t>(indices[j*N + i]);
}
for (index_t j = 0; j < K; ++j) {
KERNEL_ASSIGN(out[i*K + j], req, data[offset+j]);
}
}
};

template<typename xpu>
void GatherNDForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
if (req[0] == kNullOp) return;
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const mxnet::TShape& dshape = inputs[0].shape_;
const mxnet::TShape& ishape = inputs[1].shape_;
int M = ishape[0];
int N = ishape.Size() / M;
int K = dshape.ProdShape(M, dshape.ndim());
mshadow::Shape<10> strides;
for (int i = M-1, stride = K; i >= 0; stride *= dshape[i], --i) strides[i] = stride;
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { // output data type switch
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // indices data type switch
Kernel<gather_nd, xpu>::Launch(
s, N, req[0], N, M, K, strides, outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<IType>());
});
});
}


struct ScatterNDParam : public dmlc::Parameter<ScatterNDParam> {
mxnet::TShape shape;
DMLC_DECLARE_PARAMETER(ScatterNDParam) {
Expand Down
34 changes: 32 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7166,6 +7166,36 @@ def check(data, idx):
idx = mx.nd.array([[0, 0, 0, 0]], dtype='int32')
assert (mx.nd._internal._backward_gather_nd(data, idx, shape=(1,)).asscalar() == data.asnumpy().sum())

@with_seed()
def test_gather_nd_check_bound():
# check if indices is out of bound
data = mx.nd.array([[0, 1, 2], [3, 4, 5]])
indices1 = mx.nd.array([[0, 1, 0], [0, 1, 3]])
indices2 = mx.nd.array([[0, 1, 0], [0, 1, -5]])
try:
mx.nd.gather_nd(data, indices1)
mx.nd.waitall()
except IndexError:
# skip errors since the test is supposed to raise error
# IndexError: index 3 is out of bounds for axis 1 with size 3
pass

try:
mx.nd.gather_nd(data, indices2)
mx.nd.waitall()
except IndexError:
# skip errors since the test is supposed to raise error
# IndexError: index -5 is out of bounds for axis 1 with size 3
pass

# check if the negative indices are wrapped correctly
indices1 = mx.nd.array([[0, 1, -1], [0, 1, -2]])
indices2 = mx.nd.array([[0, 1, 1], [0, 1, 1]])
data1 = mx.nd.gather_nd(data, indices1)
data2 = mx.nd.gather_nd(data, indices2)
assert_almost_equal(data1, data2, rtol=1e-5, atol=1e-5)

haojin2 marked this conversation as resolved.
Show resolved Hide resolved

def compare_forw_backw_unary_op(
name, forward_mxnet_call, forward_numpy_call,
backward_numpy_call, shape, input_low, input_high, rtol, atol,
Expand Down Expand Up @@ -7821,7 +7851,7 @@ def check_bilinear_resize_align_corners_op():

x = np.array(data, dtype=np.float32).reshape(img_shape)
x_nd = mx.nd.array(x)

y0 = np.array(expected_data[0]).reshape((1, 1, target_height, target_width))
y0_nd = mx.nd.contrib.BilinearResize2D(x_nd, height=target_height, width=target_width, mode='size', align_corners=False)
assert_almost_equal(y0, y0_nd.asnumpy(), atol=1e-3)
Expand Down Expand Up @@ -9586,7 +9616,7 @@ def convert_bias(F, k_bias, v_bias, num_heads):
q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False,
num_hidden=qkv_units, no_bias=False)
att_score = mx.sym.contrib.interleaved_matmul_encdec_qk(
q_proj, kv_proj, heads=num_heads)
q_proj, kv_proj, heads=num_heads)
att_score = att_score + sonde
weighted_value = mx.sym.contrib.interleaved_matmul_encdec_valatt(
kv_proj, att_score, heads=num_heads)
Expand Down