diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 60f955399877..1bbdfa63160f 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -374,7 +374,7 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs, : param.axis); CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis; - const int channelCount = dshape[channelAxis]; + const index_t channelCount = dshape[channelAxis]; in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount)); in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount)); diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index c3ccd0d7a6bc..11178b358c2d 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -51,7 +51,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, CHECK(axis >= 0 && axis < dshape.ndim()) << "Channel axis out of range: axis=" << param.axis; - const int channelCount = dshape[axis]; + const index_t channelCount = dshape[axis]; SHAPE_ASSIGN_CHECK(*in_shape, layernorm::kGamma,