Skip to content

Commit

Permalink
add cuda generator (#26786) (#27014)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoxuefeng6 authored Sep 4, 2020
1 parent 09ede3b commit 0a9f9f9
Show file tree
Hide file tree
Showing 13 changed files with 523 additions and 18 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib

cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc)
cc_library(generator SRCS generator.cc DEPS enforce place)

# Get the current working branch
execute_process(
Expand Down
53 changes: 53 additions & 0 deletions paddle/fluid/framework/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,46 @@ limitations under the License. */
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
namespace framework {

const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(int64_t device_id) {
#ifdef PADDLE_WITH_CUDA

static int64_t num_cuda_devices = -1;
static std::once_flag num_devices_init_flag;
static std::deque<std::once_flag> cuda_device_flags;
static std::vector<std::shared_ptr<Generator>> default_cuda_generators;

std::call_once(num_devices_init_flag, []() {
num_cuda_devices = paddle::platform::GetCUDADeviceCount();
cuda_device_flags.resize(num_cuda_devices);
default_cuda_generators.resize(num_cuda_devices);
});
if (device_id < 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"cuda device id shoule be greater than 0"));
}

std::call_once(cuda_device_flags[device_id], [device_id]() {
default_cuda_generators[device_id] =
std::make_shared<Generator>(GetRandomSeed(), device_id);
VLOG(4) << "initial seed: "
<< default_cuda_generators[device_id]->GetCurrentSeed();
});
return default_cuda_generators[device_id];
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"getDefaultCUDAGenerator only support in CUDA place"));
#endif
}

const std::shared_ptr<Generator>& DefaultCPUGenerator() {
static auto default_cpu_generator =
std::make_shared<Generator>(GetRandomSeed());
Expand Down Expand Up @@ -103,6 +139,7 @@ uint64_t Generator::Seed() {
void Generator::SetCurrentSeed(uint64_t seed) {
std::lock_guard<std::mutex> lock(this->mu_);
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
std::seed_seq seq({seed});
this->engine_->seed(seq);
}
Expand All @@ -123,6 +160,22 @@ uint64_t Generator::Random64() {
return (*engine)();
}

std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
uint64_t increament_offset) {
uint64_t cur_offset = this->state_.thread_offset;
#ifdef PADDLE_WITH_CUDA
std::lock_guard<std::mutex> lock(this->mu_);

this->state_.thread_offset += increament_offset;

#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Increment Offset only support in CUDA place"));
#endif
return std::make_pair(static_cast<int>(this->state_.current_seed),
cur_offset);
}

void Generator::SetIsInitPy(bool is_init_py) {
this->is_init_py_ = is_init_py;
VLOG(4) << "SetIsInitPy:" << this->is_init_py_;
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/framework/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ static uint64_t GetRandomSeed() {
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
uint64_t thread_offset = 0;
std::mt19937_64 cpu_engine;
};

Expand All @@ -49,6 +50,7 @@ struct Generator {
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
Expand All @@ -59,11 +61,25 @@ struct Generator {
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = true; // TODO(zhiqiu): remove it in future
}
Generator(uint64_t seed, uint64_t device_id) {
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->state_.device = device_id;
this->state_.current_seed = seed;
this->state_.thread_offset = 0;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = false; // TODO(zhiqiu): remove it in future
}

Generator(const Generator& other) = delete;

// get random state
Expand All @@ -83,8 +99,11 @@ struct Generator {

uint64_t Random64();

std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t increament_offset);

void SetIsInitPy(bool);
bool GetIsInitPy() const;
uint64_t get_device_id() { return this->state_.device; }

private:
GeneratorState state_;
Expand All @@ -105,5 +124,8 @@ std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine();

std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);

const std::shared_ptr<Generator>& GetDefaultCUDAGenerator(
int64_t device_id = -1);

} // namespace framework
} // namespace paddle
1 change: 0 additions & 1 deletion paddle/fluid/operators/bernoulli_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/bernoulli_op.h"
Expand Down
47 changes: 47 additions & 0 deletions paddle/fluid/operators/dropout_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,42 @@ __global__ void RandomGeneratorWithSeed(const size_t n, const int* seed,
}
}

