Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/ir-builder-ir
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Apr 23, 2024
2 parents 5d753ed + c979aae commit 060009e
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 146 deletions.
1 change: 0 additions & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ exclude_patterns = [
'onnxscript/optimizer/constant_folding.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
'onnxscript/rewriter/function_rule.py', # FIXME
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
'onnxscript/optimizer/fold_constants_v0.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
Expand Down
2 changes: 0 additions & 2 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def rewrite(
function_rewrite_rules: Sequence[type[FunctionRewriteRule]] = (),
pattern_rewrite_rules: Sequence[PatternRewriteRule] = (),
) -> onnx.ModelProto:
print(f"len(value_info): {len(model.graph.value_info)}")
model_ir = ir.serde.deserialize_model(model)
if function_rewrite_rules:
for rule_cls in function_rewrite_rules:
Expand All @@ -37,5 +36,4 @@ def rewrite(
model = ir.serde.serialize_model(model_ir)
remove_unused.remove_unused_nodes(model)
remove_unused_function.remove_unused_functions(model)
print(f"len(value_info): {len(model.graph.value_info)}")
return model
100 changes: 43 additions & 57 deletions onnxscript/rewriter/function_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import functools
import logging
from typing import Callable

import onnx
from packaging import version

import onnxscript
from onnxscript import ir
from onnxscript._legacy_ir import visitor
from onnxscript.rewriter import pattern

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,7 +38,7 @@ def parse_domain(function_domain: str) -> tuple[str, version.Version | None]:
class VersionController:
def __init__(self):
# A dispatch table for rewrite implementation based on the function package version.
self.dispatch_table: dict[tuple[version.Version, version.Version], callable] = {}
self.dispatch_table: dict[tuple[version.Version, version.Version], Callable] = {}

def register_version(
self,
Expand Down Expand Up @@ -66,7 +66,7 @@ def deco(func):

return deco

def dispatch(self, version: version.Version | None) -> callable | None:
def dispatch(self, version: version.Version | None) -> Callable | None:
if version is None:
if len(self.dispatch_table) == 1:
return next(iter(self.dispatch_table.values()))
Expand Down Expand Up @@ -94,12 +94,11 @@ class FunctionRewriteRule(pattern.RewriteRule):

_opset_imports: dict[str, int]
onnx_opset: onnxscript.values.Opset
_function_shape_env: visitor.FunctionShapeEnv

def __init__(self, opset: onnxscript.values.Opset = onnxscript.opset18) -> None:
def __init__(self, opset: onnxscript.values.Opset = onnxscript.opset18) -> None: # type: ignore[has-type]
self.onnx_opset = opset

def _match_function(self, function: onnx.FunctionProto, pkg_name: str) -> bool:
def _match_function(self, function: ir.Function, pkg_name: str) -> bool:
# TODO: Consolidate more checks from `compose_new_function` to here.
if pkg_name != self.PACKAGE_NAME:
logger.info(
Expand All @@ -111,7 +110,6 @@ def _match_function(self, function: onnx.FunctionProto, pkg_name: str) -> bool:
pkg_name,
)
return False

if isinstance(self.FUNCTION_KEYWORD, str):
return function.name.find(self.FUNCTION_KEYWORD) != -1
elif isinstance(self.FUNCTION_KEYWORD, tuple):
Expand All @@ -130,27 +128,17 @@ def _find_node_contains_key_in_name(
return None

def _find_node_by_type(
self, function: onnx.FunctionProto, domain: str, op_type: str
) -> onnx.NodeProto | None:
self, function: ir.Function, domain: str, op_type: str
) -> ir.Node | None:
# Repeat
for node in function.node:
for node in function:
if node.domain == domain and node.op_type == op_type:
return node
return None

def _find_constant_node(
self, function: onnx.FunctionProto, value_name: str
) -> onnx.NodeProto | None:
# Potentially repeat, utility function.
for node in function.node:
for output in node.output:
if output == value_name:
return node
return None

def compose_new_function(
self, old_function: onnx.FunctionProto, pkg_version: version.Version | None
) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]:
self, old_function: ir.Function, pkg_version: version.Version | None
) -> ir.Function:
"""Compose a new function from the old function.
Returns:
Expand All @@ -159,21 +147,23 @@ def compose_new_function(
Raises:
FunctionRewriteError: If the rewrite fails.
"""
func = self._version_controller.dispatch(pkg_version)
# self._version_controller is created in the subclass
func = self._version_controller.dispatch(pkg_version) # type: ignore[attr-defined]
if func is not None:
return func(self, old_function)
new_function = func(self, old_function)
return new_function
raise FunctionRewriteError(
f"No rewrite implementation for package version {pkg_version}."
)

def try_rewrite_function(
self, function: onnx.FunctionProto, model: onnx.ModelProto
) -> bool:
self, function: ir.Function
) -> tuple[ir.OperatorIdentifier, ir.Function] | None:
try:
pkg_name, pkg_version = parse_domain(function.domain)
except FunctionRewriteError as e:
logger.warning("Could not parse domain: %s", e)
return False
return None

