forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
【Hackathon 6th Fundable Projects 3 No.374】fluid operator tdm_sampler (P…
- Loading branch information
Showing
7 changed files
with
406 additions
and
178 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
328 changes: 163 additions & 165 deletions
328
paddle/fluid/operators/tdm_sampler_op.h → paddle/phi/kernels/cpu/tdm_sampler_kernel.cc
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
file( | ||
GLOB func_cc_srcs | ||
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" | ||
"*.cc") | ||
if(WITH_GPU OR WITH_ROCM) | ||
file( | ||
GLOB func_cu_srcs | ||
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" | ||
"*.cu") | ||
endif() | ||
|
||
collect_srcs(kernels_srcs SRCS ${func_cc_srcs} ${func_cu_srcs}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/funcs/math/sampler.h" | ||
|
||
#include <glog/logging.h> | ||
|
||
#include "paddle/phi/core/generator.h" | ||
|
||
namespace phi { | ||
namespace math { | ||
|
||
Sampler::~Sampler() = default; | ||
|
||
UniformSampler::UniformSampler(int64_t range, unsigned int seed) | ||
: Sampler(range, seed), inv_range_(1.0f / (range + 1)) { // NOLINT | ||
random_engine_ = phi::GetCPURandomEngine(seed_); | ||
dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range); | ||
} | ||
|
||
int64_t UniformSampler::Sample() const { return (*dist_)(*random_engine_); } | ||
|
||
float UniformSampler::Probability(int64_t value) const { return inv_range_; } | ||
|
||
LogUniformSampler::LogUniformSampler(int64_t range, unsigned int seed) | ||
: Sampler(range, seed), log_range_(log(range + 1)) { // NOLINT | ||
random_engine_ = phi::GetCPURandomEngine(seed_); | ||
dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1); | ||
} | ||
|
||
int64_t LogUniformSampler::Sample() const { | ||
// Got Log Uniform distribution from uniform distribution by | ||
// inverse_transform_sampling method | ||
// More details: | ||
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler/ | ||
auto cur_random = (*dist_)(*random_engine_); | ||
const int64_t value = static_cast<int64_t>(exp(cur_random * log_range_)) - 1; | ||
// Mathematically, value should be <= range_, but might not be due to some | ||
// floating point roundoff, so we mod by range_. | ||
return value % range_; | ||
} | ||
|
||
float LogUniformSampler::Probability(int64_t value) const { | ||
// Given f(x) = 1/[(x+1) * log_range_] | ||
// The value's probability is integral of f(x) from value to (value + 1) | ||
// More details: | ||
// https://wanghaoshuang.github.io/2017/11/Log-uniform-distribution-sampler | ||
return (log((value + 2.0) / (value + 1.0))) / log_range_; // NOLINT | ||
} | ||
|
||
CustomSampler::CustomSampler(int64_t range, | ||
const float *probabilities, | ||
const int *alias, | ||
const float *alias_probabilities, | ||
unsigned int seed) | ||
: Sampler(range, seed) { | ||
random_engine_ = phi::GetCPURandomEngine(seed_); | ||
real_dist_ = std::make_shared<std::uniform_real_distribution<>>(0, 1); | ||
int_dist_ = std::make_shared<std::uniform_int_distribution<>>(0, range); | ||
|
||
alias_probs_ = alias_probabilities; | ||
probs_ = probabilities; | ||
alias_ = alias; | ||
} | ||
|
||
int64_t CustomSampler::Sample() const { | ||
auto index = (*int_dist_)(*random_engine_); | ||
auto p = (*real_dist_)(*random_engine_); | ||
if (p > alias_probs_[index]) { | ||
int alias = alias_[index]; | ||
|
||
if (alias == exceptional_val) { | ||
LOG(WARNING) << "WARNING: CustomSampler get alias " << exceptional_val; | ||
return index; | ||
} | ||
|
||
return alias; | ||
} else { | ||
return index; | ||
} | ||
} | ||
|
||
float CustomSampler::Probability(int64_t value) const { return probs_[value]; } | ||
|
||
} // namespace math | ||
} // namespace phi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <cstdint> | ||
#include <memory> | ||
#include <random> | ||
#include <vector> | ||
|
||
#include "paddle/phi/core/enforce.h" | ||
|
||
namespace phi { | ||
namespace math { | ||
|
||
// TODO(wanghaoshuang): Support for GPU | ||
|
||
/** | ||
* Sample integers from [0, range). | ||
*/ | ||
class Sampler { | ||
public: | ||
explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) { | ||
PADDLE_ENFORCE_GT( | ||
range, | ||
0, | ||
phi::errors::InvalidArgument( | ||
"Range should be greater than 0, but received %d.", range)); | ||
if (seed == 0) { | ||
std::random_device r; | ||
seed_ = r(); | ||
} else { | ||
seed_ = seed; | ||
} | ||
} | ||
|
||
virtual ~Sampler(); | ||
|
||
// Sample a single value | ||
virtual int64_t Sample() const = 0; | ||
|
||
// The probability that a single call to Sample() returns the given value. | ||
virtual float Probability(int64_t value) const = 0; | ||
|
||
int64_t range() { return range_; } | ||
|
||
protected: | ||
const int64_t range_; | ||
unsigned int seed_; | ||
}; | ||
|
||
/** | ||
* Sample integers from [0, range). | ||
* And the distribution function is: | ||
* P(x) = 1 / range | ||
*/ | ||
class UniformSampler : public Sampler { | ||
public: | ||
explicit UniformSampler(int64_t range, unsigned int seed = 0UL); | ||
|
||
~UniformSampler() override {} | ||
|
||
int64_t Sample() const override; | ||
|
||
float Probability(int64_t value) const override; | ||
|
||
private: | ||
const float inv_range_; | ||
std::shared_ptr<std::mt19937_64> random_engine_; | ||
std::shared_ptr<std::uniform_int_distribution<>> dist_; | ||
}; | ||
|
||
/** | ||
* Sample integers from [0, range). | ||
* And the distribution function is: | ||
* P(x) = (1/ln(range+1)) * ln(1 + 1/(x + 1)) | ||
*/ | ||
class LogUniformSampler : public Sampler { | ||
public: | ||
explicit LogUniformSampler(int64_t range, unsigned int seed = 0UL); | ||
|
||
~LogUniformSampler() override {} | ||
|
||
int64_t Sample() const override; | ||
|
||
float Probability(int64_t value) const override; | ||
|
||
private: | ||
const float log_range_; | ||
std::shared_ptr<std::mt19937_64> random_engine_; | ||
std::shared_ptr<std::uniform_real_distribution<>> dist_; | ||
}; | ||
|
||
/** | ||
* Sample integers from [0, range) from custom distribution. | ||
*/ | ||
class CustomSampler : public Sampler { | ||
public: | ||
explicit CustomSampler(int64_t range, | ||
const float* probabilities, | ||
const int* alias, | ||
const float* alias_probabilities, | ||
unsigned int seed = 0UL); | ||
|
||
~CustomSampler() override {} | ||
|
||
int64_t Sample() const override; | ||
|
||
float Probability(int64_t value) const override; | ||
|
||
private: | ||
const float* alias_probs_; | ||
const int* alias_; | ||
const float* probs_; | ||
const int exceptional_val = -1; | ||
std::shared_ptr<std::mt19937_64> random_engine_; | ||
std::shared_ptr<std::uniform_real_distribution<>> real_dist_; | ||
std::shared_ptr<std::uniform_int_distribution<>> int_dist_; | ||
}; | ||
|
||
} // namespace math | ||
} // namespace phi |