Skip to content

Commit

Permalink
[AutoTVM] Introducing multi_filter into ConfigSpace autotvm
Browse files Browse the repository at this point in the history
  • Loading branch information
Icemist committed Sep 11, 2022
1 parent a96bda4 commit 5da9a68
Show file tree
Hide file tree
Showing 14 changed files with 654 additions and 224 deletions.
277 changes: 271 additions & 6 deletions python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

import itertools
import functools
import logging
import math
from collections import namedtuple, OrderedDict
from random import randrange
import numpy as np

from tvm.te import schedule, thread_axis
Expand Down Expand Up @@ -665,13 +667,17 @@ def __init__(self):
self.space_map = OrderedDict() # name -> space
self._collect = True
self._length = None
self._range_length = None
self._dims = None
self._entity_map = OrderedDict() # name -> entity
self._constraints = []
self.errors = []
self.code_hash = None
self.flop = 0
self.cost = None
self.is_fallback = False
self._shared_filter = None
self._shared_filter_cache = None

@staticmethod
def axis(var):
Expand Down Expand Up @@ -725,7 +731,6 @@ def define_split(self, name, axis, policy="factors", **kwargs):
>>> # use a filter that only accepts the split scheme whose inner most tile is less then 4
>>> cfg.define_split('tile_y', y, policy='factors', filter=lambda x: x.size[-1] <= 4)
"""

axes = [axis]
return self._add_new_transform(SplitSpace, name, axes, policy, **kwargs)

Expand Down Expand Up @@ -822,11 +827,254 @@ def valid(self):
"""
return not bool(self.errors)

def is_index_valid(self, index):
"""Checks if the index satisfies the multi_filter condition
Parameters
----------
index: int
index from the range of the space
Returns
-------
valid: bool
whether the index meets all the constraints
"""
assert 0 <= index < self.range_length
if self._shared_filter is None:
return True
if self._shared_filter_cache is None:
self._make_shared_filter_cache()
return self._shared_filter_cache[index]

def multi_filter(self, filter):
"""Keeps a function as a multi_filter
Parameters
----------
filter: function
predicate with one argument
Examples
--------
>>> # use custom candidates
>>> cfg.multi_filter(
>>> filter=lambda entity: 32
>>> <= (entity["tile_fc"].size[2] * entity["tile_y"].size[2] * entity["tile_x"].size[2])
>>> < 1024
>>> )
"""
if self._collect:
self.clear_cache()
self._shared_filter = filter

@property
def range_length(self):
"""Length of the index range in the space"""
if self._range_length is None:
self._range_length = int(np.prod([len(x) for x in self.space_map.values()]))
return self._range_length

@property
def dims(self):
"""Dimensions in the space"""
if self._dims is None:
self._dims = [len(x) for x in self.space_map.values()]
return self._dims

def subrange_length(self, start, end):
"""Returns the number of valid indexes within the limited range from [start, end]
Parameters
----------
start: int
start of subrange, inclusive
end: int
end of subrange, exclusive
Returns
-------
count: int
number of valid indexes
"""
assert 0 <= start <= end <= self.range_length
if self._shared_filter is None:
return end - start
if self._shared_filter_cache is None:
self._make_shared_filter_cache()
return self._shared_filter_cache[start:end].count(True)

def get_rand_index(self, start=None, end=None, to_exclude=[]):
"""Returns a random valid index unlisted to exclusion
Parameters
----------
start: int
specifying at which position to start, inclusive
end: int
specifying at which position to end, exclusive
to_exclude: list
determines unsuitable values
Returns
-------
rand: int
random index in the space
.. note::
Excluding all valid space indexes will lead to an infinite loop.
"""
start = start or 0
end = end or self.range_length
while True:
index = randrange(start, end)
if index not in to_exclude and self.is_index_valid(index):
return index

def get_next_index(self, index, n=1, start=None, end=None):
"""Returns the nth valid next index or None if out of range
Parameters
----------
index: int
specifying at which position to start, inclusive
n: int
step by using to find the next index, for the opposite
direction a negative number should be used
start: list
start of subrange, inclusive
end: list
end of subrange, exclusive
Returns
-------
next: int
next index in the space
"""
assert n != 0
start = start or 0
end = end or self.range_length
if self._shared_filter is None:
index += n
if start <= index < end:
return index
return None
trend = 1 if n > 0 else -1
counter = abs(n)
while counter != 0:
index += trend
if index < start or index >= end:
return None
if self.is_index_valid(index):
counter -= 1
return index

