Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op] Implement GPU version of KvResourceSparseApplyAdam. #535

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions tensorflow/core/kernels/training_ali_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,140 @@ TF_CALL_float(REGISTER_GPU_KERNELS);
#endif // TENSORFLOW_USE_GPU_EV
#endif // GOOGLE_CUDA

#if GOOGLE_CUDA
#if TENSORFLOW_USE_GPU_EV
template <typename Device, typename T, typename Tindex, typename Tstep>
class KvSparseApplyAdamOpGPU : public OpKernel {
public:
explicit KvSparseApplyAdamOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
}

void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
auto locks = MaybeLockEmbeddingVariableInputMutexesInOrder<Tindex, T>(
ctx, use_exclusive_lock_, {0, 1, 2, 3, 4});
EmbeddingVarGPU<Tindex, T>* var = nullptr;
OP_REQUIRES_OK(ctx, GetInputEmbeddingVarGPU(ctx, 0, &var));
core::ScopedUnref unref_var(var);

EmbeddingVarGPU<Tindex, T>* m = nullptr;
OP_REQUIRES_OK(ctx, GetInputEmbeddingVarGPU(ctx, 1, &m));
core::ScopedUnref unref_m(m);

EmbeddingVarGPU<Tindex, T>* v = nullptr;
OP_REQUIRES_OK(ctx, GetInputEmbeddingVarGPU(ctx, 2, &v));
core::ScopedUnref unref_v(v);

Tensor beta1_power;
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
ctx, 3, use_exclusive_lock_, true, &beta1_power));

Tensor beta2_power;
OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
ctx, 4, use_exclusive_lock_, true, &beta2_power));
OP_REQUIRES(
ctx, beta1_power.IsInitialized(),
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", requested_input(3)));
OP_REQUIRES(
ctx, beta2_power.IsInitialized(),
errors::FailedPrecondition(
"Attempting to use uninitialized variables: ", requested_input(4)));

const Tensor& lr = ctx->input(5);
const Tensor& beta1 = ctx->input(6);
const Tensor& beta2 = ctx->input(7);
const Tensor& epsilon = ctx->input(8);
const Tensor& grad = ctx->input(9);
const Tensor& indices = ctx->input(10);
const Tensor& global_step = ctx->input(11);

OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(lr.shape()),
errors::InvalidArgument("lr is not a scalar: ",
lr.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(beta1.shape()),
errors::InvalidArgument("beta1 is not a scalar: ",
beta1.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(beta2.shape()),
errors::InvalidArgument("beta2 is not a scalar: ",
beta2.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon.shape().DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(indices.shape()),
errors::InvalidArgument("indices must be one-dimensional"));

int64 inner_dim = 1;
TensorShape var_shape({var->ValueLen()});
for (int d = 0; d < var_shape.dims(); d++) {
OP_REQUIRES(
ctx, var_shape.dim_size(d) == grad.dim_size(d + 1),
errors::InvalidArgument(strings::StrCat(
"var and grad must match in dimension ", d + 1)));
inner_dim *= grad.dim_size(d + 1);
}
OP_REQUIRES(
ctx, inner_dim > 0,
errors::InvalidArgument(
"Inner dimension should be greater than zero."));

OP_REQUIRES(
ctx, IsLegacyScalar(global_step.shape()),
errors::InvalidArgument(
"global_step is not a scalar: ", global_step.shape().DebugString()));

const Tindex N = indices.dim_size(0);
OP_REQUIRES(
ctx, grad.dim_size(0) == N,
errors::InvalidArgument(
"grad must be the same size as indices in the first dimension."));

const Device& device = ctx->eigen_device<Device>();
OP_REQUIRES_OK(ctx,
functor::KvSparseApplyAdamAsync<Device, T, Tindex, Tstep>()(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里直接调用AdamAsync的functor吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,我注意到AdamAsync的functor中逻辑代码与Adam相同,只是AdamAsync多了apply_sparse_rmsprop参数,于是考虑复用functor逻辑代码,不同在于将Adam 中调用的functor 中apply_sparse_rmsprop 置为了false

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,我建议新写一个functor::KvSparseApplyAdam,而不是复用这个,避免别人改AdamAsync的functor而影响到这里的逻辑正确性

device, var, m, v, beta1_power.scalar<T>(), beta2_power.scalar<T>(),
indices.vec<Tindex>(), grad.flat_outer_dims<T>(), lr.scalar<T>(),
beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
global_step.scalar<Tstep>(), false, inner_dim,
ctx->get_allocator(AllocatorAttributes())));
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}

private:
bool use_exclusive_lock_;
};

#define REGISTER_KERNELS(D, T, Tindices, Tstep) \
REGISTER_KERNEL_BUILDER(Name("KvResourceSparseApplyAdam") \
.Device(DEVICE_##D) \
.HostMemory("global_step") \
.TypeConstraint<T>("T") \
.TypeConstraint<Tindices>("Tindices") \
.TypeConstraint<Tstep>("Tstep"), \
KvSparseApplyAdamOpGPU<D##Device, T, Tindices, Tstep>);

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNEL(T) \
REGISTER_KERNELS(GPU, T, int32, int32); \
REGISTER_KERNELS(GPU, T, int32, int64); \
REGISTER_KERNELS(GPU, T, int64, int32); \
REGISTER_KERNELS(GPU, T, int64, int64);

TF_CALL_float(REGISTER_GPU_KERNEL);
TF_CALL_double(REGISTER_GPU_KERNEL);

#undef REGISTER_GPU_KERNEL
#endif // End of GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#undef REGISTER_KERNELS

#endif // TENSORFLOW_USE_GPU_EV
#endif // GOOGLE_CUDA

// Note, this op works on cpu only.
template <typename Device, typename TKey, typename T, bool has_l2_shrinkage>
class KvSparseApplyFtrlOp : public OpKernel {
Expand Down