Skip to content

Commit

Permalink
Merge pull request #10380 from panyx0718/dist_timeline
Browse files Browse the repository at this point in the history
timeline for distributed training
  • Loading branch information
panyx0718 authored May 7, 2018
2 parents 0c51888 + d1ea74d commit dce0732
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 52 deletions.
28 changes: 22 additions & 6 deletions benchmark/cluster/vgg16/vgg16_fluid.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def str2bool(v):
type=str,
default="",
help="Comma-separated list of hostname:port pairs")
parser.add_argument(
"--profile", action='store_true', help="If set, profile a few steps.")

# Flags for defining the tf.train.Server
parser.add_argument(
Expand Down Expand Up @@ -183,8 +185,8 @@ def train_loop(exe, trainer_prog):
start_time = time.time()
num_samples = 0
train_pass_acc.reset()
for batch_id, data in enumerate(train_reader()):
ts = time.time()

def run_step(batch_id, data):
img_data = np.array(
map(lambda x: x[0].reshape(data_shape), data)).astype(
"float32")
Expand All @@ -196,14 +198,28 @@ def train_loop(exe, trainer_prog):
feed={"pixel": img_data,
"label": y_data},
fetch_list=[avg_cost, batch_acc, batch_size])
return loss, acc, b_size

if args.profile and args.task_index == 0:
# warmup.
for batch_id, data in enumerate(train_reader()):
if batch_id > 5: break
run_step(batch_id, data)
with profiler.profiler('All', 'total', '/tmp/profile_vgg'):
for batch_id, data in enumerate(train_reader()):
if batch_id > 5: break
run_step(batch_id, data)

for batch_id, data in enumerate(train_reader()):
ts = time.time()
loss, acc, b_size = run_step(batch_id, data)
iters += 1
num_samples += len(data)
train_pass_acc.add(value=acc, weight=b_size)
print(
"Task:%d Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s " % (args.task_index, pass_id, iters,
loss, acc,
len(data) / (time.time() - ts))
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s" % (pass_id, iters, loss, acc,
len(data) / (time.time() - ts))
) # The accuracy is the accumulation of batches, but not the current batch.

pass_elapsed = time.time() - start_time
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/detail/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ message VariableMessage {
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If true, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from true to false.
bool profile = 11;
}

message VoidMessage {}
8 changes: 8 additions & 0 deletions paddle/fluid/operators/detail/sendrecvop_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/bytebuffer_stream.h"
#include "paddle/fluid/operators/detail/proto_encoder_helper.h"
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/platform/profiler.h"

