diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 65686ff..4729bdd 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -30,4 +30,4 @@ jobs: - name: Run tests run: | - python -m unittest discover -s test -p "*.py" + python -m unittest discover -s test/inputgen -p "**/*.py" diff --git a/examples/random_seed.py b/examples/random_seed.py new file mode 100644 index 0000000..07e7d79 --- /dev/null +++ b/examples/random_seed.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from inputgen.argtuple.gen import ArgumentTupleGenerator +from inputgen.utils.random_manager import random_manager +from specdb.db import SpecDictDB + + +def main(): + # example to seed all random number generators + random_manager.seed(1729) + + spec = SpecDictDB["add.Tensor"] + op = torch.ops.aten.add.Tensor + for ix, (posargs, inkwargs, outargs) in enumerate( + ArgumentTupleGenerator(spec).gen() + ): + op(*posargs, **inkwargs, **outargs) + print( + posargs[0].shape, + posargs[0].dtype, + posargs[1].shape, + posargs[1].dtype, + inkwargs["alpha"], + ) + if ix == 1: + print(posargs[0]) + + +if __name__ == "__main__": + main() diff --git a/inputgen/argument/engine.py b/inputgen/argument/engine.py index 4935035..fc3c6f8 100644 --- a/inputgen/argument/engine.py +++ b/inputgen/argument/engine.py @@ -4,7 +4,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import random from typing import Any, List, Optional, Tuple, Union import torch @@ -13,6 +12,7 @@ from inputgen.attribute.model import Attribute from inputgen.attribute.solve import AttributeSolver from inputgen.specs.model import Constraint, ConstraintSuffix +from inputgen.utils.random_manager import random_manager as rm from inputgen.variable.type import ScalarDtype @@ -60,7 +60,9 @@ def gen_structure_with_depth_and_length( yield from self.gen_structure_with_depth(depth, focus, length) return - focus_ixs = range(length) if focus == attr else (random.choice(range(length)),) + focus_ixs = ( + range(length) if focus == attr else (rm.get_random().choice(range(length)),) + ) for focus_ix in focus_ixs: values = [()] for ix in range(length): @@ -241,7 +243,7 @@ def gen_value_spaces(self, focus, dtype, struct): if focus == Attribute.VALUE: return [v.space for v in variables] else: - return [random.choice(variables).space] + return [rm.get_random().choice(variables).space] def gen(self, focus): # TODO(mcandales): Enable Tensor List generation diff --git a/inputgen/argument/gen.py b/inputgen/argument/gen.py index f351f56..6be178d 100644 --- a/inputgen/argument/gen.py +++ b/inputgen/argument/gen.py @@ -9,6 +9,7 @@ import torch from inputgen.argument.engine import MetaArg +from inputgen.utils.random_manager import random_manager from inputgen.variable.gen import VariableGenerator from inputgen.variable.space import VariableSpace from torch.testing._internal.common_dtype import floating_types, integral_types @@ -41,6 +42,8 @@ def gen(self): ) def get_random_tensor(self, size, dtype, high=None, low=None): + torch_rng = random_manager.get_torch() + if low is None and high is None: low = -100 high = 100 @@ -55,7 +58,9 @@ def get_random_tensor(self, size, dtype, high=None, low=None): elif not self.space.contains(1): return torch.full(size, False, dtype=dtype) else: - return torch.randint(low=0, high=2, size=size, dtype=dtype) + return torch.randint( + low=0, high=2, size=size, dtype=dtype, generator=torch_rng + ) if dtype in integral_types(): low = math.ceil(low) @@ -68,16 +73,38 @@ def get_random_tensor(self, size, dtype, high=None, low=None): if dtype == torch.uint8: if not self.space.contains(0): - return torch.randint(low=max(1, low), high=high, size=size, dtype=dtype) + return torch.randint( + low=max(1, low), + high=high, + size=size, + dtype=dtype, + generator=torch_rng, + ) else: - return torch.randint(low=max(0, low), high=high, size=size, dtype=dtype) + return torch.randint( + low=max(0, low), + high=high, + size=size, + dtype=dtype, + generator=torch_rng, + ) - t = torch.randint(low=low, high=high, size=size, dtype=dtype) + t = torch.randint( + low=low, high=high, size=size, dtype=dtype, generator=torch_rng + ) if not self.space.contains(0): if high > 0: - pos = torch.randint(low=max(1, low), high=high, size=size, dtype=dtype) + pos = torch.randint( + low=max(1, low), + high=high, + size=size, + dtype=dtype, + generator=torch_rng, + ) else: - pos = torch.randint(low=low, high=0, size=size, dtype=dtype) + pos = torch.randint( + low=low, high=0, size=size, dtype=dtype, generator=torch_rng + ) t = torch.where(t == 0, pos, t) if dtype in integral_types(): diff --git a/inputgen/attribute/engine.py b/inputgen/attribute/engine.py index bf27a4d..0c6ed3c 100644 --- a/inputgen/attribute/engine.py +++ b/inputgen/attribute/engine.py @@ -12,7 +12,7 @@ from inputgen.attribute.solve import AttributeSolver from inputgen.specs.model import Constraint from inputgen.variable.gen import VariableGenerator -from inputgen.variable.type import ScalarDtype +from inputgen.variable.type import ScalarDtype, sort_values_of_type class AttributeEngine(AttributeSolver): @@ -51,4 +51,4 @@ def gen(self, focus: Attribute, *args): if len(vals) == 0: vals = VariableGenerator(variable.space).gen(num) gen_vals.update(vals) - return gen_vals + return sort_values_of_type(self.vtype, gen_vals) diff --git a/inputgen/utils/__init__.py b/inputgen/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/inputgen/utils/random_manager.py b/inputgen/utils/random_manager.py new file mode 100644 index 0000000..12046cd --- /dev/null +++ b/inputgen/utils/random_manager.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random + +import torch + + +class RandomManager: + def __init__(self): + self._rng = random.Random() + self._torch_rng = torch.Generator() + + def seed(self, seed): + """ + Seeds the random number generators for random and torch. + """ + self._rng.seed(seed) + self._torch_rng.manual_seed(seed) + + def get_random(self): + # self._rng.seed(42) + return self._rng + + def get_torch(self): + # self._torch_rng.manual_seed(42) + return self._torch_rng + + +random_manager = RandomManager() diff --git a/inputgen/variable/gen.py b/inputgen/variable/gen.py index 9fd0342..5398076 100644 --- a/inputgen/variable/gen.py +++ b/inputgen/variable/gen.py @@ -5,11 +5,12 @@ # LICENSE file in the root directory of this source tree. import math -import random from typing import Any, List, Optional, Set, Union +from inputgen.utils.random_manager import random_manager as rm from inputgen.variable.constants import BOUND_ON_INF, INT64_MAX, INT64_MIN from inputgen.variable.space import Interval, Intervals, VariableSpace +from inputgen.variable.type import sort_values_of_type from inputgen.variable.utils import nextdown, nextup @@ -51,7 +52,7 @@ def gen_float_from_interval(r: Interval) -> Optional[float]: elif lower > upper: return None else: - return random.uniform(lower, upper) + return rm.get_random().uniform(lower, upper) def gen_min_float_from_intervals(rs: Intervals) -> Optional[float]: @@ -69,7 +70,7 @@ def gen_max_float_from_intervals(rs: Intervals) -> Optional[float]: def gen_float_from_intervals(rs: Intervals) -> Optional[float]: if rs.empty(): return None - r = random.choice(rs.intervals) + r = rm.get_random().choice(rs.intervals) return gen_float_from_interval(r) @@ -112,7 +113,7 @@ def gen_int_from_interval(r: Interval) -> Optional[int]: elif upper is None: upper = max(lower, 0) + BOUND_ON_INF assert lower is not None and upper is not None - return random.randint(lower, upper) + return rm.get_random().randint(lower, upper) def gen_min_int_from_intervals(rs: Intervals) -> Optional[int]: @@ -133,7 +134,7 @@ def gen_int_from_intervals(rs: Intervals) -> Optional[int]: intervals_with_ints = [r for r in rs.intervals if r.contains_int()] if len(intervals_with_ints) == 0: return None - r = random.choice(intervals_with_ints) + r = rm.get_random().choice(intervals_with_ints) return gen_int_from_interval(r) @@ -147,6 +148,12 @@ def __init__(self, space: VariableSpace): self.vtype = space.vtype self.space = space + def _sorted(self, values: Set[Any]) -> List[Any]: + return sort_values_of_type(self.vtype, values) + + def _sample(self, values: Set[Any], num: int) -> List[Any]: + return rm.get_random().sample(self._sorted(values), num) + def gen_min(self) -> Any: """Returns the minimum value of the space.""" if self.space.empty() or self.vtype not in [bool, int, float]: @@ -221,7 +228,7 @@ def gen_edges_non_extreme(self, num: int = 2) -> Set[Any]: edges_not_extreme = self.gen_edges() - self.gen_extremes() if num >= len(edges_not_extreme): return edges_not_extreme - return set(random.sample(list(edges_not_extreme), num)) + return set(self._sample(edges_not_extreme, num)) def gen_non_edges(self, num: int = 2) -> Set[Any]: """Generates non-edge (or interior) values of the space.""" @@ -232,7 +239,7 @@ def gen_non_edges(self, num: int = 2) -> Set[Any]: if self.space.discrete.initialized: vals = self.space.discrete.values - edge_or_extreme_vals if num < len(vals): - vals = set(random.sample(list(vals), num)) + vals = set(self._sample(vals, num)) else: for _ in range(100): v: Optional[Union[int, float]] = None @@ -269,11 +276,8 @@ def gen_balanced(self, num: int = 6) -> Set[Any]: if num >= len(balanced): return balanced - return set(random.sample(list(balanced), num)) + return set(self._sample(balanced, num)) def gen(self, num: int = 6) -> List[Any]: """Generates a sorted (if applicable), balanced sample of the space.""" - vals = list(self.gen_balanced(num)) - if self.vtype in [bool, int, float, str]: - return sorted(vals) - return vals + return sort_values_of_type(self.vtype, self.gen_balanced(num)) diff --git a/inputgen/variable/type.py b/inputgen/variable/type.py index f723d75..0a5f124 100644 --- a/inputgen/variable/type.py +++ b/inputgen/variable/type.py @@ -6,7 +6,7 @@ import math from enum import Enum -from typing import Any +from typing import Any, List, Set import torch @@ -93,3 +93,13 @@ def convert_to_vtype(vtype: type, v: Any) -> Any: if vtype == float: return float(v) return v + + +def sort_values_of_type(vtype: type, values: Set[Any]) -> List[Any]: + if vtype in [bool, int, float, str, tuple]: + return sorted(values) + if vtype == torch.dtype: + return [v for v in SUPPORTED_TENSOR_DTYPES if v in values] + if vtype == ScalarDtype: + return [v for v in ScalarDtype if v in values] + return list(values)