Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: unify PADDLE_ENFORCE for cublas, cudnn, curand #2883

Merged
merged 12 commits into from
Jul 19, 2017
9 changes: 4 additions & 5 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
# ddim lib
cc_library(enforce SRCS enforce.cc DEPS glog)
cc_test(enforce_test SRCS enforce_test.cc DEPS enforce)
cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)

cc_library(tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory)
cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)

cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc)

proto_library(attr_type SRCS attr_type.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)

cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)

cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc enforce)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)

py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/attr_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/framework/enforce.h"
#include "paddle/platform/enforce.h"

namespace paddle {
namespace framework {
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/platform/enforce.h"

namespace paddle {
namespace framework {
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */
#include <stdexcept>
#include <vector>
#include "paddle/framework/dim.h"
#include "paddle/framework/enforce.h"
#include "paddle/platform/enforce.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
Expand Down
15 changes: 0 additions & 15 deletions paddle/framework/enforce.cc

This file was deleted.

75 changes: 0 additions & 75 deletions paddle/framework/enforce.h

This file was deleted.

2 changes: 1 addition & 1 deletion paddle/framework/net_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ TEST(OpKernel, all) {
ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt);

ASSERT_THROW(net->AddOp(op2), paddle::framework::EnforceNotMet);
ASSERT_THROW(net->AddOp(op2), std::runtime_error);
}
10 changes: 5 additions & 5 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ TEST(OpRegistry, IllegalAttr) {
try {
paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
} catch (std::runtime_error& err) {
caught = true;
std::string msg = "larger_than check fail";
const char* err_msg = err.what();
Expand Down Expand Up @@ -138,7 +138,7 @@ TEST(OpRegistry, CustomChecker) {
try {
paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
} catch (std::runtime_error& err) {
caught = true;
std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what();
Expand All @@ -157,7 +157,7 @@ TEST(OpRegistry, CustomChecker) {
try {
paddle::framework::OperatorPtr op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
} catch (std::runtime_error& err) {
caught = true;
std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what();
Expand Down Expand Up @@ -196,7 +196,7 @@ TEST(ProtoMaker, DuplicatedAttr) {
pd::OpProto op_proto;
pd::OpAttrChecker op_checker;
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
}

class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker {
Expand All @@ -212,5 +212,5 @@ TEST(ProtoMaker, DuplicatedInOut) {
pd::OpProto op_proto;
pd::OpAttrChecker op_checker;
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker.Validate(), paddle::framework::EnforceNotMet);
ASSERT_THROW(proto_maker.Validate(), std::runtime_error);
}
2 changes: 1 addition & 1 deletion paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ limitations under the License. */
#include <memory>
#include <typeindex>
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"

Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ TEST(Tensor, DataAssert) {
bool caught = false;
try {
src_tensor.data<double>();
} catch (paddle::framework::EnforceNotMet err) {
} catch (std::runtime_error& err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
Expand Down Expand Up @@ -107,7 +107,7 @@ TEST(Tensor, ShareDataFrom) {
bool caught = false;
try {
dst_tensor.ShareDataFrom<float>(src_tensor);
} catch (EnforceNotMet err) {
} catch (std::runtime_error& err) {
caught = true;
std::string msg =
"Tenosr holds no memory. Call Tensor::mutable_data first.";
Expand Down
5 changes: 2 additions & 3 deletions paddle/memory/detail/system_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License. */

#include "paddle/memory/detail/system_allocator.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/error.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/gpu_info.h"

#include <stdlib.h> // for malloc and free
Expand Down Expand Up @@ -128,8 +128,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) {
// process is terminating, in which case we don't care if
// cudaFree succeeds.
if (err != cudaErrorCudartUnloading) {
platform::throw_on_error(err,
"cudaFree{Host} failed in GPUAllocator::Free.");
PADDLE_ENFORCE(err, "cudaFree{Host} failed in GPUAllocator::Free.");
}
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ cc_test(place_test SRCS place_test.cc DEPS place glog gflags)

add_subdirectory(dynload)

cc_test(enforce_test SRCS enforce_test.cc)

IF(WITH_GPU)
set(GPU_CTX_DEPS dynload_cuda dynamic_loader)
ELSE()
Expand Down
1 change: 0 additions & 1 deletion paddle/platform/cpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License. */
#endif

#include "gflags/gflags.h"
#include "paddle/platform/error.h"

DEFINE_double(fraction_of_cpu_memory_to_use, 1,
"Default use 100% of CPU memory for PaddlePaddle,"
Expand Down
56 changes: 25 additions & 31 deletions paddle/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ limitations under the License. */

#pragma once

#include "paddle/framework/enforce.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"

#ifndef PADDLE_ONLY_CPU
#include "paddle/platform/dynload/cublas.h"
#include "paddle/platform/dynload/cudnn.h"
#include "paddle/platform/dynload/curand.h"
#include "paddle/platform/error.h"
#include "paddle/platform/gpu_info.h"
#define EIGEN_USE_GPU
#endif
Expand Down Expand Up @@ -71,8 +72,7 @@ class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
GPUPlaceGuard guard(gpu_place_);
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
"cudaStreamCreate failed");
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ENFORCE means the condition must be true. But CUXXXX_STATUS_SUCCESS is zero, and false. So it seems that these lines should be

PADDLE_ENFORCE(cudaStreamCreate(&stream_) == CUDA_SUCCESS);

Copy link
Contributor Author

@gangliao gangliao Jul 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, PADDLE_ENFORCE(condition, ...) means condition must execute correctly. No matter the return of condition is 0 or 1, the internal of PADDLE_ENFORCE can figure out how to deal with it.

Copy link
Member

@QiJune QiJune Jul 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @reyoung. I check on how glog does. What CHECK does is to check the condition is true.
And we can implement CUDNN_ENFORCE/CUBLAS_ENFORCE based on PADDLE_ENFORCE for writing related codes more conveniently

Copy link
Collaborator

@wangkuiyi wangkuiyi Jul 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that

PADDLE_ENFORCE(cudaStreamCreate(&stream_))

is shorter than

PADDLE_ENFORCE(cudaStreamCreate(&stream_) == CUDA_SUCCESS);

and it is compatible with the use of the word "enforce" in English, I support that we adopt @gangliao 's proposal -- to make PADDLE_ENFORCE handle the condition.

I agree the English word CHECK should be followed by a boolean value and expects to be true. My point is that ENFORCE should be followed by an action, which is hopefully successful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, @wangkuiyi

PADDLE_ENFORCE(cudaStreamCreate(&stream_) == CUDA_SUCCESS);

will also impede us to print apposite error info for cuda, cudnn, cublas, curand.

eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
Expand All @@ -83,8 +83,8 @@ class CUDADeviceContext : public DeviceContext {
}

void Wait() {
paddle::platform::throw_on_error(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed");
PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
"cudaStreamSynchronize failed");
}

cudaStream_t stream() { return stream_; }
Expand All @@ -94,25 +94,23 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle() {
if (!blas_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_) ==
CUBLAS_STATUS_SUCCESS,
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
"cublasCreate failed");
PADDLE_ENFORCE(paddle::platform::dynload::cublasSetStream(
blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
"cublasSetStream failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
"cublasSetStream failed");
}
return blas_handle_;
}

cudnnHandle_t cudnn_handle() {
if (!dnn_handle_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_) ==
CUDNN_STATUS_SUCCESS,
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
"cudnnCreate failed");
PADDLE_ENFORCE(paddle::platform::dynload::cudnnSetStream(
dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
"cudnnSetStream failed");
PADDLE_ENFORCE(
paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
"cudnnSetStream failed");
}
return dnn_handle_;
}
Expand All @@ -121,43 +119,39 @@ class CUDADeviceContext : public DeviceContext {
if (!rand_generator_) {
GPUPlaceGuard guard(gpu_place_);
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
CURAND_STATUS_SUCCESS,
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
"curandCreateGenerator failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
rand_generator_, random_seed_),
"curandSetPseudoRandomGeneratorSeed failed");
PADDLE_ENFORCE(paddle::platform::dynload::curandSetStream(
rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
"curandSetStream failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
"curandSetStream failed");
}
return rand_generator_;
}

~CUDADeviceContext() {
Wait();
if (blas_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_) ==
CUBLAS_STATUS_SUCCESS,
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
"cublasDestroy failed");
}

if (dnn_handle_) {
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_) ==
CUDNN_STATUS_SUCCESS,
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
"cudnnDestroy failed");
}

if (rand_generator_) {
PADDLE_ENFORCE(paddle::platform::dynload::curandDestroyGenerator(
rand_generator_) == CURAND_STATUS_SUCCESS,
"curandDestroyGenerator failed");
PADDLE_ENFORCE(
paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
"curandDestroyGenerator failed");
}
eigen_stream_.reset();
eigen_device_.reset();
paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
"cudaStreamDestroy failed");
PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
}

private:
Expand Down
2 changes: 1 addition & 1 deletion paddle/platform/dynload/dynamic_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */
#include <string>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/framework/enforce.h"
#include "paddle/platform/enforce.h"

DEFINE_string(cudnn_dir, "",
"Specify path for loading libcudnn.so. For instance, "
Expand Down
Loading