def clear_cache(self):
"""Clears the cache of index validity"""
del self._shared_filter_cache
self._length = None
self._range_length = None
self._shared_filter_cache = None

def _make_shared_filter_cache(self):
def apply(t):
entities = OrderedDict()
for name, space in self.space_map.items():
entities[name] = space[t % len(space)]
t //= len(space)
return bool(self._shared_filter(entities))

self._shared_filter_cache = tuple(apply(i) for i in range(self.range_length))
self._length = self._shared_filter_cache.count(True)

def point2knob(self, point):
"""Convert point form (single integer) to knob (vector)
Parameters
----------
point: int
point to convert
Returns
-------
knob: list
knob representation of the point
"""
knob = []
for dim in self.dims:
knob.append(point % dim)
point //= dim
return knob

def knob2point(self, knob):
"""Convert knob form (vector) to point form (single integer)
Parameters
----------
knob: list
knob to convert
Returns
-------
point: int
point of the knob representation
"""
point = 0
for j, k in enumerate(knob):
point += int(np.prod(self.dims[:j])) * k
return point

def sample_ints(self, m):
"""
Sample m different integer numbers from [0, self.range_length) without replacement
This function is an alternative of `np.random.choice` when self.range_length > 2 ^ 32, in
which case numpy does not work.
Parameters
----------
m: int
The number of sampled int
Returns
-------
ints: an numpy array of size m
"""
assert m <= len(self)
vis = set()
while len(vis) < m:
new = randrange(0, self.range_length)
if self.is_index_valid(new):
vis.add(new)
return np.fromiter(vis, int, len(vis))

def random_walk(self, point):
"""random walk as local transition
Parameters
----------
point: int
index of the ConfigEntity
Returns
-------
new_point: int
new neighborhood index
"""
# transform to knob form
old_knob = self.point2knob(point)
new_knob = old_knob.copy()
new_point = self.knob2point(new_knob)
# mutate
while new_knob == old_knob or not self.is_index_valid(new_point):
from_i = np.random.randint(len(old_knob))
to_v = np.random.randint(self.dims[from_i])
new_knob[from_i] = to_v
new_point = self.knob2point(new_knob)
# transform to index form
return new_point

def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
"""Add a new transform space in template"""
# if we do not have tuned info (_collect == True) but defined KNOB value
# for "default" scheduling before call of _add_new_transform, in this case
# no need to create new space and override previously pointed KNOB values
if kwargs.get("filter"):
self.clear_cache()
if self._collect and not (self.is_fallback and name in self._entity_map):
# convert schedule axis to space definition axis
axes = [x if isinstance(x, (VirtualAxis, Axis)) else self.axis(x) for x in axes]
Expand All @@ -839,8 +1087,11 @@ def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
return [Axis(None, i) for i in range(space_class.get_num_output(axes, policy, **kwargs))]

def __len__(self):
if self._length is None:
self._length = int(np.prod([len(x) for x in self.space_map.values()]))
"""Returns the number of valid indexes in the space"""
if self._shared_filter is None:
return self.range_length
if self._shared_filter_cache is None:
self._make_shared_filter_cache()
return self._length

def get(self, index):
Expand All @@ -850,9 +1101,21 @@ def get(self, index):
----------
index: int
index in the space
Returns
-------
config: ConfigEntity
config corresponds to the index
"""
if index < 0 or index >= len(self):
raise IndexError("Index out of range: size {}, got index {}".format(len(self), index))
if index < 0 or index >= self.range_length:
raise IndexError(
"Index out of range: size {}, got index {}".format(self.range_length, index)
)
if not self.is_index_valid(index):
raise IndexError(
"Index does not correspond to the multi-filter condition, got index {}. "
"Use is_index_valid to pre-check".format(index)
)
entities = OrderedDict()
t = index
for name, space in self.space_map.items():
Expand All @@ -876,7 +1139,9 @@ def __getitem__(self, name):
return self._entity_map[name]

def __repr__(self):
res = "ConfigSpace (len=%d, space_map=\n" % len(self)
res = "ConfigSpace (len={}, range_length={}, space_map=\n".format(
len(self), self.range_length
)
for i, (name, space) in enumerate(self.space_map.items()):
res += " %2d %s: %s\n" % (i, name, space)
return res + ")"
Expand Down
Loading

0 comments on commit 5da9a68

Please sign in to comment.