Skip to content

Commit

Permalink
[Dy2St][AMP] add should_auto_cast attribute for each operator (Padd…
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Nov 13, 2023
1 parent 03d33df commit 3a7fd89
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 0 deletions.
42 changes: 42 additions & 0 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2935,6 +2935,9 @@ def __init__(
# attr for static graph mode cuda graph
self._cuda_graph_attr = _current_cuda_graph_mode

# attr for OP should cast in AMP mode
self._should_auto_cast: bool = True

op_maker = core.op_proto_and_checker_maker

if op_maker.kOpRoleAttrName() not in op_attrs:
Expand Down Expand Up @@ -3692,6 +3695,25 @@ def dist_attr(self, dist_attr):
"""
self.desc.dist_attr = dist_attr

def set_auto_cast(self, auto_cast):
"""
Set auto cast attribute of this Operator.
Args:
auto_cast(bool): True if this Operator should cast in AMP mode.
"""
self._should_auto_cast = auto_cast

@property
def should_auto_cast(self):
"""
Get auto cast attribute of this Operator.
Returns:
bool: True if this Operator should cast in AMP mode.
"""
return self._should_auto_cast


@signature_safe_contextmanager
def _stride_in_no_check_dy2st_diff():
Expand Down Expand Up @@ -6323,6 +6345,7 @@ def clone(self, for_test=False):
p._copy_param_info_from(self)
p._copy_data_info_from(self, pruned_origin_block_id_map)
p._copy_dist_param_info_from(self)
p._copy_operator_info_from(self)
return p

def _prune(self, targets):
Expand Down Expand Up @@ -6446,6 +6469,7 @@ def _prune_with_input(self, feeded_var_names, targets):
res._copy_param_info_from(self)
res._copy_data_info_from(self, pruned_origin_block_id_map)
res._copy_dist_param_info_from(self)
res._copy_operator_info_from(self)

return res

Expand Down Expand Up @@ -6961,6 +6985,24 @@ def _copy_data_info_from(self, other, pruned_origin_block_id_map=None):
if other_var.stop_gradient:
var.stop_gradient = True

def _copy_operator_info_from(self, other: "Program"):
"""
Copy the information of Operator information from other program.
Args:
other(Program): Other program
Returns:
None
"""
if not isinstance(other, Program):
raise TypeError(
f"Function Program._copy_operator_info_from() needs to pass in a source Program, but received {type(other)}"
)
for dst_block, src_block in zip(self.blocks, other.blocks):
for dst_op, src_op in zip(dst_block.ops, src_block.ops):
dst_op.set_auto_cast(src_op.should_auto_cast)

def list_vars(self):
"""
Get all Tensors from this Program. A iterable object is returned.
Expand Down
38 changes: 38 additions & 0 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import re
from contextlib import contextmanager

import paddle
from paddle.autograd.py_layer import PyLayerMeta
from paddle.base.data_feeder import convert_dtype
from paddle.base.dygraph.base import _convert_into_variable, in_to_static_mode
from paddle.base.framework import Variable, core, default_main_program
from paddle.pir import OpResult
from paddle.static.amp.fp16_utils import AmpOptions

from .py_layer import StaticPyLayer
from .utils import (
Expand Down Expand Up @@ -77,6 +81,9 @@ def convert_load(x):
if new_var is not None:
return new_var

if x is paddle.amp.auto_cast:
return convert_auto_cast

return x


Expand Down Expand Up @@ -805,6 +812,37 @@ def convert_pop(target, *args):
return _run_python_pop(target, *args)


@contextmanager
def convert_auto_cast(
enable=True,
custom_white_list=None,
custom_black_list=None,
level='O1',
dtype='float16',
use_promote=True,
):
from .program_translator import ProgramTranslator

if enable:
raise NotImplementedError("Does not support local switching on amp now")

amp_records = ProgramTranslator.get_instance()._amp_records
main_program = paddle.static.default_main_program()
current_block_idx = main_program.current_block_idx
current_block = main_program.current_block()
start_op_idx = len(current_block.ops)
amp_options = AmpOptions(
enable, custom_white_list, custom_black_list, level, dtype, use_promote
)
yield
end_op_idx = len(current_block.ops)
if current_block_idx not in amp_records:
amp_records[current_block_idx] = []
amp_records[current_block_idx].append(
(amp_options, start_op_idx, end_op_idx)
)


def _run_paddle_pop(array, *args):
if len(args) == 0:
idx = -1
Expand Down
8 changes: 8 additions & 0 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections
import inspect
import os
import threading
import warnings
import weakref
from typing import TYPE_CHECKING

import paddle.pir.core as ir_static
from paddle import decomposition
Expand Down Expand Up @@ -65,6 +68,9 @@
unwrap,
)

if TYPE_CHECKING:
from paddle.static.amp.fp16_utils import AmpOptions

__all__ = []

# For each traced function, we set `max_traced_program_count` = 10 to consider caching performance.
Expand Down Expand Up @@ -1286,6 +1292,7 @@ def from_func_spec(
)

new_name_generator = UniqueNameGenerator()
ProgramTranslator.get_instance()._amp_records.clear()

with framework.program_guard(main_program, startup_program):
with _to_static_mode_guard_(is_to_static=True), UniqueNameGuard(
Expand Down Expand Up @@ -1763,6 +1770,7 @@ def __init__(self):
self._program_cache = ProgramCache()
self._params_recorder = ParametersRecorder()
self._inplace_map = InplaceMap()
self._amp_records: dict[int, list[tuple[AmpOptions, int, int]]] = {}
self.enable_to_static = True

def enable(self, enable_to_static):
Expand Down
47 changes: 47 additions & 0 deletions python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging
from dataclasses import dataclass

import numpy as np

Expand Down Expand Up @@ -40,6 +43,16 @@
_fp16_guard_pattern = "__use_fp16__"


@dataclass
class AmpOptions:
enable: bool
custom_white_list: list[str] | None
custom_black_list: list[str] | None
level: str
dtype: str
use_promote: bool


def _rename_arg(op, old_name, new_name):
"""
If an op has old_name input and output, rename these input
Expand Down Expand Up @@ -586,6 +599,40 @@ def process_op_input_and_outputs(op, block, global_block, dtype):
return low_precison_var_names


def map_block(block, fn, parent_op=None):
fn(block, parent_op)
program = block.program
for op in block.ops:
if not op.has_attr("sub_block"):
continue
sub_block = program.blocks[op.attr("sub_block").id]
map_block(sub_block, fn, op)


def prepare_op_should_auto_cast(
program: paddle.static.Program,
amp_records: dict[int, list[tuple[AmpOptions, int, int]]],
):
amp_enable_op_map: dict[paddle.static.Operator, bool] = {}

def fill_amp_enable_op_map(block, parent_op):
block_idx = block.idx
ops = block.ops
for op in ops:
# The top level should be FP16
current_op_amp_options = amp_enable_op_map.get(parent_op, True)
if block_idx in amp_records:
for amp_options, start, end in amp_records[block_idx]:
if op.idx in range(start, end):
current_op_amp_options = amp_options.enable
break
amp_enable_op_map[op] = current_op_amp_options

map_block(program.global_block(), fill_amp_enable_op_map)
for op, enable in amp_enable_op_map.items():
op.set_auto_cast(enable)


def cast_model_to_fp16(
program,
amp_lists=None,
Expand Down
128 changes: 128 additions & 0 deletions test/dygraph_to_static/test_local_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE(SigureMo): This unittest does NOT need to run in PIR mode. Don't import Dy2StTestBase.

import unittest

import paddle
from paddle.jit.dy2static.program_translator import ProgramTranslator
from paddle.static.amp.fp16_utils import prepare_op_should_auto_cast


class LocalAutoCastLayer1(paddle.nn.Layer):
def __init__(self):
super().__init__()
self._fc = paddle.nn.Linear(10, 10)

@paddle.jit.to_static(full_graph=True)
def forward(self, x):
x = self._fc(x)
y = self._fc(x) * 2
with paddle.amp.auto_cast(False):
x = x.astype("float32")
y = y.astype("float32")
if x[0][0] > 1:
x = x + y
else:
x = x - y
x = x * 2

return x + 1


class LocalAutoCastLayer2(paddle.nn.Layer):
def __init__(self):
super().__init__()
self._fc = paddle.nn.Linear(10, 10)

@paddle.jit.to_static(full_graph=True)
def forward(self, x):
with paddle.amp.auto_cast(False):
x = x.astype("float32")
x = self._fc(x)
y = self._fc(x) * 2
if x[0][0] > 1:
x = x + y
else:
x = x - y
x = x * 2

return x + 1


class TestLocalCast(unittest.TestCase):
def get_auto_cast_ops_info_from_program(self, program):
auto_cast_ops_info = []
for block in program.blocks:
current_block_should_auto_cast = []
auto_cast_ops_info.append(current_block_should_auto_cast)
for op in block.ops:
current_block_should_auto_cast.append(op.should_auto_cast)
return auto_cast_ops_info

def should_auto_cast_for_each_ops(self, layer, input):
concrete_program, _ = layer.forward.get_concrete_program(input)
program = concrete_program.main_program
prepare_op_should_auto_cast(
program, ProgramTranslator.get_instance()._amp_records
)
auto_cast_ops_info = self.get_auto_cast_ops_info_from_program(program)
paddle.enable_static()
cloned_program = program.clone()
paddle.disable_static()
cloned_auto_cast_ops_info = self.get_auto_cast_ops_info_from_program(
cloned_program
)
self.assertEqual(auto_cast_ops_info, cloned_auto_cast_ops_info)
return auto_cast_ops_info

def test_should_auto_cast_1(self):
layer = LocalAutoCastLayer1()
input = paddle.randn([10, 10])
expected = [
# There are part of ops in auto_cast(False) block
[
True, True, True, True, True,
False, False, False, False, False, False, False, False, False, False, False,
True,
],
# All if branch in auto_cast(False) block
[False, False],
# All else branch in auto_cast(False) block
[False, False, False],
] # fmt: skip
actual = self.should_auto_cast_for_each_ops(layer, input)
self.assertEqual(expected, actual)

def test_should_auto_cast_2(self):
layer = LocalAutoCastLayer2()
input = paddle.randn([10, 10])
expected = [
# There are part of ops in auto_cast(False) block
[
False, False, False, False, False, False,
True, True, True, True, True, True, True, True, True, True,
],
# All if branch out of auto_cast(False) block
[True, True],
# All else branch out of auto_cast(False) block
[True, True, True],
] # fmt: skip
actual = self.should_auto_cast_for_each_ops(layer, input)
self.assertEqual(expected, actual)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3a7fd89

Please sign in to comment.