Skip to content

Commit

Permalink
fix by code review
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 committed Mar 14, 2022
1 parent dbbeafe commit 310578d
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 12 deletions.
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/cumprod_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/cumprod_kernel.h"

#include "paddle/fluid/operators/math/inclusive_scan.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cumprod_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/cumprod.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/isclose_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
// limitations under the License.

#include "paddle/phi/kernels/isclose_kernel.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/isclose_kernel_impl.h"

PD_REGISTER_KERNEL(
isclose, GPU, ALL_LAYOUT, phi::IscloseKernel, float, double) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/kldiv_loss_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/kldiv_loss_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kldiv_loss_grad_kernel_impl.h"
#include "paddle/phi/kernels/kldiv_loss_grad_kernel.h"
PD_REGISTER_KERNEL(
kldiv_loss_grad, GPU, ALL_LAYOUT, phi::KLDivLossGradKernel, float, double) {
}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/kldiv_loss_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/kldiv_loss_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/kldiv_loss_kernel_impl.h"
#include "paddle/phi/kernels/kldiv_loss_kernel.h"
PD_REGISTER_KERNEL(
kldiv_loss, GPU, ALL_LAYOUT, phi::KLDivLossKernel, float, double) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/lgamma_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/lgamma_grad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lgamma_grad_kernel_impl.h"
#include "paddle/phi/kernels/lgamma_grad_kernel.h"
PD_REGISTER_KERNEL(
lgamma_grad, GPU, ALL_LAYOUT, phi::LgammaGradKernel, float, double) {}
3 changes: 2 additions & 1 deletion paddle/phi/kernels/gpu/lgamma_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/lgamma_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/lgamma_kernel.h"

namespace phi {
template <typename T>
Expand Down
32 changes: 26 additions & 6 deletions paddle/phi/ops/compat/isclose_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,32 @@
namespace phi {

KernelSignature IscloseOpArgumentMapping(const ArgumentMappingContext& ctx) {
std::string rtol_map_str = (ctx.HasInput("Rtol") ? "Rtol" : "rtol");
std::string atol_map_str = (ctx.HasInput("Atol") ? "Atol" : "atol");
return KernelSignature("isclose",
{"Input", "Other"},
{rtol_map_str, atol_map_str, "equal_nan"},
{"Out"});
if (ctx.HasInput("Rtol")) {
if (ctx.HasInput("Atol")) {
return KernelSignature("isclose",
{"Input", "Other"},
{"Rtol", "Atol", "equal_nan"},
{"Out"});

} else {
return KernelSignature("isclose",
{"Input", "Other"},
{"Rtol", "atol", "equal_nan"},
{"Out"});
}
} else {
if (ctx.HasInput("Atol")) {
return KernelSignature("isclose",
{"Input", "Other"},
{"rtol", "Atol", "equal_nan"},
{"Out"});
} else {
return KernelSignature("isclose",
{"Input", "Other"},
{"rtol", "atol", "equal_nan"},
{"Out"});
}
}
}

} // namespace phi
Expand Down

0 comments on commit 310578d

Please sign in to comment.