Skip to content

Commit

Permalink
Merge branch 'main' into xiaowu/fixBug(embedding_bag)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaowuhu committed Nov 20, 2023
2 parents 2666f56 + 10f9a1f commit 95a24dd
Show file tree
Hide file tree
Showing 33 changed files with 1,144 additions and 279 deletions.
4 changes: 2 additions & 2 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ init_command = [
]

[[linter]]
code = 'BLACK-ISORT'
code = 'RUFF-FORMAT'
include_patterns = [
'**/*.py',
]
Expand All @@ -82,7 +82,7 @@ command = [
'-m',
'lintrunner_adapters',
'run',
'black_isort_linter',
'ruff_format_linter',
'--',
'@{{PATHSFILE}}'
]
Expand Down
4 changes: 1 addition & 3 deletions docs/examples/04_plot_eager_mode_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@


@script()
def linear(
A: FLOAT["N", "K"], W: FLOAT["K", "M"], Bias: FLOAT["M"]
) -> FLOAT["N", "M"]: # noqa: F821
def linear(A: FLOAT["N", "K"], W: FLOAT["K", "M"], Bias: FLOAT["M"]) -> FLOAT["N", "M"]: # noqa: F821
T1 = op.MatMul(A, W)
T2 = op.Add(T1, Bias)
Y = op.Relu(T2)
Expand Down
4 changes: 2 additions & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
"pyyaml",
)
ONNX = "onnx==1.14.1"
ONNX_RUNTIME = "onnxruntime==1.16.0"
PYTORCH = "torch==2.0.1"
ONNX_RUNTIME = "onnxruntime==1.16.1"
PYTORCH = "torch==2.1.0"
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
"flatbuffers",
"coloredlogs",
Expand Down
3 changes: 1 addition & 2 deletions onnxscript/_internal/param_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def separate_input_attributes_from_arguments(
else:
onnx_attributes[param.name] = kwargs[param.name]
elif (
param.is_attribute
and param.default is not values._EmptyDefault # pylint: disable=protected-access
param.is_attribute and param.default is not values._EmptyDefault # pylint: disable=protected-access
):
# User did not provide the attribute
if fill_defaults:
Expand Down
3 changes: 1 addition & 2 deletions onnxscript/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
import onnx

import onnxscript
from onnxscript import irbuilder, onnx_types, sourceinfo
from onnxscript import irbuilder, onnx_types, sourceinfo, values
from onnxscript import type_annotation as ta
from onnxscript import values
from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation

PY_VERSION_GE_39 = ast_utils.PY_VERSION_GE_39
Expand Down
6 changes: 5 additions & 1 deletion onnxscript/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import onnx
import onnxruntime as ort
import pytest
from onnxruntime.capi.onnxruntime_pybind11_state import (
Fail,
InvalidArgument,
Expand Down Expand Up @@ -270,7 +271,10 @@ def test_renaming(self):

self.validate_save(renaming, shape_inference=False)

@unittest.skip(reason="TypeError: val must be numeric not <class 'NoneType'>")
@pytest.mark.xfail(
strict=True,
reason="default_opset must be specified in script for functions that do not contain any use of an ONNX op",
)
def test_opt_output(self):
from onnxscript.tests.models import opt_output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class TestDeduceTypeConstraints(unittest.TestCase):
"_aten_embedding_bag_onnx",
"_aten_embedding_bag_1d_padding_idx_onnx",
)
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ()

@parameterized.parameterized.expand(
((op,) for op in torch_lib_onnx_functions_from_registry()),
Expand All @@ -41,11 +40,13 @@ def test_deduce_type_constraints_does_not_crash_for_onnx_function(
):
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN:
self.skipTest("Unimplemented: function contains loop or scan node.")
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_NESTED_FUNCTION:
self.skipTest("Unimplemented: function contains nested function.")
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
onnx_function
)
try:
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
onnx_function
)
except NotImplementedError as e:
if "Nested function" in str(e):
self.skipTest("Unimplemented: function contains nested function.")
logger.info(
"Original signature: %s%s",
onnx_function.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,8 @@
import os
import re
import textwrap
from pathlib import Path
from typing import Any, Dict, List, Sequence

import black
import isort
import torch
import torchgen.gen
import torchgen.model
Expand Down Expand Up @@ -319,15 +316,6 @@ def main(args: argparse.Namespace) -> None:
)
py_module.accept(cg.PythonWriter(f))

# Format the generated files so that they pass linting.
# line_length=95 is to match the lintrunner rules.
isort.file(output_path)
black.format_file_in_place(
Path(output_path),
fast=True,
mode=black.Mode(line_length=95),
write_back=black.WriteBack.YES,
)
print("Done.")


Expand Down
3 changes: 3 additions & 0 deletions onnxscript/function_libs/torch_lib/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Shared constants for the library."""

DOMAIN = "pkg.onnxscript.torch_lib"
41 changes: 41 additions & 0 deletions onnxscript/function_libs/torch_lib/_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Experimental flags.
NOTE: These flags are experimental only. Any flag here can be removed at any
time without notice.
"""

import logging
import os

logger = logging.getLogger(__name__)


def _load_boolean_flag(
name: str,
*,
this_will: str,
deprecated: bool = False,
) -> bool:
"""Load a boolean flag from environment variable.
Args:
name: The name of the environment variable.
this_will: A string that describes what this flag will do.
deprecated: Whether this flag is deprecated.
"""
state = os.getenv(name) == "1"
if state:
if deprecated:
logger.error(
"Experimental flag %s is deprecated. Please remove it from your environment.",
name,
)
else:
logger.warning("Experimental flag %s is enabled. This will %s.", name, this_will)
return state