if pkg_version is None and not pkg_name.startswith("onnxscript"):
logger.warning(
Expand All @@ -185,57 +175,53 @@ def try_rewrite_function(
)

if not self._match_function(function, pkg_name):
return False
return None
logger.info(
"Rule %s matched function %s::%s",
self.__class__.__name__,
function.domain,
function.name,
)

try:
new_function, opset_imports = self.compose_new_function(function, pkg_version)
new_function = self.compose_new_function(function, pkg_version)
except FunctionRewriteError as e:
logger.warning("Could not rewrite function: %s", e)
return False

nodes = new_function.node
return None

del function.input[:]
function.input.extend(new_function.input)
del function.output[:]
function.output.extend(new_function.output)
new_function.name = function.name
new_function.domain = function.domain

del function.node[:]
function.node.extend(nodes)
for new_opset in opset_imports:
function.opset_import.append(new_opset)
if new_opset.domain not in self._opset_imports:
model.opset_import.append(new_opset)
return True
return function.identifier(), new_function

def try_rewrite(self, model: ir.Model, value) -> bool:
raise NotImplementedError(
"Use `try_rewrite_function` instead for function based rewrites."
)

def lookup(self, function: onnx.FunctionProto, value_name: str) -> ir.Value | None:
return self._function_shape_env.lookup(function, value_name)

def apply_to_model(
self, model: ir.Model, *, commute: bool = False
) -> tuple[int, ir.Model]:
del commute # unused
model_proto: onnx.ModelProto = ir.serde.serialize_model(model)
self._function_shape_env = visitor.FunctionShapeEnv()
self._function_shape_env.load_from_model_proto(model_proto)
self._opset_imports = {x.domain: x.version for x in model_proto.opset_import}

rewrite_count = 0
for function in model_proto.functions:
rewrite_count += self.try_rewrite_function(function, model_proto)
model = ir.serde.deserialize_model(model_proto)
return rewrite_count, model

old_function_to_new_function: dict[ir.OperatorIdentifier, ir.Function] = {}
for function in model.functions.values():
rewrite_or_none = self.try_rewrite_function(function)
if rewrite_or_none is not None:
old_function_to_new_function[rewrite_or_none[0]] = rewrite_or_none[1]
model = self.update_to_new_function(model, old_function_to_new_function)
return len(old_function_to_new_function), model

def update_to_new_function(
self,
model: ir.Model,
old_function_to_new_function: dict[ir.OperatorIdentifier, ir.Function],
) -> ir.Model:
for old_function_id, new_function_ir in old_function_to_new_function.items():
model.functions[old_function_id] = new_function_ir
for new_opset, opset_version in new_function_ir.opset_imports.items():
if new_opset not in model.opset_imports:
model.opset_imports[new_opset] = opset_version
return model

def count_matches(self, model, *, commute: bool = False) -> int:
raise NotImplementedError()
Expand Down
9 changes: 3 additions & 6 deletions onnxscript/rewriter/onnxruntime/transformers/biassplitgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import logging

import onnx

import onnxscript
from onnxscript import ir
from onnxscript.rewriter import function_rule

logger = logging.getLogger(__name__)
Expand All @@ -16,9 +15,7 @@ class GegluRewriteRule(function_rule.FunctionRewriteRule):
_version_controller = function_rule.VersionController()

@_version_controller.register_version() # type: ignore[misc]
def _fusion(
self, function: onnx.FunctionProto
) -> tuple[onnx.FunctionProto, tuple[onnx.OperatorSetIdProto]]:
def _fusion(self, function: ir.Function) -> ir.Function:
del function # Unused
op = self.onnx_opset
msft_opset = onnxscript.values.Opset("com.microsoft", 1)
Expand All @@ -29,4 +26,4 @@ def ggelu(input, weight, bias):
return msft_opset.BiasSplitGelu(matmul_input, bias)

function_proto = onnxscript.script(default_opset=op)(ggelu).to_function_proto() # type: ignore[arg-type]
return function_proto, (onnx.helper.make_operatorsetid("com.microsoft", 1),)
return ir.serde.deserialize_function(function_proto)
12 changes: 4 additions & 8 deletions onnxscript/rewriter/onnxruntime/transformers/fastgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import logging

import onnx

import onnxscript
from onnxscript import ir
from onnxscript.rewriter import function_rule

logger = logging.getLogger(__name__)
Expand All @@ -16,16 +15,13 @@ class GeluRewriteRule(function_rule.FunctionRewriteRule):
_version_controller = function_rule.VersionController()

@_version_controller.register_version()
def _fusion(
self, function: onnx.FunctionProto
) -> tuple[onnx.FunctionProto, list[onnx.OperatorSetIdProto]]:
def _fusion(self, function: ir.Function) -> ir.Function:
del function # Unused
op = self.onnx_opset
msft_opset = onnxscript.values.Opset("com.microsoft", 1)

def gelu(input):
return msft_opset.FastGelu(input)

return onnxscript.script(default_opset=op)(gelu).to_function_proto(), (
onnx.helper.make_operatorsetid("com.microsoft", 1),
)
function_proto = onnxscript.script(default_opset=op)(gelu).to_function_proto()
return ir.serde.deserialize_function(function_proto)
23 changes: 10 additions & 13 deletions onnxscript/rewriter/onnxruntime/transformers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

import logging

import onnx
from onnx import numpy_helper

import onnxscript
from onnxscript.rewriter import function_rule
from onnxscript import ir
from onnxscript.rewriter import _ir_utils, function_rule

logger = logging.getLogger(__name__)

Expand All @@ -17,19 +15,17 @@ class LNRewriteRule(function_rule.FunctionRewriteRule):
_version_controller = function_rule.VersionController()

@_version_controller.register_version()
def _fusion( # type: ignore[misc]
self, function: onnx.FunctionProto
) -> tuple[onnx.FunctionProto, list[onnx.OperatorSetIdProto]]:
def _fusion(self, function: ir.Function) -> ir.Function:
# TODO(bowbao): Might be more desirable to annotate as attribute in nn.Module
aten_add_node = self._find_node_by_type(function, "", "Add")
if aten_add_node is None:
raise function_rule.FunctionRewriteError("Could not find Add node")

eps_node = self._find_constant_node(function, aten_add_node.input[1])
if eps_node is None:
raise function_rule.FunctionRewriteError("Could not find eps node")

eps = numpy_helper.to_array(eps_node.attribute[0].t).item()
eps_ir_value = _ir_utils.propagate_const_value(aten_add_node.inputs[1])
eps_numpy_value = _ir_utils.get_numpy_from_ir_value(eps_ir_value)
if eps_numpy_value is None:
raise function_rule.FunctionRewriteError("Could not find eps")
eps = eps_numpy_value.item()
logger.info("eps: %s", eps)

# TODO(ORT): SimplifiedLayerNormalization in ort is defined under onnx domain.
Expand All @@ -42,4 +38,5 @@ def ln(input, weight):
input, weight, axis=-1, epsilon=eps, stash_type=1
)

return onnxscript.script(default_opset=op)(ln).to_function_proto(), []
function_proto = onnxscript.script(default_opset=op)(ln).to_function_proto()
return ir.serde.deserialize_function(function_proto)
Loading

0 comments on commit 060009e

Please sign in to comment.