Skip to content

Commit

Permalink
- lints
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Jul 14, 2022
1 parent bbe11e2 commit 96631de
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 48 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relay/collage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .collage import MEASURE_NUMBER, MEASURE_REPEAT, WARMUP_MIN_REPEAT_MS, CostEstimator
"""relay.collage exports"""
from .collage import MEASURE_NUMBER, MEASURE_REPEAT, WARMUP_MIN_REPEAT_MS, CostEstimator, \
MockEstimator
72 changes: 39 additions & 33 deletions python/tvm/relay/collage/collage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@
"""Mostly helper methods which interface the main C++ Collage implementation with Python.
See relay.transform.CollagePartition for the main Collage entrypoint."""

import tvm
from tvm._ffi.registry import register_func, register_object
from tvm.runtime import Object
from . import _ffi_api
import numpy as np
import logging
import os
import math
import tempfile

import numpy as np

import tvm
from tvm._ffi.registry import register_func, register_object
from tvm.runtime import Object
from . import _ffi_api

# Parameters to use when estimating latency (of both partitions and overall models).
MEASURE_NUMBER = 20
MEASURE_REPEAT = 5
Expand All @@ -50,22 +52,23 @@ def __init__(self, target_costs):
self.__init_handle_by_constructor__(_ffi_api.MockEstimator, target_costs)


def arg_for(type, device):
"""Returns a test argument of type on device"""
assert isinstance(type, tvm.ir.TensorType)
def arg_for(arg_type, device):
"""Returns a test argument of Relay arg_type on device"""
assert isinstance(arg_type, tvm.ir.TensorType)
return tvm.nd.array(
np.random.uniform(-1.0, 1.0, size=type.concrete_shape).astype(type.dtype), device=device
np.random.uniform(-1.0, 1.0, size=arg_type.concrete_shape).astype(arg_type.dtype),
device=device,
)


