diff --git a/tensorflow/core/kernels/training_ali_ops.cc b/tensorflow/core/kernels/training_ali_ops.cc index dc0d4de6a4c..790ac90428f 100644 --- a/tensorflow/core/kernels/training_ali_ops.cc +++ b/tensorflow/core/kernels/training_ali_ops.cc @@ -481,6 +481,140 @@ TF_CALL_float(REGISTER_GPU_KERNELS); #endif // TENSORFLOW_USE_GPU_EV #endif // GOOGLE_CUDA +#if GOOGLE_CUDA +#if TENSORFLOW_USE_GPU_EV +template +class KvSparseApplyAdamOpGPU : public OpKernel { + public: + explicit KvSparseApplyAdamOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { + auto locks = MaybeLockEmbeddingVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, {0, 1, 2, 3, 4}); + EmbeddingVarGPU* var = nullptr; + OP_REQUIRES_OK(ctx, GetInputEmbeddingVarGPU(ctx, 0, &var)); + core::ScopedUnref unref_var(var); + + EmbeddingVarGPU* m = nullptr; + OP_REQUIRES_OK(ctx, GetInputEmbeddingVarGPU(ctx, 1, &m)); + core::ScopedUnref unref_m(m); + + EmbeddingVarGPU* v = nullptr; + OP_REQUIRES_OK(ctx, GetInputEmbeddingVarGPU(ctx, 2, &v)); + core::ScopedUnref unref_v(v); + + Tensor beta1_power; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( + ctx, 3, use_exclusive_lock_, true, &beta1_power)); + + Tensor beta2_power; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( + ctx, 4, use_exclusive_lock_, true, &beta2_power)); + OP_REQUIRES( + ctx, beta1_power.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", requested_input(3))); + OP_REQUIRES( + ctx, beta2_power.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", requested_input(4))); + + const Tensor& lr = ctx->input(5); + const Tensor& beta1 = ctx->input(6); + const Tensor& beta2 = ctx->input(7); + const Tensor& epsilon = ctx->input(8); + const Tensor& grad = ctx->input(9); + const Tensor& indices = ctx->input(10); + const Tensor& global_step = ctx->input(11); + + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar: ", + lr.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(beta1.shape()), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(beta2.shape()), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(epsilon.shape()), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(indices.shape()), + errors::InvalidArgument("indices must be one-dimensional")); + + int64 inner_dim = 1; + TensorShape var_shape({var->ValueLen()}); + for (int d = 0; d < var_shape.dims(); d++) { + OP_REQUIRES( + ctx, var_shape.dim_size(d) == grad.dim_size(d + 1), + errors::InvalidArgument(strings::StrCat( + "var and grad must match in dimension ", d + 1))); + inner_dim *= grad.dim_size(d + 1); + } + OP_REQUIRES( + ctx, inner_dim > 0, + errors::InvalidArgument( + "Inner dimension should be greater than zero.")); + + OP_REQUIRES( + ctx, IsLegacyScalar(global_step.shape()), + errors::InvalidArgument( + "global_step is not a scalar: ", global_step.shape().DebugString())); + + const Tindex N = indices.dim_size(0); + OP_REQUIRES( + ctx, grad.dim_size(0) == N, + errors::InvalidArgument( + "grad must be the same size as indices in the first dimension.")); + + const Device& device = ctx->eigen_device(); + OP_REQUIRES_OK(ctx, + functor::KvSparseApplyAdamAsync()( + device, var, m, v, beta1_power.scalar(), beta2_power.scalar(), + indices.vec(), grad.flat_outer_dims(), lr.scalar(), + beta1.scalar(), beta2.scalar(), epsilon.scalar(), + global_step.scalar(), false, inner_dim, + ctx->get_allocator(AllocatorAttributes()))); + MaybeForwardRefInputToRefOutput(ctx, 0, 0); + } + + private: + bool use_exclusive_lock_; +}; + +#define REGISTER_KERNELS(D, T, Tindices, Tstep) \ + REGISTER_KERNEL_BUILDER(Name("KvResourceSparseApplyAdam") \ + .Device(DEVICE_##D) \ + .HostMemory("global_step") \ + .TypeConstraint("T") \ + .TypeConstraint("Tindices") \ + .TypeConstraint("Tstep"), \ + KvSparseApplyAdamOpGPU); + +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNELS(GPU, T, int32, int32); \ + REGISTER_KERNELS(GPU, T, int32, int64); \ + REGISTER_KERNELS(GPU, T, int64, int32); \ + REGISTER_KERNELS(GPU, T, int64, int64); + +TF_CALL_float(REGISTER_GPU_KERNEL); +TF_CALL_double(REGISTER_GPU_KERNEL); + +#undef REGISTER_GPU_KERNEL +#endif // End of GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#undef REGISTER_KERNELS + +#endif // TENSORFLOW_USE_GPU_EV +#endif // GOOGLE_CUDA + // Note, this op works on cpu only. template class KvSparseApplyFtrlOp : public OpKernel {