template <typename T, typename MaskType>
__global__ void RandomGeneratorWithGenerator(const size_t n, uint64_t seed,
const float dropout_prob,
const T* src, MaskType* mask_data,
T* dst, bool is_upscale_in_train,
uint64_t increment) {
curandStatePhilox4_32_10_t state;
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int step_size = 0;

MaskType mask;
T dest;
for (; idx < n; idx += blockDim.x * gridDim.x) {
T s = src[idx];
if (step_size == 0) {
curand_init(seed, idx, increment, &state);
step_size = blockDim.x * gridDim.x;
} else {
curand_init(seed, idx, increment, &state);
}
if (curand_uniform(&state) < dropout_prob) {
mask = 0;
dest = 0;
} else {
mask = 1;
if (is_upscale_in_train) {
dest = s / static_cast<T>(1.0f - dropout_prob);
} else {
dest = s;
}
}
mask_data[idx] = mask;
dst[idx] = dest;
}
}

// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
Expand Down Expand Up @@ -150,6 +186,17 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
}

int device_id = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace())
.GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && (!context.Attr<bool>("fix_seed"))) {
auto seed_offset = gen_cuda->IncrementOffset(1);
RandomGeneratorWithGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_offset.first, dropout_prob, x_data, mask_data, y_data,
upscale_in_train, seed_offset.second);
return;
}

RandomGenerator<T, uint8_t><<<grid, threads, 0, stream>>>(
size, seed_data, dropout_prob, x_data, mask_data, y_data,
upscale_in_train);
Expand Down
60 changes: 53 additions & 7 deletions paddle/fluid/operators/gaussian_random_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fill_constant_op.h"
Expand All @@ -24,15 +25,20 @@ template <typename T>
struct GaussianGenerator {
T mean_, std_;
unsigned int seed_;
unsigned int offset_ = 0;

__host__ __device__ GaussianGenerator(T mean, T std, int seed)
: mean_(mean), std_(std), seed_(seed) {}

__host__ __device__ GaussianGenerator(T mean, T std, int seed, int offset)
: mean_(mean), std_(std), seed_(seed), offset_(offset) {}

__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::normal_distribution<T> dist(mean_, std_);
rng.discard(n);
unsigned int new_n = n + offset_;
rng.discard(new_n);
return dist(rng);
}
};
Expand All @@ -43,9 +49,11 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
Expand All @@ -56,9 +64,27 @@ class GPUGaussianRandomKernel : public framework::OpKernel<T> {
T* data = tensor->mutable_data<T>(context.GetPlace());

int64_t size = tensor->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));

int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int offset_step = 100;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int gen_offset = offset_step * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first, gen_offset));
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
}
}
};

Expand All @@ -69,17 +95,37 @@ class GPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));

int device_id =
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);

if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int offset_step = 100;
// NOTE(xuefeng): Currently, we let offset step fixed to avoid
// unexpected results which may cause ut fail.
// we will fix this in future.
int gen_offset = offset_step * seed_offset.second;
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed_offset.first,
seed_offset.second));
} else {
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GaussianGenerator<T>(mean, std, seed));
}
}
};
} // namespace operators
Expand Down
11 changes: 10 additions & 1 deletion paddle/fluid/operators/randint_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/uniform_random_op.h"

Expand Down Expand Up @@ -49,15 +50,23 @@ class GPURandintKernel : public framework::OpKernel<T> {

int64_t size = out->numel();
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));

/*
std::minstd_rand engine;
if (seed == 0) {
std::random_device rd;
seed = rd();
}
engine.seed(seed);
*/

std::uniform_int_distribution<> dist(context.Attr<int>("low"),
context.Attr<int>("high") - 1);
for (int64_t i = 0; i < size; ++i) data[i] = dist(engine);
auto engine = framework::GetCPURandomEngine(seed);

for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}

if (platform::is_gpu_place(context.GetPlace())) {
// Copy tensor to out
Expand Down
Loading

0 comments on commit 0a9f9f9

Please sign in to comment.