def vm_estimate_seconds(device, vm, func_name, args):
"""Returns the estimated latency, in seconds, of running func_name with args on the given vm."""
def vm_estimate_seconds(device, the_vm, func_name, args):
"""Returns the estimated latency, in seconds, of running func_name with args on the_vm."""
# Warmup
vm.benchmark(
the_vm.benchmark(
device, repeat=1, number=1, min_repeat_ms=WARMUP_MIN_REPEAT_MS, func_name=func_name, **args
)
# One more time, with feeling
return vm.benchmark(
return the_vm.benchmark(
device,
repeat=MEASURE_REPEAT,
number=MEASURE_NUMBER,
Expand All @@ -85,11 +88,11 @@ def estimate_seconds(mod, target):
# Build the module.
logging.info("Compiling module to estimate")
exe = tvm.relay.vm.compile(mod, target)
except RuntimeError as e:
except RuntimeError as err:
# A build failure indicates the partition is not supported.
# eg trying to build an nn.batch_norm on GPU, which has no schedule since we assume it
# is only ever used with a tuple projection which is rewritten away.
logging.info(f"Assigning module infinite cost since unable to build: {e}")
logging.info("Assigning module infinite cost since unable to build: %s", err)
return math.inf

# Finalize compilation
Expand All @@ -102,35 +105,35 @@ def estimate_seconds(mod, target):
exe = tvm.runtime.vm.Executable.load_exec(code, lib)

# Benchmark the module.
vm = tvm.runtime.vm.VirtualMachine(exe, device)
the_vm = tvm.runtime.vm.VirtualMachine(exe, device)
func_name = "main"
main_args = {v.name_hint: arg_for(v.checked_type, device) for v in mod[func_name].params}
logging.info("Benchmarking module to estimate")
profile = vm_estimate_seconds(device, vm, func_name, main_args)
logging.info(f"profile: {profile}")
profile = vm_estimate_seconds(device, the_vm, func_name, main_args)
logging.info("profile: %s", profile)
return profile.median # seconds


make_labelled_dfpattern_partition_rule = tvm._ffi.get_global_func(
"relay.collage.make_labelled_dfpattern_partition_rule"
MakeLabelledDFPatternPartitionRule = tvm._ffi.get_global_func(
"relay.collage.MakeLabelledDFPatternPartitionRule"
)
make_labelled_dfpattern_partition_rule_with_predicate = tvm._ffi.get_global_func(
"relay.collage.make_labelled_dfpattern_partition_rule_with_predicate"
MakeLabelledDFPatternPartitionRuleWithPredicate = tvm._ffi.get_global_func(
"relay.collage.MakeLabelledDFPatternPartitionRuleWithPredicate"
)
make_pattern_byoc_partition_rule = tvm._ffi.get_global_func(
"relay.collage.make_pattern_byoc_partition_rule"
MakePatternBYOCPartitionRule = tvm._ffi.get_global_func(
"relay.collage.MakePatternBYOCPartitionRule"
)


def make_labelled_dfpattern_partition_rule_wrapper(compiler, tuple):
def make_labelled_dfpattern_partition_rule_wrapper(compiler, pattern_tuple):
"""Returns a DFPatternPartitionRule representing one (label, pattern, predicate) entry from
the pattern table for external codegen compiler"""
if len(tuple) == 2:
rule_name, dataflow_pattern = tuple
return make_labelled_dfpattern_partition_rule(compiler, rule_name, dataflow_pattern)
if len(pattern_tuple) == 2:
rule_name, dataflow_pattern = pattern_tuple
return MakeLabelledDFPatternPartitionRule(compiler, rule_name, dataflow_pattern)
else:
rule_name, dataflow_pattern, predicate = tuple
return make_labelled_dfpattern_partition_rule_with_predicate(
rule_name, dataflow_pattern, predicate = pattern_tuple
return MakeLabelledDFPatternPartitionRuleWithPredicate(
compiler, rule_name, dataflow_pattern, predicate
)

Expand All @@ -143,9 +146,12 @@ def make_byoc_partition_rule(compiler):
pattern_table is not None
), f"No pattern table entry was found for BYOC compiler {compiler}"
logging.info(
f"Converting {len(pattern_table)} rules for {compiler} for use in pattern style BYOC lowering/codegen"
"Converting %s rules for %s for use in pattern style BYOC lowering/codegen",
len(pattern_table),
compiler,
)
sub_rules = [
make_labelled_dfpattern_partition_rule_wrapper(compiler, tuple) for tuple in pattern_table
make_labelled_dfpattern_partition_rule_wrapper(compiler, pattern_tuple)
for pattern_tuple in pattern_table
]
return make_pattern_byoc_partition_rule(compiler, sub_rules)
return MakePatternBYOCPartitionRule(compiler, sub_rules)
21 changes: 7 additions & 14 deletions src/relay/collage/gather_partition_specs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,20 +162,13 @@ PartitionRule MakePatternBYOCPartitionRule(const std::string& compiler,
return PrimitivePartitionRule("", std::move(valid));
}

TVM_REGISTER_GLOBAL("relay.collage.make_labelled_dfpattern_partition_rule")
.set_body_typed([](String compiler, String rule_name, DFPattern dataflow_pattern) {
return MakeLabelledDFPatternPartitionRule(std::move(compiler), std::move(rule_name),
std::move(dataflow_pattern));
});

TVM_REGISTER_GLOBAL("relay.collage.make_labelled_dfpattern_partition_rule_with_predicate")
.set_body_typed([](String compiler, String rule_name, DFPattern dataflow_pattern,
TPatternPredicate predicate) {
return MakeLabelledDFPatternPartitionRule(std::move(compiler), std::move(rule_name),
std::move(dataflow_pattern), std::move(predicate));
});

TVM_REGISTER_GLOBAL("relay.collage.make_pattern_byoc_partition_rule")
TVM_REGISTER_GLOBAL("relay.collage.MakeLabelledDFPatternPartitionRule")
.set_body_typed(MakeLabelledDFPatternPartitionRule);

TVM_REGISTER_GLOBAL("relay.collage.MakeLabelledDFPatternPartitionRuleWithPredicate")
.set_body_typed(MakeLabelledDFPatternPartitionRule);

TVM_REGISTER_GLOBAL("relay.collage.MakePatternBYOCPartitionRule")
.set_body_typed(MakePatternBYOCPartitionRule);

/*!
Expand Down

0 comments on commit 96631de

Please sign in to comment.