Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTVM] Introducing multi_filter into ConfigSpace autotvm #12545

Merged
merged 2 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
330 changes: 321 additions & 9 deletions python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import functools
import math
from collections import namedtuple, OrderedDict
from random import randrange
import numpy as np

from tvm.te import schedule, thread_axis
Expand Down Expand Up @@ -665,13 +666,17 @@ def __init__(self):
self.space_map = OrderedDict() # name -> space
self._collect = True
self._length = None
self._range_length = None
self._dims = None
self._entity_map = OrderedDict() # name -> entity
self._constraints = []
self.errors = []
self.code_hash = None
self.flop = 0
self.cost = None
self.is_fallback = False
self._shared_filter = None
self._shared_filter_cache = None

@staticmethod
def axis(var):
Expand Down Expand Up @@ -714,18 +719,19 @@ def define_split(self, name, axis, policy="factors", **kwargs):
the total number of axis after split (`int`).
``no_tail``:
should we only include divisible numbers as split factors (`bool`).
`candidate``:
``candidate``:
(policy=candidate) manual candidate list (`List`).

Examples
--------
>>> # use custom candidates
>>> cfg.define_split('tile_x', x, policy='candidate', candidate=[[1, 4, 4], [4, 1, 4]])
>>> cfg.define_split('tile_x', x, policy='candidate', num_outputs=3,
>>> candidate=[[1, 4, 4], [4, 1, 4]])

>>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
>>> cfg.define_split('tile_y', y, policy='factors', filter=lambda x: x.size[-1] <= 4)
>>> cfg.define_split('tile_y', y, policy='factors', num_outputs=3,
>>> filter=lambda x: x.size[-1] <= 4)
"""

axes = [axis]
return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)

Expand Down Expand Up @@ -822,11 +828,300 @@ def valid(self):
"""
return not bool(self.errors)

def is_index_valid(self, index):
"""Checks if the index satisfies the multi_filter condition

Parameters
----------
index: int
index from the range of the space

Returns
-------
valid: bool
whether the index meets all the constraints
"""
assert 0 <= index < self.range_length
if self._shared_filter is None:
return True
if self._shared_filter_cache is None:
self._make_shared_filter_cache()
return self._shared_filter_cache[index]

def multi_filter(self, filter): # pylint: disable=redefined-builtin
"""The filter can restrict combination of parameters in difference to the knob filter,
that restricts only single parameter

Parameters
----------
filter: function
predicate with one argument (Callable[[int], bool])

.. note::

Using this filter causes additional restrictions on the use of __len__.
Normally, it define the count of valid indexes and the range of space, but when
multi_filter enabled, it requires to use __len__ for getting the count of valid
indexes or range_length for the range of space. It is recommended to use:
``is_index_valid``, ``get_next_index``, ``get_rand_index`` to bypass the space

Examples
--------
>>> # Pre-requisites
>>> candidates = [[16, 64], [32, 32], [64, 16]]
>>> filter = lambda v: v.size[0] != 16
>>> multi_filter = lambda e: (e["tile_x"].size[0] + e["tile_y"].size[0]) <= 64

>>> # Case 1 - without filtering
>>> cfg.define_split("tile_x", x, num_outputs=2, policy="candidate", candidate=candidates)
>>> cfg.define_split("tile_y", y, num_outputs=2, policy="candidate", candidate=candidates)
>>> # [('tile_x', [16, 64]), ('tile_y', [16, 64])],None,0
>>> # [('tile_x', [32, 32]), ('tile_y', [16, 64])],None,1
>>> # [('tile_x', [64, 16]), ('tile_y', [16, 64])],None,2
>>> # [('tile_x', [16, 64]), ('tile_y', [32, 32])],None,3
>>> # [('tile_x', [32, 32]), ('tile_y', [32, 32])],None,4
>>> # [('tile_x', [64, 16]), ('tile_y', [32, 32])],None,5
>>> # [('tile_x', [16, 64]), ('tile_y', [64, 16])],None,6
>>> # [('tile_x', [32, 32]), ('tile_y', [64, 16])],None,7
>>> # [('tile_x', [64, 16]), ('tile_y', [64, 16])],None,8

>>> # Case 2 - with filter
>>> cfg.define_split("tile_x", x, num_outputs=2, policy="candidate", candidate=candidates,
>>> filter=filter)
>>> cfg.define_split("tile_y", y, num_outputs=2, policy="candidate", candidate=candidates,
>>> filter=filter)
>>> # [('tile_x', [32, 32]), ('tile_y', [32, 32])],None,0
>>> # [('tile_x', [64, 16]), ('tile_y', [32, 32])],None,1
>>> # [('tile_x', [32, 32]), ('tile_y', [64, 16])],None,2
>>> # [('tile_x', [64, 16]), ('tile_y', [64, 16])],None,3

>>> # Case 3 - with filter and multi_filter
>>> cfg.define_split("tile_x", x, num_outputs=2, policy="candidate", candidate=candidates,
>>> filter=filter)
>>> cfg.define_split("tile_y", y, num_outputs=2, policy="candidate", candidate=candidates,
>>> filter=filter)
>>> cfg.multi_filter(filter=multi_filter)
>>> # [('tile_x', [32, 32]), ('tile_y', [32, 32])],None,0
"""
if self._collect:
self.clear_cache()
self._shared_filter = filter

@property
def range_length(self):
"""Length of the index range in the space"""
if self._range_length is None:
self._range_length = int(np.prod([len(x) for x in self.space_map.values()]))
return self._range_length

@property
def dims(self):
"""Dimensions in the space"""
if self._dims is None:
self._dims = [len(x) for x in self.space_map.values()]
return self._dims

def subrange_length(self, start, end):
"""Returns the number of valid indexes within the limited range from [start, end]

Parameters
----------
start: int
start of subrange, inclusive
end: int
end of subrange, exclusive

Returns
-------
count: int
number of valid indexes
"""
assert 0 <= start <= end <= self.range_length
if self._shared_filter is None:
return end - start
if self._shared_filter_cache is None:
self._make_shared_filter_cache()
return self._shared_filter_cache[start:end].count(True)

def get_rand_index(self, start=None, end=None, to_exclude=None):
"""Returns a random valid index unlisted to exclusion

Parameters
----------
start: int, optional
specifying at which position to start, inclusive
end: int, optional
specifying at which position to end, exclusive
to_exclude: list, optional
determines unsuitable values

Returns
-------
rand: int
random index in the space

.. note::

Excluding all valid space indexes will lead to an infinite loop.

"""
start = start or 0
end = end or self.range_length
while True:
index = randrange(start, end)
if self.is_index_valid(index) and index not in (to_exclude or []):
return index

def get_next_index(self, index, n=1, start=None, end=None):
"""Returns the nth valid next index or None if out of range

Parameters
----------
index: int
specifying at which position to start, inclusive
n: int, optional
step by using to find the next index, for the opposite
direction a negative number should be used
start: list, optional
start of subrange, inclusive
end: list, optional
end of subrange, exclusive

Returns
-------
next: int
next index in the space
"""
assert n != 0
start = start or 0
end = end or self.range_length
if self._shared_filter is None:
index += n
if start <= index < end:
return index
return None
trend = 1 if n > 0 else -1
counter = abs(n)
while counter != 0:
index += trend
if index < start or index >= end:
return None
if self.is_index_valid(index):
counter -= 1
return index

def clear_cache(self):
"""Clears the cache of index validity"""
del self._shared_filter_cache
self._dims = None
self._length = None
self._range_length = None
self._shared_filter_cache = None

def _make_shared_filter_cache(self):
def apply(t):
entities = OrderedDict()
for name, space in self.space_map.items():
entities[name] = space[t % len(space)]
t //= len(space)
return bool(self._shared_filter(entities))

self._shared_filter_cache = tuple(apply(i) for i in range(self.range_length))
self._length = self._shared_filter_cache.count(True)

def point2knob(self, point):
"""Convert point form (single integer) to knob (vector)

Parameters
----------
point: int
point to convert

Returns
-------
knob: list
knob representation of the point
"""
knob = []
for dim in self.dims:
knob.append(point % dim)
point //= dim
return knob

def knob2point(self, knob):
"""Convert knob form (vector) to point form (single integer)

Parameters
----------
knob: list
knob to convert

Returns
-------
point: int
point of the knob representation
"""
point = 0
for j, k in enumerate(knob):
point += int(np.prod(self.dims[:j])) * k
return point

def sample_ints(self, m):
"""
Sample m different integer numbers from [0, self.range_length) without replacement
This function is an alternative of `np.random.choice` when self.range_length > 2 ^ 32, in
which case numpy does not work.

Parameters
----------
m: int
The number of sampled int

Returns
-------
ints: an numpy array of size m
"""
assert m <= len(self)
vis = set()
while len(vis) < m:
new = randrange(0, self.range_length)
if self.is_index_valid(new):
vis.add(new)
return np.fromiter(vis, int, len(vis))

def random_walk(self, point):
"""random walk as local transition

Parameters
----------
point: int
index of the ConfigEntity

Returns
-------
new_point: int
new neighborhood index
"""
# transform to knob form
old_knob = self.point2knob(point)
new_knob = old_knob.copy()
new_point = self.knob2point(new_knob)
# mutate
while new_knob == old_knob or not self.is_index_valid(new_point):
from_i = np.random.randint(len(old_knob))
to_v = np.random.randint(self.dims[from_i])
new_knob[from_i] = to_v
new_point = self.knob2point(new_knob)
# transform to index form
return new_point

def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
"""Add a new transform space in template"""
# if we do not have tuned info (_collect == True) but defined KNOB value
# for "default" scheduling before call of _add_new_transform, in this case
# no need to create new space and override previously pointed KNOB values
if kwargs.get("filter"):
self.clear_cache()
if self._collect and not (self.is_fallback and name in self._entity_map):
# convert schedule axis to space definition axis
axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x) for x in axes]
Expand All @@ -839,8 +1134,11 @@ def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
return [Axis(None, i) for i in range(space_class.get_num_output(axes, policy, **kwargs))]

def __len__(self):
Icemist marked this conversation as resolved.
Show resolved Hide resolved
if self._length is None:
self._length = int(np.prod([len(x) for x in self.space_map.values()]))
"""Returns the number of valid indexes in the space"""
if self._shared_filter is None:
return self.range_length
if self._shared_filter_cache is None:
self._make_shared_filter_cache()
return self._length

def get(self, index):
Expand All @@ -850,9 +1148,21 @@ def get(self, index):
----------
index: int
index in the space

Returns
-------
config: ConfigEntity
config corresponds to the index
"""
if index < 0 or index >= len(self):
raise IndexError("Index out of range: size {}, got index {}".format(len(self), index))
if index < 0 or index >= self.range_length:
raise IndexError(
"Index out of range: size {}, got index {}".format(self.range_length, index)
)
if not self.is_index_valid(index):
raise IndexError(
"Index does not correspond to the multi-filter condition, got index {}. "
"Use is_index_valid to pre-check".format(index)
)
entities = OrderedDict()
t = index
for name, space in self.space_map.items():
Expand All @@ -876,7 +1186,9 @@ def __getitem__(self, name):
return self._entity_map[name]

def __repr__(self):
res = "ConfigSpace (len=%d, space_map=\n" % len(self)
res = "ConfigSpace (len={}, range_length={}, space_map=\n".format(
len(self), self.range_length
)
for i, (name, space) in enumerate(self.space_map.items()):
res += " %2d %s: %s\n" % (i, name, space)
return res + ")"
Expand Down
Loading