Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/ser-initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Aug 8, 2024
2 parents 6388d95 + 9bae2b5 commit 64e740a
Show file tree
Hide file tree
Showing 32 changed files with 1,120 additions and 697 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- main
- 'gh/**/base' # ghstack base branches
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
merge_group:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ on:

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
merge_group:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorial/rewriter/examples/erfgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def erf_gelu_pattern_2(op, x):


def gelu(op, x: ir.Value):
return op.Gelu(x, domain="com.microsoft")
return op.Gelu(x, _domain="com.microsoft")


####################################
Expand Down
4 changes: 4 additions & 0 deletions onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True):
"cannot import module, import_module does not work",
),
skip("^test_bitwise_not_3d", "cannot import module, import_module does not work"),
skip(
"^test_resize_upsample_scales_linear_half_pixel_symmetric",
"cannot import module, import_module does not work",
),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,18 @@ def _add_attribute_to_torchscript_node(
return node.fs_(key, list(value)) # type: ignore[arg-type]
if isinstance(value[0], int):
return node.is_(key, list(value)) # type: ignore[attr-defined]
raise TypeError(f"Unsupported sequence type '{type(value)}' for attribute '{key}'")
raise TypeError(f"Unsupported attribute type '{type(value)}' for attribute '{key}'")
raise TypeError(
f"Unsupported sequence type '{type(value)}' for attribute '{key}' in "
f"node={node!r}, value is {value!r}"
)
if "TensorProtoDataType" in str(type(value)):
# torch._C._onnx.TensorProtoDataType
return node.i_(key, int(value))

raise TypeError(
f"Unsupported attribute type '{type(value)}' for attribute '{key}' "
f"in node={node!r}, value is {value!r}"
)


@runtime_typing.checked
Expand Down
Loading

0 comments on commit 64e740a

Please sign in to comment.