namespace paddle {
namespace operators {
Expand All @@ -45,6 +46,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void* payload = nullptr;
size_t payload_size;
ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
// Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
if (platform::ShouldSendProfileState()) {
e.WriteBool(VarMsg::kProfileFieldNumber, platform::IsProfileEnabled());
}
e.WriteString(VarMsg::kVarnameFieldNumber, name);
if (var->IsType<framework::LoDTensor>()) {
e.WriteUint64(VarMsg::kTypeFieldNumber, 0);
Expand Down
22 changes: 21 additions & 1 deletion paddle/fluid/operators/detail/variable_response.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/profiler.h"

#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
Expand Down Expand Up @@ -427,7 +428,26 @@ int VariableResponse::Parse(Source* source) {
meta_.set_out_varname(temp);
break;
}

case sendrecv::VariableMessage::kProfileFieldNumber: {
bool profiling;
if (!input.ReadRaw(reinterpret_cast<void*>(&profiling), 1)) {
return tag;
}
meta_.set_profile(profiling);
int64_t listener_id = platform::ListenerId();
if (listener_id <= 0) {
break;
}
if (profiling && !platform::IsProfileEnabled()) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
} else if (!profiling && platform::IsProfileEnabled()) {
// TODO(panyx0718): Should we allow to customize file dir.
platform::DisableProfiler(
platform::EventSortingKey::kDefault,
string::Sprintf("/tmp/profile_ps_%lld", listener_id));
}
break;
}
default: {
// Unknown tag, return unknown error.
return -1;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/listen_and_serv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>

#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -294,6 +295,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,

void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer.
platform::SetProfileListener();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
Expand Down
35 changes: 26 additions & 9 deletions paddle/fluid/platform/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/profiler.h"

#include <sys/time.h>
#include <time.h>
#include <algorithm>
#include <iomanip>
#include <limits>
#include <map>
#include <mutex> // NOLINT
#include <random>
#include <string>
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
Expand All @@ -33,6 +36,9 @@ namespace platform {

struct EventList;

static int64_t profiler_lister_id = 0;
static bool should_send_profile_state = false;

// The profiler state, the initial value is ProfilerState::kDisabled
static ProfilerState g_state = ProfilerState::kDisabled;
// The thread local event list only can be accessed by the specific thread
Expand Down Expand Up @@ -219,13 +225,12 @@ void EnableProfiler(ProfilerState state) {
PADDLE_ENFORCE(state != ProfilerState::kDisabled,
"Can't enbale profling, since the input state is ",
"ProfilerState::kDisabled");
PADDLE_ENFORCE(g_state == ProfilerState::kDisabled,
"The profiling state should be disabled when calling ",
"EnableProfiler.");
g_state = state;
if (g_state == ProfilerState::kAll) {
GetDeviceTracer()->Enable();
if (state == g_state) {
return;
}
g_state = state;
should_send_profile_state = true;
GetDeviceTracer()->Enable();
#ifdef PADDLE_WITH_CUDA
if (g_state == ProfilerState::kCUDA) {
// Generate some dummy events first to reduce the startup overhead.
Expand Down Expand Up @@ -435,21 +440,33 @@ void ParseEvents(const std::vector<std::vector<Event>>& events,

void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path) {
PADDLE_ENFORCE(g_state != ProfilerState::kDisabled,
"Can't disable profiling, since it's not starting.");
if (g_state == ProfilerState::kDisabled) return;
// Mark the profiling stop.
Mark("_stop_profiler_", nullptr);

std::vector<std::vector<Event>> all_events = GetAllEvents();
ParseEvents(all_events, sorted_key);
ResetProfiler();
DeviceTracer* tracer = GetDeviceTracer();
if (g_state == ProfilerState::kAll && tracer && tracer->IsEnabled()) {
if (tracer->IsEnabled()) {
tracer->Disable();
tracer->GenProfile(profile_path);
}
g_state = ProfilerState::kDisabled;
should_send_profile_state = true;
}

bool IsProfileEnabled() { return g_state != ProfilerState::kDisabled; }
bool ShouldSendProfileState() { return should_send_profile_state; }

void SetProfileListener() {
std::mt19937 rng;
rng.seed(std::random_device()());
std::uniform_int_distribution<std::mt19937::result_type> dist6(
1, std::numeric_limits<int64_t>::max());
profiler_lister_id = dist6(rng);
}
int64_t ListenerId() { return profiler_lister_id; }

} // namespace platform
} // namespace paddle
8 changes: 8 additions & 0 deletions paddle/fluid/platform/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,13 @@ void ResetProfiler();
void DisableProfiler(EventSortingKey sorted_key,
const std::string& profile_path);

// Test if the profiler is currently enabled.
bool IsProfileEnabled();
// Whether the trainer should send profiling state to PS.
bool ShouldSendProfileState();
// Mark current process as PS by assigning a lister id.
void SetProfileListener();
int64_t ListenerId();

} // namespace platform
} // namespace paddle
90 changes: 54 additions & 36 deletions tools/timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
'--profile_path', type=str, default='', help='Input profile file name.')
'--profile_path',
type=str,
default='',
help='Input profile file name. If there are multiple file, the format '
'should be trainer1=file1,trainer2=file2,ps=file3')
parser.add_argument(
'--timeline_path', type=str, default='', help='Output timeline file name.')
args = parser.parse_args()
Expand Down Expand Up @@ -108,8 +112,8 @@ def format_to_string(self, pretty=False):


class Timeline(object):
def __init__(self, profile_pb):
self._profile_pb = profile_pb
def __init__(self, profile_dict):
self._profile_dict = profile_dict
self._pid = 0
self._devices = dict()
self._chrome_trace = _ChromeTraceFormatter()
Expand All @@ -120,35 +124,37 @@ def _allocate_pid(self):
return cur_pid

def _allocate_pids(self):
for event in self._profile_pb.events:
if event.type == profiler_pb2.Event.CPU:
if (event.device_id, "CPU") not in self._devices:
pid = self._allocate_pid()
self._devices[(event.device_id, "CPU")] = pid
self._chrome_trace.emit_pid("cpu:block:%d" %
(event.device_id), pid)
elif event.type == profiler_pb2.Event.GPUKernel:
if (event.device_id, "GPUKernel") not in self._devices:
pid = self._allocate_pid()
self._devices[(event.device_id, "GPUKernel")] = pid
self._chrome_trace.emit_pid("gpu:%d" % (event.device_id),
pid)
for k, profile_pb in self._profile_dict.iteritems():
for event in profile_pb.events:
if event.type == profiler_pb2.Event.CPU:
if (k, event.device_id, "CPU") not in self._devices:
pid = self._allocate_pid()
self._devices[(k, event.device_id, "CPU")] = pid
self._chrome_trace.emit_pid("%s:cpu:block:%d" %
(k, event.device_id), pid)
elif event.type == profiler_pb2.Event.GPUKernel:
if (k, event.device_id, "GPUKernel") not in self._devices:
pid = self._allocate_pid()
self._devices[(k, event.device_id, "GPUKernel")] = pid
self._chrome_trace.emit_pid("%s:gpu:%d" %
(k, event.device_id), pid)

def _allocate_events(self):
for event in self._profile_pb.events:
if event.type == profiler_pb2.Event.CPU:
type = "CPU"
elif event.type == profiler_pb2.Event.GPUKernel:
type = "GPUKernel"
pid = self._devices[(event.device_id, type)]
args = {'name': event.name}
if event.memcopy.bytes > 0:
args = {'mem_bytes': event.memcopy.bytes}
# TODO(panyx0718): Chrome tracing only handles ms. However, some
# ops takes micro-seconds. Hence, we keep the ns here.
self._chrome_trace.emit_region(
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
event.sub_device_id, 'Op', event.name, args)
for k, profile_pb in self._profile_dict.iteritems():
for event in profile_pb.events:
if event.type == profiler_pb2.Event.CPU:
type = "CPU"
elif event.type == profiler_pb2.Event.GPUKernel:
type = "GPUKernel"
pid = self._devices[(k, event.device_id, type)]
args = {'name': event.name}
if event.memcopy.bytes > 0:
args = {'mem_bytes': event.memcopy.bytes}
# TODO(panyx0718): Chrome tracing only handles ms. However, some
# ops takes micro-seconds. Hence, we keep the ns here.
self._chrome_trace.emit_region(
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
event.sub_device_id, 'Op', event.name, args)

def generate_chrome_trace(self):
self._allocate_pids()
Expand All @@ -163,11 +169,23 @@ def generate_chrome_trace(self):
if args.timeline_path:
timeline_path = args.timeline_path

with open(profile_path, 'r') as f:
profile_s = f.read()
profile_pb = profiler_pb2.Profile()
profile_pb.ParseFromString(profile_s)

tl = Timeline(profile_pb)
profile_paths = profile_path.split(',')
profile_dict = dict()
if len(profile_path) == 1:
with open(profile_path, 'r') as f:
profile_s = f.read()
profile_pb = profiler_pb2.Profile()
profile_pb.ParseFromString(profile_s)
profile_dict['trainer'] = profile_pb
else:
for profile_path in profile_paths:
k, v = profile_path.split('=')
with open(v, 'r') as f:
profile_s = f.read()
profile_pb = profiler_pb2.Profile()
profile_pb.ParseFromString(profile_s)
profile_dict[k] = profile_pb

tl = Timeline(profile_dict)
with open(timeline_path, 'w') as f:
f.write(tl.generate_chrome_trace())

0 comments on commit dce0732

Please sign in to comment.