From 58664ac6ad0168ebd423448ed9c4330da41763f7 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Sat, 5 Mar 2022 10:19:36 -0800 Subject: [PATCH 1/4] Add permutation and space sequence spaces --- compiler_gym/service/proto/BUILD | 2 + compiler_gym/service/proto/CMakeLists.txt | 26 +++++---- compiler_gym/service/proto/__init__.py | 2 + .../service/proto/compiler_gym_service.proto | 8 +++ compiler_gym/service/proto/py_converters.py | 35 ++++++++++++ compiler_gym/spaces/BUILD | 18 ++++++ compiler_gym/spaces/CMakeLists.txt | 21 +++++++ compiler_gym/spaces/__init__.py | 4 ++ compiler_gym/spaces/permutation.py | 32 +++++++++++ compiler_gym/spaces/space_sequence.py | 57 +++++++++++++++++++ compiler_gym/util/BUILD | 1 + compiler_gym/util/CMakeLists.txt | 1 + compiler_gym/util/permutation.py | 35 ++++++++++++ tests/service/proto/py_converters_test.py | 35 ++++++++++++ tests/spaces/sequence_test.py | 11 +++- tests/util/BUILD | 10 ++++ tests/util/CMakeLists.txt | 10 ++++ tests/util/permutation_test.py | 30 ++++++++++ 18 files changed, 325 insertions(+), 13 deletions(-) create mode 100644 compiler_gym/spaces/permutation.py create mode 100644 compiler_gym/spaces/space_sequence.py create mode 100644 compiler_gym/util/permutation.py create mode 100644 tests/util/permutation_test.py diff --git a/compiler_gym/service/proto/BUILD b/compiler_gym/service/proto/BUILD index 946eecdc3..69e0fcda3 100644 --- a/compiler_gym/service/proto/BUILD +++ b/compiler_gym/service/proto/BUILD @@ -23,8 +23,10 @@ py_library( "//compiler_gym/spaces:dict", "//compiler_gym/spaces:discrete", "//compiler_gym/spaces:named_discrete", + "//compiler_gym/spaces:permutation", "//compiler_gym/spaces:scalar", "//compiler_gym/spaces:sequence", + "//compiler_gym/spaces:space_sequence", "//compiler_gym/spaces:tuple", ], ) diff --git a/compiler_gym/service/proto/CMakeLists.txt b/compiler_gym/service/proto/CMakeLists.txt index 152a8add5..15e2974ae 100644 --- a/compiler_gym/service/proto/CMakeLists.txt +++ b/compiler_gym/service/proto/CMakeLists.txt @@ -19,8 +19,10 @@ cg_py_library( compiler_gym::spaces::dict compiler_gym::spaces::discrete compiler_gym::spaces::named_discrete + compiler_gym::spaces::permutation compiler_gym::spaces::scalar compiler_gym::spaces::sequence + compiler_gym::spaces::space_sequence compiler_gym::spaces::tuple ) @@ -33,10 +35,10 @@ proto_library( ) py_proto_library( - NAME - compiler_gym_service_py - DEPS - ::compiler_gym_service + NAME + compiler_gym_service_py + DEPS + ::compiler_gym_service ) cc_proto_library( @@ -48,15 +50,15 @@ cc_proto_library( ) cc_grpc_library( - NAME compiler_gym_service_cc_grpc - SRCS ::compiler_gym_service - GRPC_ONLY - PUBLIC - DEPS ::compiler_gym_service_cc + NAME compiler_gym_service_cc_grpc + SRCS ::compiler_gym_service + GRPC_ONLY + PUBLIC + DEPS ::compiler_gym_service_cc ) py_grpc_library( - NAME "compiler_gym_service_py_grpc" - SRCS "::compiler_gym_service" - DEPS "::compiler_gym_service_py" + NAME "compiler_gym_service_py_grpc" + SRCS "::compiler_gym_service" + DEPS "::compiler_gym_service_py" ) diff --git a/compiler_gym/service/proto/__init__.py b/compiler_gym/service/proto/__init__.py index d863bfd4e..85cbb8c27 100644 --- a/compiler_gym/service/proto/__init__.py +++ b/compiler_gym/service/proto/__init__.py @@ -52,6 +52,7 @@ SendSessionParameterRequest, SessionParameter, Space, + SpaceSequenceSpace, StartSessionReply, StartSessionRequest, StepReply, @@ -126,6 +127,7 @@ "ServiceTransportError", "SessionParameter", "Space", + "SpaceSequenceSpace", "StartSessionReply", "StartSessionRequest", "StepReply", diff --git a/compiler_gym/service/proto/compiler_gym_service.proto b/compiler_gym/service/proto/compiler_gym_service.proto index 43d835b36..3d286f60d 100644 --- a/compiler_gym/service/proto/compiler_gym_service.proto +++ b/compiler_gym/service/proto/compiler_gym_service.proto @@ -294,6 +294,13 @@ message StringSpace { Int64Range length_range = 1; } +// A variable length sequence of spaces. +message SpaceSequenceSpace { + // The number of spaces in the sequence. + Int64Range length_range = 1; + Space space = 2; +} + // Can be used in Space.any_value or Event.any_value to describe an opaque // serialized data. message Opaque { @@ -326,6 +333,7 @@ message Space { FloatSequenceSpace float_sequence = 16; DoubleSequenceSpace double_sequence = 17; StringSequenceSpace string_sequence = 18; + SpaceSequenceSpace space_sequence = 25; BooleanBox boolean_box = 19; ByteBox byte_box = 20; Int64Box int64_box = 21; diff --git a/compiler_gym/service/proto/py_converters.py b/compiler_gym/service/proto/py_converters.py index ad9d6cd4d..3e04775d9 100644 --- a/compiler_gym/service/proto/py_converters.py +++ b/compiler_gym/service/proto/py_converters.py @@ -56,6 +56,7 @@ ObservationSpace, Opaque, Space, + SpaceSequenceSpace, StringSequenceSpace, StringSpace, StringTensor, @@ -65,8 +66,10 @@ from compiler_gym.spaces.dict import Dict from compiler_gym.spaces.discrete import Discrete from compiler_gym.spaces.named_discrete import NamedDiscrete +from compiler_gym.spaces.permutation import Permutation from compiler_gym.spaces.scalar import Scalar from compiler_gym.spaces.sequence import Sequence +from compiler_gym.spaces.space_sequence import SpaceSequence from compiler_gym.spaces.tuple import Tuple @@ -189,6 +192,22 @@ def convert_bytes_to_numpy(arr: bytes) -> np.ndarray: return np.frombuffer(arr, dtype=np.int8) +def convert_permutation_space_message(space: Space) -> Permutation: + if ( + space.int64_sequence.scalar_range.max + - space.int64_sequence.scalar_range.min + + 1 + != space.int64_sequence.length_range.min + or space.int64_sequence.length_range.min + != space.int64_sequence.length_range.max + ): + raise ValueError(f"Invalid permutation space message:\n{space}.") + return Permutation( + name=None, + scalar_range=convert_range_message(space.int64_sequence.scalar_range), + ) + + class NumpyToTensorMessageConverter: dtype_conversion_map: DictType[Type, Callable[[Any], Message]] @@ -434,9 +453,11 @@ def make_message_default_converter() -> Callable[[Any], Any]: conversion_map[Space] = TypeIdDispatchConverter( default_converter=SpaceMessageDefaultConverter(res), + conversion_map={"permutation": convert_permutation_space_message}, ) conversion_map[ListSpace] = ListSpaceMessageConverter(conversion_map[Space]) conversion_map[DictSpace] = DictSpaceMessageConverter(conversion_map[Space]) + conversion_map[SpaceSequenceSpace] = SpaceSequenceSpaceMessageConverter(res) conversion_map[ActionSpace] = ActionSpaceMessageConverter(res) conversion_map[ObservationSpace] = ObservationSpaceMessageConverter(res) @@ -693,6 +714,20 @@ def __call__( convert_to_sequence_space_message = ToSequenceSpaceMessageConverter() +class SpaceSequenceSpaceMessageConverter: + space_message_converter: Callable[[Space], GymSpace] + + def __init__(self, space_message_converter): + self.space_message_converter = space_message_converter + + def __call__(self, seq: SpaceSequenceSpace) -> GymSpace: + return SpaceSequence( + name=None, + space=self.space_message_converter(seq.space), + size_range=(seq.length_range.min, seq.length_range.max), + ) + + class SpaceMessageDefaultConverter: message_converter: TypeBasedConverter diff --git a/compiler_gym/spaces/BUILD b/compiler_gym/spaces/BUILD index 683272c17..66adfcc83 100644 --- a/compiler_gym/spaces/BUILD +++ b/compiler_gym/spaces/BUILD @@ -17,9 +17,11 @@ py_library( ":dict", ":discrete", ":named_discrete", + ":permutation", ":reward", ":scalar", ":sequence", + ":space_sequence", ":tuple", ], ) @@ -90,8 +92,24 @@ py_library( ], ) +py_library( + name = "space_sequence", + srcs = ["space_sequence.py"], + visibility = ["//compiler_gym:__subpackages__"], +) + py_library( name = "tuple", srcs = ["tuple.py"], visibility = ["//compiler_gym:__subpackages__"], ) + +py_library( + name = "permutation", + srcs = ["permutation.py"], + visibility = ["//compiler_gym:__subpackages__"], + deps = [ + ":scalar", + ":sequence", + ], +) diff --git a/compiler_gym/spaces/CMakeLists.txt b/compiler_gym/spaces/CMakeLists.txt index 6bba6af76..16cd54668 100644 --- a/compiler_gym/spaces/CMakeLists.txt +++ b/compiler_gym/spaces/CMakeLists.txt @@ -17,9 +17,11 @@ cg_py_library( ::dict ::discrete ::named_discrete + ::permutation ::reward ::scalar ::sequence + ::space_sequence ::tuple PUBLIC ) @@ -96,7 +98,26 @@ cg_py_library( PUBLIC ) +cg_py_library( + NAME + space_sequence + SRCS + "space_sequence.py" + PUBLIC +) + cg_py_library( NAME tuple SRCS "tuple.py" ) + +cg_py_library( + NAME + permutation + SRCS + "permutation.py" + DEPS + ::scalar + ::sequence + PUBLIC +) diff --git a/compiler_gym/spaces/__init__.py b/compiler_gym/spaces/__init__.py index 7b9b1a79b..597f798e3 100644 --- a/compiler_gym/spaces/__init__.py +++ b/compiler_gym/spaces/__init__.py @@ -7,9 +7,11 @@ from compiler_gym.spaces.dict import Dict from compiler_gym.spaces.discrete import Discrete from compiler_gym.spaces.named_discrete import NamedDiscrete +from compiler_gym.spaces.permutation import Permutation from compiler_gym.spaces.reward import DefaultRewardFromObservation, Reward from compiler_gym.spaces.scalar import Scalar from compiler_gym.spaces.sequence import Sequence +from compiler_gym.spaces.space_sequence import SpaceSequence from compiler_gym.spaces.tuple import Tuple __all__ = [ @@ -20,8 +22,10 @@ "Dict", "Discrete", "NamedDiscrete", + "Permutation", "Reward", "Scalar", "Sequence", + "SpaceSequence", "Tuple", ] diff --git a/compiler_gym/spaces/permutation.py b/compiler_gym/spaces/permutation.py new file mode 100644 index 000000000..f12f55f54 --- /dev/null +++ b/compiler_gym/spaces/permutation.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np + +from compiler_gym.spaces.scalar import Scalar +from compiler_gym.spaces.sequence import Sequence + + +class Permutation(Sequence): + def __init__(self, name: str, scalar_range: Scalar): + sz = scalar_range.max - scalar_range.min + 1 + super().__init__( + name=name, + size_range=(sz, sz), + dtype=scalar_range.dtype, + scalar_range=scalar_range, + ) + + def sample(self): + return ( + np.random.choice(self.size_range[0], size=self.size_range[1], replace=False) + + self.scalar_range.min + ) + + def __eq__(self, other) -> bool: + return ( + isinstance(self, other.__class__) + and self.name == other.name + and super().__eq__(other) + ) diff --git a/compiler_gym/spaces/space_sequence.py b/compiler_gym/spaces/space_sequence.py new file mode 100644 index 000000000..3883b2612 --- /dev/null +++ b/compiler_gym/spaces/space_sequence.py @@ -0,0 +1,57 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import Counter +from collections.abc import Collection +from typing import Optional, Tuple + +import numpy as np +from gym.spaces import Space + + +class SpaceSequence(Space): + name: str + space: Space + size_range: Tuple[int, Optional[int]] + + def __init__( + self, name: str, space: Space, size_range: Tuple[int, Optional[int]] = (0, None) + ): + self.name = name + self.space = space + self.size_range = size_range + + def contains(self, x): + if not isinstance(x, Collection): + return False + + lower_bound = self.size_range[0] + upper_bound = float("inf") if self.size_range[1] is None else self.size_range[1] + if not (lower_bound <= len(x) <= upper_bound): + return False + + for element in x: + if not self.space.contains(element): + return False + return True + + def __eq__(self, other) -> bool: + return ( + isinstance(self, other.__class__) + and self.name == other.name + and Counter(self.size_range) == Counter(other.size_range) + and self.space == other.space + ) + + def sample(self): + return [ + self.space.sample() + for _ in range( + np.random.randint( + low=self.size_range[0], + high=None if self.size_range[1] is None else self.size_range[1] + 1, + ) + ) + ] diff --git a/compiler_gym/util/BUILD b/compiler_gym/util/BUILD index a9cad9397..27f75acd1 100644 --- a/compiler_gym/util/BUILD +++ b/compiler_gym/util/BUILD @@ -21,6 +21,7 @@ py_library( "logs.py", "minimize_trajectory.py", "parallelization.py", + "permutation.py", "registration.py", "runfiles_path.py", "shell_format.py", diff --git a/compiler_gym/util/CMakeLists.txt b/compiler_gym/util/CMakeLists.txt index 5bffc6add..ccb63371d 100644 --- a/compiler_gym/util/CMakeLists.txt +++ b/compiler_gym/util/CMakeLists.txt @@ -22,6 +22,7 @@ cg_py_library( "logs.py" "minimize_trajectory.py" "parallelization.py" + "permutation.py" "registration.py" "runfiles_path.py" "shell_format.py" diff --git a/compiler_gym/util/permutation.py b/compiler_gym/util/permutation.py new file mode 100644 index 000000000..7cd4dba6b --- /dev/null +++ b/compiler_gym/util/permutation.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from numbers import Integral +from typing import List + +import numpy as np + + +def convert_number_to_permutation( + n: Integral, permutation_size: Integral +) -> List[Integral]: + m = n + res = np.zeros(permutation_size, dtype=type(permutation_size)) + elements = np.arange(permutation_size, dtype=type(permutation_size)) + for i in range(permutation_size): + j = m % (permutation_size - i) + m = m // (permutation_size - i) + res[i] = elements[j] + elements[j] = elements[permutation_size - i - 1] + return res + + +def convert_permutation_to_number(permutation: List[Integral]) -> Integral: + pos = np.arange(len(permutation), dtype=int) + elements = np.arange(len(permutation), dtype=int) + m = 1 + res = 0 + for i in range(len(permutation) - 1): + res += m * pos[permutation[i]] + m = m * (len(permutation) - i) + pos[elements[len(permutation) - i - 1]] = pos[permutation[i]] + elements[pos[permutation[i]]] = elements[len(permutation) - i - 1] + return res diff --git a/tests/service/proto/py_converters_test.py b/tests/service/proto/py_converters_test.py index 40237c85b..dc7b95cd8 100644 --- a/tests/service/proto/py_converters_test.py +++ b/tests/service/proto/py_converters_test.py @@ -40,6 +40,7 @@ NamedDiscreteSpace, Opaque, Space, + SpaceSequenceSpace, StringSpace, StringTensor, py_converters, @@ -50,8 +51,10 @@ Dict, Discrete, NamedDiscrete, + Permutation, Scalar, Sequence, + SpaceSequence, Tuple, ) from tests.test_main import main @@ -776,6 +779,23 @@ def test_convert_to_string_space(): assert converted_space.length_range.max == 2 +def test_convert_space_sequence_space(): + space = Space( + space_sequence=SpaceSequenceSpace( + length_range=Int64Range(min=0, max=2), + space=Space(int64_value=Int64Range(min=-1, max=1)), + ), + ) + converted_space = py_converters.message_default_converter(space) + assert isinstance(converted_space, SpaceSequence) + assert converted_space.size_range[0] == space.space_sequence.length_range.min + assert converted_space.size_range[1] == space.space_sequence.length_range.max + assert isinstance(converted_space.space, Scalar) + assert np.dtype(converted_space.space.dtype) == np.int64 + assert converted_space.space.min == space.space_sequence.space.int64_value.min + assert converted_space.space.max == space.space_sequence.space.int64_value.max + + def test_space_message_default_converter(): message_converter = py_converters.TypeBasedConverter( conversion_map={StringSpace: py_converters.convert_sequence_space} @@ -928,5 +948,20 @@ def default_converter(msg): ) +def test_convert_permutation_space_message(): + msg = Space( + type_id="permutation", + int64_sequence=Int64SequenceSpace( + length_range=Int64Range(min=5, max=5), scalar_range=Int64Range(min=0, max=4) + ), + ) + permutation = py_converters.message_default_converter(msg) + assert isinstance(permutation, Permutation) + assert permutation.scalar_range.min == 0 + assert permutation.scalar_range.max == 4 + assert permutation.size_range[0] == 5 + assert permutation.size_range[1] == 5 + + if __name__ == "__main__": main() diff --git a/tests/spaces/sequence_test.py b/tests/spaces/sequence_test.py index 1b5777ca2..84ae9494c 100644 --- a/tests/spaces/sequence_test.py +++ b/tests/spaces/sequence_test.py @@ -5,7 +5,7 @@ """Unit tests for //compiler_gym/spaces:sequence.""" import pytest -from compiler_gym.spaces import Scalar, Sequence +from compiler_gym.spaces import Scalar, Sequence, SpaceSequence from tests.test_main import main @@ -64,5 +64,14 @@ def test_bytes_contains(): assert not space.contains("Hello, world!") +def test_space_sequence_contains(): + subspace = Scalar(name="subspace", min=0, max=1, dtype=float) + space_seq = SpaceSequence(name="seq", space=subspace, size_range=(0, 2)) + assert space_seq.contains([0.5, 0.6]) + assert not space_seq.contains(["not-a-number"]) + assert not space_seq.contains([2.0]) + assert not space_seq.contains([0.1, 0.2, 0.3]) + + if __name__ == "__main__": main() diff --git a/tests/util/BUILD b/tests/util/BUILD index d977ecef9..68bb400f4 100644 --- a/tests/util/BUILD +++ b/tests/util/BUILD @@ -93,6 +93,16 @@ py_test( ], ) +py_test( + name = "permutation_test", + timeout = "short", + srcs = ["permutation_test.py"], + deps = [ + "//compiler_gym/util", + "//tests:test_main", + ], +) + py_test( name = "runfiles_path_test", srcs = ["runfiles_path_test.py"], diff --git a/tests/util/CMakeLists.txt b/tests/util/CMakeLists.txt index 92a607442..1d78de002 100644 --- a/tests/util/CMakeLists.txt +++ b/tests/util/CMakeLists.txt @@ -89,6 +89,16 @@ cg_py_test( tests::test_main ) +cg_py_test( + NAME + permutation_test + SRCS + "permutation_test.py" + DEPS + compiler_gym::util::util + tests::test_main +) + cg_py_test( NAME runfiles_path_test SRCS "runfiles_path_test.py" diff --git a/tests/util/permutation_test.py b/tests/util/permutation_test.py new file mode 100644 index 000000000..232eceea8 --- /dev/null +++ b/tests/util/permutation_test.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np + +import compiler_gym.util.permutation as permutation +from tests.test_main import main + + +def test_permutation_number_mapping(): + original_permutation = np.array([4, 3, 1, 5, 2, 6, 0], dtype=int) + permutation_number = permutation.convert_permutation_to_number(original_permutation) + mapped_permutation = permutation.convert_number_to_permutation( + n=permutation_number, permutation_size=len(original_permutation) + ) + assert np.array_equal(original_permutation, mapped_permutation) + + original_permutation2 = np.array([2, 0, 5, 1, 4, 6, 3], dtype=int) + permutation_number2 = permutation.convert_permutation_to_number( + original_permutation2 + ) + mapped_permutation2 = permutation.convert_number_to_permutation( + n=permutation_number2, permutation_size=len(original_permutation2) + ) + assert np.array_equal(original_permutation2, mapped_permutation2) + + +if __name__ == "__main__": + main() From 1a2d063ce2be5bd62f371a2215a68455e108a5c4 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 1 Apr 2022 19:45:53 -0700 Subject: [PATCH 2/4] Fix formatting in compiler_gym/service/proto/CMakeLists.txt --- compiler_gym/service/proto/CMakeLists.txt | 24 +++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/compiler_gym/service/proto/CMakeLists.txt b/compiler_gym/service/proto/CMakeLists.txt index 15e2974ae..fdd2cdc8d 100644 --- a/compiler_gym/service/proto/CMakeLists.txt +++ b/compiler_gym/service/proto/CMakeLists.txt @@ -35,10 +35,10 @@ proto_library( ) py_proto_library( - NAME - compiler_gym_service_py - DEPS - ::compiler_gym_service + NAME + compiler_gym_service_py + DEPS + ::compiler_gym_service ) cc_proto_library( @@ -50,15 +50,15 @@ cc_proto_library( ) cc_grpc_library( - NAME compiler_gym_service_cc_grpc - SRCS ::compiler_gym_service - GRPC_ONLY - PUBLIC - DEPS ::compiler_gym_service_cc + NAME compiler_gym_service_cc_grpc + SRCS ::compiler_gym_service + GRPC_ONLY + PUBLIC + DEPS ::compiler_gym_service_cc ) py_grpc_library( - NAME "compiler_gym_service_py_grpc" - SRCS "::compiler_gym_service" - DEPS "::compiler_gym_service_py" + NAME "compiler_gym_service_py_grpc" + SRCS "::compiler_gym_service" + DEPS "::compiler_gym_service_py" ) From 1da67b5b4f948830efc7daeb0d8115909d38c1dd Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 1 Apr 2022 19:50:14 -0700 Subject: [PATCH 3/4] Add SpaceSequenceSpace to documentation --- docs/source/rpc.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/rpc.rst b/docs/source/rpc.rst index f525a1b0d..e4bdb146b 100644 --- a/docs/source/rpc.rst +++ b/docs/source/rpc.rst @@ -179,6 +179,9 @@ Core Message Types .. doxygenstruct:: StringSequenceSpace :members: +.. doxygenstruct:: SpaceSequenceSpace + :members: + .. doxygenstruct:: StringSpace :members: From d95c95a17fa6c860a9a7d54a4411e988788cbff2 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 1 Apr 2022 21:44:57 -0700 Subject: [PATCH 4/4] Add documentation and tests --- compiler_gym/service/proto/py_converters.py | 7 +++- compiler_gym/spaces/permutation.py | 21 +++++++++--- compiler_gym/spaces/space_sequence.py | 10 ++++-- docs/source/compiler_gym/spaces.rst | 18 +++++++++++ tests/service/proto/py_converters_test.py | 9 ++++++ tests/spaces/BUILD | 10 ++++++ tests/spaces/CMakeLists.txt | 10 ++++++ tests/spaces/permutation_test.py | 36 +++++++++++++++++++++ tests/util/permutation_test.py | 2 +- 9 files changed, 113 insertions(+), 10 deletions(-) create mode 100644 tests/spaces/permutation_test.py diff --git a/compiler_gym/service/proto/py_converters.py b/compiler_gym/service/proto/py_converters.py index 3e04775d9..ca358efc4 100644 --- a/compiler_gym/service/proto/py_converters.py +++ b/compiler_gym/service/proto/py_converters.py @@ -201,7 +201,12 @@ def convert_permutation_space_message(space: Space) -> Permutation: or space.int64_sequence.length_range.min != space.int64_sequence.length_range.max ): - raise ValueError(f"Invalid permutation space message:\n{space}.") + raise ValueError( + f"Invalid permutation space message:\n{space}." + " Variable sequence length is not allowed." + " A permutation must also include all integers in its range " + "[min, min + length)." + ) return Permutation( name=None, scalar_range=convert_range_message(space.int64_sequence.scalar_range), diff --git a/compiler_gym/spaces/permutation.py b/compiler_gym/spaces/permutation.py index f12f55f54..d187bdfbf 100644 --- a/compiler_gym/spaces/permutation.py +++ b/compiler_gym/spaces/permutation.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from numbers import Integral + import numpy as np from compiler_gym.spaces.scalar import Scalar @@ -9,7 +11,20 @@ class Permutation(Sequence): + """The space of permutations of all numbers in the range `scalar_range`.""" + def __init__(self, name: str, scalar_range: Scalar): + """Constructor. + + :param name: The name of the permutation space. + :param scalar_range: Range of numbers in the permutation. + For example the scalar range [1, 3] would define permutations like + [1, 2, 3] or [2, 1, 3], etc. + + :raises TypeError: If `scalar_range.dtype` is not an integral type. + """ + if not issubclass(np.dtype(scalar_range.dtype).type, Integral): + raise TypeError("Permutation space can have integral scalar range only.") sz = scalar_range.max - scalar_range.min + 1 super().__init__( name=name, @@ -25,8 +40,4 @@ def sample(self): ) def __eq__(self, other) -> bool: - return ( - isinstance(self, other.__class__) - and self.name == other.name - and super().__eq__(other) - ) + return isinstance(self, other.__class__) and super().__eq__(other) diff --git a/compiler_gym/spaces/space_sequence.py b/compiler_gym/spaces/space_sequence.py index 3883b2612..ed6aaaaf8 100644 --- a/compiler_gym/spaces/space_sequence.py +++ b/compiler_gym/spaces/space_sequence.py @@ -12,13 +12,17 @@ class SpaceSequence(Space): - name: str - space: Space - size_range: Tuple[int, Optional[int]] + """Variable-length sequence of subspaces that have the same definition.""" def __init__( self, name: str, space: Space, size_range: Tuple[int, Optional[int]] = (0, None) ): + """Constructor. + + :param name: The name of the space. + :param space: Shared definition of the spaces in the sequence. + :param size_range: Range of the sequence length. + """ self.name = name self.space = space self.size_range = size_range diff --git a/docs/source/compiler_gym/spaces.rst b/docs/source/compiler_gym/spaces.rst index 21a5a4745..f3ace14dd 100644 --- a/docs/source/compiler_gym/spaces.rst +++ b/docs/source/compiler_gym/spaces.rst @@ -52,6 +52,15 @@ NamedDiscrete .. automethod:: __getitem__ +Permutation +-------- + +.. autoclass:: Permutation + :members: + + .. automethod:: __init__ + + Reward ------ @@ -70,6 +79,15 @@ Scalar .. automethod:: __init__ +SpaceSequence +------ + +.. autoclass:: SpaceSequence + :members: + + .. automethod:: __init__ + + Sequence -------- diff --git a/tests/service/proto/py_converters_test.py b/tests/service/proto/py_converters_test.py index dc7b95cd8..928716e34 100644 --- a/tests/service/proto/py_converters_test.py +++ b/tests/service/proto/py_converters_test.py @@ -962,6 +962,15 @@ def test_convert_permutation_space_message(): assert permutation.size_range[0] == 5 assert permutation.size_range[1] == 5 + invalid_permutation_space_msg = Space( + type_id="permutation", + int64_sequence=Int64SequenceSpace( + length_range=Int64Range(min=3, max=5), scalar_range=Int64Range(min=0, max=4) + ), + ) + with pytest.raises(ValueError, match="Invalid permutation space message"): + py_converters.message_default_converter(invalid_permutation_space_msg) + if __name__ == "__main__": main() diff --git a/tests/spaces/BUILD b/tests/spaces/BUILD index 678e73878..473a5d1b7 100644 --- a/tests/spaces/BUILD +++ b/tests/spaces/BUILD @@ -24,6 +24,16 @@ py_test( ], ) +py_test( + name = "permutation_test", + timeout = "short", + srcs = ["permutation_test.py"], + deps = [ + "//compiler_gym/spaces", + "//tests:test_main", + ], +) + py_test( name = "reward_test", timeout = "short", diff --git a/tests/spaces/CMakeLists.txt b/tests/spaces/CMakeLists.txt index 1231328b9..e7d03cfea 100644 --- a/tests/spaces/CMakeLists.txt +++ b/tests/spaces/CMakeLists.txt @@ -25,6 +25,16 @@ cg_py_test( tests::test_main ) +cg_py_test( + NAME + permutation_test + SRCS + "permutation_test.py" + DEPS + compiler_gym::spaces::spaces + tests::test_main +) + cg_py_test( NAME reward_test diff --git a/tests/spaces/permutation_test.py b/tests/spaces/permutation_test.py new file mode 100644 index 000000000..01af4edb5 --- /dev/null +++ b/tests/spaces/permutation_test.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +from compiler_gym.spaces import Permutation, Scalar +from tests.test_main import main + + +def test_invalid_scalar_range_dtype(): + with pytest.raises( + TypeError, match="Permutation space can have integral scalar range only." + ): + Permutation(name="", scalar_range=Scalar(name="", min=0, max=2, dtype=float)) + + +def test_equal(): + assert Permutation( + name="perm", scalar_range=Scalar(name="range", min=0, max=2, dtype=int) + ) == Permutation( + name="perm", scalar_range=Scalar(name="range", min=0, max=2, dtype=int) + ) + + +def test_not_equal(): + permutation = Permutation( + name="perm", scalar_range=Scalar(name="range", min=0, max=2, dtype=int) + ) + assert permutation != Permutation( + name="perm", scalar_range=Scalar(name="range", min=0, max=1, dtype=int) + ) + + +if __name__ == "__main__": + main() diff --git a/tests/util/permutation_test.py b/tests/util/permutation_test.py index 232eceea8..5950222fe 100644 --- a/tests/util/permutation_test.py +++ b/tests/util/permutation_test.py @@ -23,7 +23,7 @@ def test_permutation_number_mapping(): mapped_permutation2 = permutation.convert_number_to_permutation( n=permutation_number2, permutation_size=len(original_permutation2) ) - assert np.array_equal(original_permutation2, mapped_permutation2) + np.testing.assert_array_equal(original_permutation2, mapped_permutation2) if __name__ == "__main__":