diff --git a/paddle/fluid/operators/cast_op.cc b/paddle/fluid/operators/cast_op.cc index a5ec47d84fe42..72f8cb04f2de3 100644 --- a/paddle/fluid/operators/cast_op.cc +++ b/paddle/fluid/operators/cast_op.cc @@ -63,13 +63,27 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker { } }; +class CastOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + // CastOp kernel's device type is decided by input tensor place + kt.place_ = ctx.Input("X")->place(); + return kt; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; -REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape, - ops::CastOpProtoMaker); +REGISTER_OPERATOR(cast, ops::CastOp, ops::CastOpGradMaker, + ops::CastOpInferShape, ops::CastOpProtoMaker); REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel, ops::CastOpKernel, ops::CastOpKernel, diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index e75a6529e9fa2..00a6f7c237d58 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -19,6 +19,7 @@ import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.framework as framework +import paddle.fluid.core as core def exponential_decay(learning_rate, @@ -81,6 +82,16 @@ def piecewise_decay(global_step, boundaries, values): class TestLearningRateDecay(unittest.TestCase): def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for place in places: + self.check_decay_with_place(place, python_decay_fn, fluid_decay_fn, + kwargs) + + def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn, + kwargs): + decayed_lr = fluid_decay_fn(**kwargs) place = fluid.CPUPlace()