diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 85bafbef2b63e..1af0d447da29c 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -56,6 +56,8 @@ message MpConfig { optional bool sync_grad= 2 [ default = false ]; optional bool sync_moment= 3 [ default = false ]; optional string sync_mode= 4 [ default = 'broadcast' ]; + // Broadcast mp input data + optional bool need_broadcast_data=8 [default = true]; } message PpConfig { diff --git a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py index 883533d8e1724..13546a02b5bd2 100755 --- a/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/tensor_parallel.py @@ -42,5 +42,10 @@ def _prepare_for_model(self): logger.info("mp's parameters is ready") def _pre_forward(self, *inputs, **kwargs): - logger.debug("mp start broadcast input data") - return broadcast_input_data(self._hcg, *inputs, **kwargs) + need_broadcast_data = True + if self._strategy is not None: + mp_configs = self._strategy.hybrid_configs["mp_configs"] + need_broadcast_data = mp_configs.need_broadcast_data + if need_broadcast_data: + logger.debug("mp start broadcast input data") + return broadcast_input_data(self._hcg, *inputs, **kwargs)