Skip to content

Commit

Permalink
Merge pull request apache#65 from piiswrong/master
Browse files Browse the repository at this point in the history
multi output softmax
  • Loading branch information
tqchen committed Oct 24, 2015
2 parents f00b208 + d298f4f commit ded43f1
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 0 deletions.
71 changes: 71 additions & 0 deletions mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,77 @@ inline void SoftmaxGrad(Tensor<gpu, 2, DType> &dst,
expr::MakePlan(label),
dst.size(1));
}

template<typename DType>
__global__ void Softmax3DGradKernel(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label) {
const index_t xmax = dst.size(1);
const int y = blockIdx.x;
const int n = threadIdx.x;

if (n < dst.size(2)) {
const int k = static_cast<int>(label[y][n]);
for (index_t i = 0; i < xmax; ++i) {
if (i == k) {
dst[y][i][n] = src[y][i][n] - 1.0f;
} else {
dst[y][i][n] = src[y][i][n];
}
}
}
}

template<typename DType>
__global__ void Softmax3DKernel(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src) {
const index_t xmax = dst.size(1);
const int y = blockIdx.x;
const int n = threadIdx.x;

if (n < dst.size(2)) {
DType smax = src[y][0][n];
for (index_t i = 1; i < xmax; ++i) {
smax = max(smax, src[y][i][n]);
}
DType ssum = 0.0f;
for (index_t i = 0; i < xmax; ++i) {
DType p = expf(src[y][i][n] - smax);
ssum += p;
dst[y][i][n] = p;
}
for (index_t i = 0; i < xmax; ++i) {
dst[y][i][n] /= ssum;
}
}
}

template<typename DType>
inline void Softmax(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src) {
dim3 dimBlock(kBaseThreadNum);
dim3 dimGrid(dst.size(0), dst.size(2));
CHECK_EQ(dst.shape_, src.shape_) << "Softmax: shape mismatch";
CheckLaunchParam(dimGrid, dimBlock, "Softmax");
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
Softmax3DKernel<DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src);
}


template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 3, DType> &dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label) {
dim3 dimBlock(kBaseThreadNum);
dim3 dimGrid(dst.size(0), dst.size(2));
CHECK_EQ(dst.shape_, src.shape_) << "SoftmaxGrad: shape mismatch";
CHECK_EQ(dst.size(0), label.size(0)) << "SoftmaxGrad: label shape mismatch";
CHECK_EQ(dst.size(2), label.size(1)) << "SoftmaxGrad: label shape mismatch";
CheckLaunchParam(dimGrid, dimBlock, "SoftmaxGrad");
cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
Softmax3DGradKernel<DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src, label);
}

} // namespace cuda
} // namespace mshadow
#endif // MSHADOW_CUDA_TENSOR_GPU_INL_CUH_
4 changes: 4 additions & 0 deletions mshadow/tensor_blob.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,10 @@ class TBlob {
inline index_t size(index_t idx) const {
return shape_[idx];
}
/*! \brief total number of elements in the tensor */
inline index_t Size(void) const {
return shape_.Size();
}
/*!
* \brief fetch the tensor, with respect to specific dimension
* if dim do not match the stored dimension, an error will be issued
Expand Down
40 changes: 40 additions & 0 deletions mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,24 @@ inline void SoftmaxGrad(Tensor<cpu, 2, DType> dst,
}
}

template<typename DType>
inline void SoftmaxGrad(Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> &src,
const Tensor<cpu, 2, DType> &label) {
for (index_t n = 0; n < dst.size(2); ++n) {
for (index_t y = 0; y < dst.size(0); ++y) {
const index_t k = static_cast<int>(label[y][n]);
for (index_t x = 0; x < dst.size(1); ++x) {
if (x == k) {
dst[y][k][n] = src[y][k][n] - 1.0f;
} else {
dst[y][x][n] = src[y][x][n];
}
}
}
}
}

template<typename DType>
inline void Softmax(Tensor<cpu, 2, DType> dst,
const Tensor<cpu, 2, DType> &energy) {
Expand All @@ -291,6 +309,28 @@ inline void Softmax(Tensor<cpu, 2, DType> dst,
}
}

template<typename DType>
inline void Softmax(Tensor<cpu, 3, DType> dst,
const Tensor<cpu, 3, DType> &energy) {
CHECK_EQ(dst.shape_, energy.shape_) << "Softmax: shape mismatch";
for (index_t y = 0; y < dst.size(0); ++y) {
for (index_t n = 0; n < dst.size(2); ++n) {
DType mmax = energy[y][0][n];
for (index_t x = 1; x < dst.size(1); ++x) {
if (mmax < energy[y][x][n]) mmax = energy[y][x][n];
}
DType sum = 0.0f;
for (index_t x = 0; x < dst.size(1); ++x) {
dst[y][x][n] = std::exp(energy[y][x][n] - mmax);
sum += dst[y][x][n];
}
for (index_t x = 0; x < dst.size(1); ++x) {
dst[y][x][n] /= sum;
}
}
}
}

template<typename DType>
inline DType VDot(const Tensor<cpu, 1, DType> &lhs,
const Tensor<cpu, 1, DType> &rhs) {
Expand Down
13 changes: 13 additions & 0 deletions mshadow/tensor_gpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,26 @@ inline void Softmax(Tensor<gpu, 2, DType> dst,
cuda::Softmax(dst, src);
}

template<typename DType>
inline void Softmax(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType>& src) {
cuda::Softmax(dst, src);
}

template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 2, DType> dst,
const Tensor<gpu, 2, DType> &src,
const Tensor<gpu, 1, DType> &label) {
cuda::SoftmaxGrad(dst, src, label);
}

template<typename DType>
inline void SoftmaxGrad(Tensor<gpu, 3, DType> dst,
const Tensor<gpu, 3, DType> &src,
const Tensor<gpu, 2, DType> &label) {
cuda::SoftmaxGrad(dst, src, label);
}

} // namespace mshadow
#endif // __CUDACC__
#endif // MSHADOW_TENSOR_GPU_INL_H_

0 comments on commit ded43f1

Please sign in to comment.