Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize verbs with static memory registration and multi request buffers #38

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 64 additions & 18 deletions tensorflow_networking/verbs/BUILD
Original file line number Diff line number Diff line change
@@ -1,13 +1,41 @@
# Description:
# Verbs RDMA communication interfaces and implementations for TensorFlow.

package(default_visibility = [
"//tensorflow_networking:__subpackages__",
])
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cuda_library")

licenses(["notice"]) # Apache 2.0
# For platform specific build config
load(
"@org_tensorflow//tensorflow/core/platform:default/build_config.bzl",
"tf_proto_library_cc",
)

load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cuda_library")
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_cc_binary",
"tf_cc_test",
"tf_cuda_library",
)

load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")

# For platform specific build config
load(
"@org_tensorflow//tensorflow/core/platform:default/build_config.bzl",
"tf_kernel_tests_linkstatic",
)

load(
"@org_tensorflow//tensorflow/core/platform:default/build_config_root.bzl",
"tf_cuda_tests_tags",
)

package(
default_visibility = [
"//tensorflow_networking:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)

exports_files(["LICENSE"])

Expand All @@ -19,12 +47,6 @@ filegroup(
]),
)

# For platform specific build config
load(
"@org_tensorflow//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library_cc",
)

tf_proto_library_cc(
name = "verbs_service_proto",
srcs = ["verbs_service.proto"],
Expand All @@ -43,6 +65,10 @@ cc_library(
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:lib",
],
linkopts = select({
"@org_tensorflow//tensorflow:with_verbs_support": ["-libverbs"],
"//conditions:default": [],
}),
)

cc_library(
Expand All @@ -52,9 +78,10 @@ cc_library(
deps = [
":grpc_verbs_service_impl",
":rdma_mgr",
":rdma",
":verbs_service_proto_cc",
"@org_tensorflow//tensorflow:grpc++",
"@org_tensorflow//tensorflow/core:lib",
#"@org_tensorflow//tensorflow/core:lib_internal",
"@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_call",
Expand All @@ -77,6 +104,7 @@ cc_library(
name = "grpc_verbs_client",
srcs = ["grpc_verbs_client.cc"],
hdrs = ["grpc_verbs_client.h"],
copts = ["-Og", "-g3"],
deps = [
":grpc_verbs_service_impl",
":verbs_service_proto_cc",
Expand All @@ -90,49 +118,66 @@ cc_library(
cc_library(
name = "rdma_rendezvous_mgr",
srcs = ["rdma_rendezvous_mgr.cc"],
hdrs = ["rdma_rendezvous_mgr.h"],
hdrs = ["rdma_rendezvous_mgr.h", "rdma.h"],
copts = ["-Og", "-g3"],
deps = [
":rdma_mgr",
":verbs_util",
"@org_tensorflow//tensorflow/core",
#"@org_tensorflow//tensorflow/core:core_cpu_internal",
#"@org_tensorflow//tensorflow/core:gpu_runtime",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
"@org_tensorflow//tensorflow/core/distributed_runtime:worker_env",
#"@org_tensorflow//tensorflow/core/distributed_runtime:worker_cache_partial",
],
)

tf_cuda_library(
name = "rdma_mgr",
srcs = ["rdma_mgr.cc"],
hdrs = ["rdma_mgr.h"],
hdrs = ["rdma_mgr.h", "rdma.h"],
copts = ["-Og", "-g3"],
deps = [
":grpc_verbs_client",
":rdma",
#":rdma",
":verbs_util",
":verbs_service_proto_cc",
"@org_tensorflow//tensorflow/core",
#"@org_tensorflow//tensorflow/core:core_cpu_internal",
"@org_tensorflow//tensorflow/core:lib",
#"@org_tensorflow//tensorflow/core:lib_internal",
"@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr",
"@org_tensorflow//tensorflow/core/distributed_runtime:worker_env",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
],
)


tf_cuda_library(
name = "rdma",
srcs = ["rdma.cc"],
hdrs = ["rdma.h"],
linkopts = ["-libverbs"],
linkopts = select({
"@org_tensorflow//tensorflow:with_verbs_support": ["-libverbs"],
"//conditions:default": [],
}),
copts = ["-Og", "-g3"],
deps = [
":rdma_mgr",
":grpc_verbs_client",
":verbs_service_proto_cc",
":verbs_util",
"@org_tensorflow//tensorflow/core",
#"@org_tensorflow//tensorflow/core:core_cpu_internal",
"@org_tensorflow//tensorflow/core:framework",
#"@org_tensorflow//tensorflow/core:framework_internal",
#"@org_tensorflow//tensorflow/core:gpu_runtime",
"@org_tensorflow//tensorflow/core:lib",
#"@org_tensorflow//tensorflow/core:lib_internal",
"@org_tensorflow//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr",
"@org_tensorflow//tensorflow/core/distributed_runtime:worker_env",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_channel",
],
)

Expand All @@ -151,3 +196,4 @@ cc_library(
],
alwayslink = 1,
)

82 changes: 0 additions & 82 deletions tensorflow_networking/verbs/Dockerfile

This file was deleted.

13 changes: 0 additions & 13 deletions tensorflow_networking/verbs/docker_howto.txt

This file was deleted.

33 changes: 33 additions & 0 deletions tensorflow_networking/verbs/grpc_verbs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,39 @@ Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request,
return GetRemoteAddress(&call_options, request, response);
}


Status GrpcVerbsClient::ReqDriverMessage(CallOptions* call_options,
const DriverMessageReq* request,
DriverMessageResp* response) {
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->ReqDriverMessage(&ctx, *request, response));
}

Status GrpcVerbsClient::ReqDriverMessage(const DriverMessageReq* request,
DriverMessageResp* response) {
CallOptions call_options;
call_options.SetTimeout(-1); // no time out
return ReqDriverMessage(&call_options, request, response);
}

Status GrpcVerbsClient::ReqPleSendOrCheck(CallOptions* call_options,
const PleSendOrCheckReq* request,
PleSendOrCheckResp* response) {
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->ReqPleSendOrCheck(&ctx, *request, response));
}

Status GrpcVerbsClient::ReqPleSendOrCheck(const PleSendOrCheckReq* request,
PleSendOrCheckResp* response) {
CallOptions call_options;
call_options.SetTimeout(-1); // no time out
return ReqPleSendOrCheck(&call_options, request, response);
}

void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
int64 time_in_ms) {
if (time_in_ms > 0) {
Expand Down
17 changes: 15 additions & 2 deletions tensorflow_networking/verbs/grpc_verbs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_
#define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_

#include "tensorflow_networking/verbs/grpc_verbs_service_impl.h"
#include "tensorflow_networking/verbs/verbs_service.pb.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow_networking/verbs/grpc_verbs_service_impl.h"
#include "tensorflow_networking/verbs/verbs_service.pb.h"

namespace tensorflow {

Expand All @@ -37,6 +37,19 @@ class GrpcVerbsClient {
Status GetRemoteAddress(const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response);

Status ReqDriverMessage(CallOptions* call_options,
const DriverMessageReq* request,
DriverMessageResp* response);
Status ReqDriverMessage(const DriverMessageReq* request,
DriverMessageResp* response);

Status ReqPleSendOrCheck(CallOptions* call_options,
const PleSendOrCheckReq* request,
PleSendOrCheckResp* response);

Status ReqPleSendOrCheck(const PleSendOrCheckReq* request,
PleSendOrCheckResp* response);

private:
std::unique_ptr<grpc::VerbsService::Stub> stub_;

Expand Down
Loading