Skip to content

Commit

Permalink
Softmax omp support (apache#184)
Browse files Browse the repository at this point in the history
* add omp support for softmax

* fix lint issue

* remove omp for mapplan

* remove omp for mapplan

* remove MKL guard
  • Loading branch information
zhenlinluo authored and tqchen committed Dec 17, 2016
1 parent 6ff0804 commit 50ca991
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mshadow/tensor_cpu-inl.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ inline void MapPlan(TRValue<R, cpu, dim, DType> *dst,
const expr::Plan<E, DType> &plan) {
Shape<2> shape = expr::ShapeCheck<dim, R>::Check(dst->self()).FlatTo2D();
expr::Plan<R, DType> dplan = expr::MakePlan(dst->self());
// #pragma omp parallel for
// temp remove openmp, as default setting throttles CPU
for (index_t y = 0; y < shape[0]; ++y) {
for (index_t x = 0; x < shape[1]; ++x) {
Expand Down Expand Up @@ -279,6 +278,7 @@ template<typename DType>
inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &src,
const Tensor<cpu, 1, DType> &label) {
#pragma omp parallel for
for (index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y]);
for (index_t x = 0; x < dst.size(1); ++x) {
Expand All @@ -296,6 +296,7 @@ inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &src,
const Tensor<cpu, 1, DType> &label,
const DType &ignore_label) {
#pragma omp parallel for
for (index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y]);
for (index_t x = 0; x < dst.size(1); ++x) {
Expand All @@ -316,6 +317,7 @@ template<typename DType>
inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> &src,
const Tensor<cpu, 2, DType> &label) {
#pragma omp parallel for
for (index_t n = 0; n < dst.size(2); ++n) {
for (index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y][n]);
Expand All @@ -335,6 +337,7 @@ inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> &src,
const Tensor<cpu, 2, DType> &label,
const DType &ignore_label) {
#pragma omp parallel for
for (index_t n = 0; n < dst.size(2); ++n) {
for (index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y][n]);
Expand All @@ -359,6 +362,7 @@ template<typename DType>
inline void Softmax(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &energy) {
CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
#pragma omp parallel for
for (index_t y = 0; y < dst.size(0); ++y) {
Softmax(dst[y], energy[y]);
}
Expand All @@ -368,6 +372,7 @@ template<typename DType>
inline void Softmax(Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> &energy) {
CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
#pragma omp parallel for
for (index_t y = 0; y < dst.size(0); ++y) {
for (index_t n = 0; n < dst.size(2); ++n) {
DType mmax = energy[y][0][n];
Expand Down

0 comments on commit 50ca991

Please sign in to comment.