EXPERIMENTAL_INITIALIZERS_AS_INPUTS: bool = _load_boolean_flag(
"TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS",
this_will="make initializers as inputs to the model graph",
)
29 changes: 25 additions & 4 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from onnxscript import evaluator
from onnxscript import tensor as onnxscript_tensor
from onnxscript._internal import param_manipulation, runtime_typing
from onnxscript.function_libs.torch_lib import _flags
from onnxscript.function_libs.torch_lib.ops import common as common_ops

__all__ = [
"TorchScriptTensor",
Expand Down Expand Up @@ -198,7 +200,7 @@ def symbolic_value(self) -> torch.Value:
def _unwrap_tensor_to_torch_value(
value: Union[
ValidArgumentType, Mapping[str, ValidArgumentType], Sequence[ValidArgumentType]
]
],
) -> Union[
ValidTorchValueType,
Dict[str, ValidTorchValueType],
Expand Down Expand Up @@ -363,6 +365,16 @@ def _tensor_rawdata_size(tensor: torch.Tensor) -> int:
return tensor.numel() * tensor.element_size()


def _shared_functions() -> list[onnx.FunctionProto]:
"""Hack to always include the share ops."""

# TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
return [
common_ops.Rank.to_function_proto(),
common_ops.IsScalar.to_function_proto(),
]


class TorchScriptGraph:
def __init__(
self,
Expand Down Expand Up @@ -717,7 +729,6 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func
opset_imports=onnx_model.opset_import,
doc_string=onnx_model.doc_string,
)
# TODO: onnx.checker.check_function(onnx_function)?
return onnx_function

@runtime_typing.checked
Expand All @@ -740,13 +751,15 @@ def to_model_proto(
large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD

export_kwargs: dict[str, Any] = dict(
initializers=self.initializers if include_initializers else {},
initializers=self.initializers
if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS
else {},
onnx_opset_version=opset_version,
dynamic_axes={},
defer_weight_export=False,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
strip_doc_string=False,
keep_initializers_as_inputs=False,
keep_initializers_as_inputs=_flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS,
custom_opsets={},
add_node_names=True,
node_attr_to_name={},
Expand Down Expand Up @@ -786,6 +799,7 @@ def to_model_proto(
onnx_model = onnx.load_from_string(proto)

onnx_model.functions.extend(function_proto_dict.values())
onnx_model.functions.extend(_shared_functions())

# `_export_onnx` only exports opset_imports that is visible to it. It does not
# export opset_imports for nested functions, since it does not have access to
Expand All @@ -800,6 +814,13 @@ def to_model_proto(
for domain, version in unique_custom_domains.items()
]
)
# Include the library shared opset domain
# TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
onnx_model.opset_import.append(
onnx.helper.make_opsetid(
common_ops.common_opset.domain, common_ops.common_opset.version
)
)

try:
if not cache_model_to_disk:
Expand Down
2 changes: 0 additions & 2 deletions onnxscript/function_libs/torch_lib/graph_building_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
import onnxscript.testing
from onnxscript import FLOAT, evaluator
from onnxscript import opset18 as op
from onnxscript._internal import version_utils
from onnxscript.function_libs.torch_lib import graph_building, ops


@unittest.skipIf(version_utils.torch_older_than("2.0"), "torchscript in 1.13 not supported")
class TestTorchScriptTracingEvaluator(unittest.TestCase):
def setUp(self):
self.opset_version = 18
Expand Down
56 changes: 56 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Common operators shared in the torchlib library."""

import onnxscript
import onnxscript.values
from onnxscript import BOOL, INT64
from onnxscript import opset18 as op
from onnxscript.function_libs.torch_lib import _constants, tensor_typing
from onnxscript.function_libs.torch_lib.tensor_typing import RealType
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT

COMPLEX64_TYPE = COMPLEX64.dtype
COMPLEX128_TYPE = COMPLEX128.dtype

DOMAIN = f"{_constants.DOMAIN}.common"

common_opset = onnxscript.values.Opset(domain=DOMAIN, version=1)


@onnxscript.script(common_opset)
def Rank(input: tensor_typing.TTensor) -> INT64:
"""Take the rank of the input tensor."""

return op.Size(op.Shape(input))


@onnxscript.script(common_opset)
def IsScalar(input: tensor_typing.TTensor) -> BOOL:
"""Return whether the input has rank 0, or is a scalar."""

return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))


def cast_to(a: RealType, dtype: int) -> RealType:
"""Cast input to dtype while handling complex types."""

# Traced function because different if branches return different dtypes
# which is not supported in an ONNX function
if dtype == COMPLEX128_TYPE:
# Cast to the real representation of the complex type
casted = op.Cast(a, to=DOUBLE.dtype)
# Create a complex number
real_part = op.Unsqueeze(casted, axes=[-1])
imag_part = op.Expand(op.Cast(0.0, to=DOUBLE.dtype), op.Shape(real_part))
result = op.Concat(real_part, imag_part, axis=-1)
elif dtype == COMPLEX64_TYPE:
# Cast to the real representation of the complex type
casted = op.Cast(a, to=FLOAT.dtype)
# Create a complex number
real_part = op.Unsqueeze(casted, axes=[-1])
imag_part = op.Expand(0.0, op.Shape(real_part))
result = op.Concat(real_part, imag_part, axis=-1)
else:
# Cast to real numbers
result = op.Cast(a, to=dtype)

return result
Loading

0 comments on commit 95a24dd

Please sign in to comment.