Skip to content

Commit

Permalink
[DistDialect] Pir add mp precision alignment ut (PaddlePaddle#63770)
Browse files Browse the repository at this point in the history
* update ut

* run loss success

* fix code style

* tiny update

* for tmp

* fix interpreter probmel

* update

* tiny update

* fix code style

* fix comments

* tinyfix

* tinyfix

* fix conflicts

* update
  • Loading branch information
wentaoyu authored and runzhech committed Apr 30, 2024
1 parent 5e18604 commit 9242795
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <map>
#include <vector>

#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"

PD_DECLARE_bool(new_executor_sequential_run);
Expand Down Expand Up @@ -57,6 +58,11 @@ class DependencyBuilder {

void ShareDependencyFrom(const DependencyBuilder& src);

bool IsSameDeviceContext(size_t op1, size_t op2) const {
return &((*instructions_)[op1].DeviceContext()) ==
&((*instructions_)[op2].DeviceContext());
}

protected:
void AddDependencyForCoalesceTensorOp();
virtual void AddDependencyForCommunicationOp();
Expand Down Expand Up @@ -116,6 +122,11 @@ class PirDependencyBuilder : public DependencyBuilder {

void ShareDependencyFrom(const PirDependencyBuilder& src);

bool IsSameDeviceContext(size_t op1, size_t op2) const {
return &((instructions_)[op1]->DeviceContext()) ==
&((instructions_)[op2]->DeviceContext());
}

private:
void AddDependencyForCommunicationOp() override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,9 @@ void shrink_event_info(
std::set<size_t> unnecessary_waiter_instr_ids;
for (size_t cur_instr_id : waiter_instr_ids) {
for (size_t next_instr_id : waiter_instr_ids) {
if (dependency_builder.OpHappensBefore(cur_instr_id, next_instr_id)) {
if (dependency_builder.OpHappensBefore(cur_instr_id, next_instr_id) &&
dependency_builder.IsSameDeviceContext(cur_instr_id,
next_instr_id)) {
unnecessary_waiter_instr_ids.insert(next_instr_id);
break;
}
Expand Down
13 changes: 13 additions & 0 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def _parallel_pir(self, mode):
dense_program = paddle.base.libpaddle.pir.apply_dist2dense_pass(
dist_program
)

self._pir_dense_main_progs[mode] = dense_program
self._pir_dist_main_progs[mode] = dist_program

Expand Down Expand Up @@ -1022,6 +1023,17 @@ def _init_comm(self):
for process_group in all_process_groups:
process_group.instantiate()

def _init_lr(self):
buffer_tensor = global_scope().var("learning_rate_0").get_tensor()
if not isinstance(self._optimizer._learning_rate, float):
raise TypeError(
"learning rate should be float, got %s here"
% type(self._optimizer._learning_rate)
)
buffer_tensor.set(
np.float32(self._optimizer._learning_rate), self._place
)

def _initialize(self, mode, init_parameters=True):
self._place = _get_device()
if isinstance(self._place, paddle.framework.CUDAPlace):
Expand All @@ -1042,6 +1054,7 @@ def _initialize(self, mode, init_parameters=True):
self._pir_dense_main_progs[mode], self._place
)

self._init_lr()
return

if self._strategy.seed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def reshard(
paddle.pir.set_insertion_point_after(op)
group = new_process_group(src_mesh.process_ids)
reduced_value = paddle._pir_ops.c_allreduce_sum_(
op_value, group.id, False, False
op_value, group.id, True, False
)

# set dist type and dist attr
Expand Down
5 changes: 5 additions & 0 deletions test/auto_parallel/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pir_mse_spmd MODULES test_mse_spmd_rule ENVS
FLAGS_enable_pir_api=1)
py_test_modules(test_mlp MODULES test_mlp ENVS FLAGS_enable_pir_api=1)
py_test_modules(
test_semi_auto_parallel_dist_to_static_pir MODULES
test_semi_auto_parallel_dist_to_static_pir ENVS FLAGS_enable_pir_api=1)
py_test_modules(test_reshard MODULES test_reshard ENVS FLAGS_enable_pir_api=1)
py_test_modules(test_learning_rate MODULES test_learning_rate ENVS
FLAGS_enable_pir_api=1)
set_tests_properties(test_mlp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT
60)
set_tests_properties(test_semi_auto_parallel_dist_to_static_pir
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 30)
endif()
168 changes: 168 additions & 0 deletions test/auto_parallel/pir/semi_auto_parallel_dist_to_static_mlp_pir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random

import numpy as np
from test_to_static_pir_program import DemoNet

import paddle
import paddle.distributed as dist
from paddle import nn
from paddle.framework import _current_expected_place
from paddle.io import DataLoader

BATCH_SIZE = 4
BATCH_NUM = 4
SEQ_LEN = 2
IMAGE_SIZE = 16
CLASS_NUM = 8


def create_numpy_like_random(name):
return paddle.ParamAttr(
name=name, initializer=paddle.nn.initializer.Uniform(0, 1)
)


class RandomDataset(paddle.io.Dataset):
def __init__(self, images, labels, num_samples, return_dict=False):
self.images = images
self.labels = labels
self.num_samples = num_samples
self.return_dict = return_dict

def __getitem__(self, idx):
if self.return_dict:
return {
"image": self.images[idx],
"label": self.labels[idx],
}
else:
return self.images[idx], self.labels[idx]

def __len__(self):
return self.num_samples


class TestSimpleNetForSemiAutoParallel:
def __init__(self):
self._seed = eval(os.getenv("seed"))
self._ckpt_path = os.getenv("ckpt_path")
self.mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
self._in_pir_mode = paddle.base.framework.get_flags(
"FLAGS_enable_pir_api"
)["FLAGS_enable_pir_api"]

def set_random_seed(self, seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)

def create_data_loader(self, return_dict=False):
images = np.random.rand(BATCH_SIZE, IMAGE_SIZE).astype('float32')
labels = np.random.rand(BATCH_SIZE, CLASS_NUM).astype('float32')
dataset = RandomDataset(images, labels, BATCH_SIZE, return_dict)
loader = DataLoader(dataset, batch_size=BATCH_SIZE)
return loader

def run_dy2static(self, layer, opt, dist_loader):
# create loss
loss_fn = nn.MSELoss()
# static training
dist_model = dist.to_static(layer, dist_loader, loss_fn, opt)
loss_list = []
dist_model.train()

if self._in_pir_mode:
mode = "train"

dist_model._engine._has_prepared[mode] = True
dist_model._mode = mode
dist_model._engine._mode = mode
paddle.disable_static()
dist_model._engine._initialize(mode)
dist_model._engine._executor = paddle.static.Executor(
_current_expected_place()
)
dist_model._engine._init_comm()

for epoch in range(5):
for batch_id, data in enumerate(dist_loader()):
if isinstance(data, dict):
image = data['image']
label = data['label']
else:
image, label = data
loss = dist_model(image, label)
loss_list.append(loss)

return np.array(loss_list), dist_model

def run_dynamic(self, layer, opt, dist_loader, is_recompute=False):
# create loss
loss_fn = nn.MSELoss()
loss_list = []
for epoch in range(5):
for batch_id, data in enumerate(dist_loader()):
if isinstance(data, dict):
image = data['image']
label = data['label']
else:
image, label = data
if is_recompute:
image.stop_gradient = False
out = layer(image)
loss = loss_fn(out, label)
loss_list.append(loss.numpy())
loss.backward()

opt.step()
opt.clear_grad()
return np.array(loss_list)

def test_mp_demo_net(self):
paddle.disable_static()
self.set_random_seed(self._seed)
data_loader = self.create_data_loader()

self.set_random_seed(self._seed)
dy_layer = DemoNet(self.mesh)
dy_opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=dy_layer.parameters()
)

paddle.base.set_flags({'FLAGS_enable_pir_api': 1})
self.set_random_seed(self._seed)
dy2static_layer = DemoNet(self.mesh)
dy2static_opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=dy2static_layer.parameters()
)
dist_dataloader = dist.shard_dataloader(
dataloader=data_loader,
meshes=[self.mesh],
)
dy2static_losses, dist_model = self.run_dy2static(
dy2static_layer, dy2static_opt, dist_dataloader
)
dy_losses = self.run_dynamic(dy_layer, dy_opt, dist_dataloader)
np.testing.assert_array_equal(dy_losses, dy2static_losses)

def run_test_case(self):
self.test_mp_demo_net()


if __name__ == '__main__':
TestSimpleNetForSemiAutoParallel().run_test_case()
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest

import collective.test_communication_api_base as test_base


class TestSemiAutoParallelStaticDecorate(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(
num_of_devices=2,
timeout=300,
)
self._default_envs = {"dtype": "float32", "seed": "2023"}
self._changeable_envs = {"backend": ["gpu"]}

def test_mlp(self):
envs_list = test_base.gen_product_envs_list(
{"dtype": "float32", "seed": "2023"}, {"backend": ["gpu"]}
)
for envs in envs_list:
ckpt_path_tmp = tempfile.TemporaryDirectory()
envs["ckpt_path"] = ckpt_path_tmp.name
self.run_test_case(
"semi_auto_parallel_dist_to_static_mlp_pir.py",
user_defined_envs=envs,
)
ckpt_path_tmp.cleanup()


if __name__ == "__main__":
unittest.main()

0 comments on commit 9242795

Please sign in to comment.