Skip to content

Commit

Permalink
Merge pull request #645 from sogartar/perm-space-seq
Browse files Browse the repository at this point in the history
Add permutation and space sequence spaces
  • Loading branch information
ChrisCummins authored Apr 13, 2022
2 parents 61f460f + fdbbaae commit b803856
Show file tree
Hide file tree
Showing 23 changed files with 419 additions and 1 deletion.
2 changes: 2 additions & 0 deletions compiler_gym/service/proto/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ py_library(
"//compiler_gym/spaces:dict",
"//compiler_gym/spaces:discrete",
"//compiler_gym/spaces:named_discrete",
"//compiler_gym/spaces:permutation",
"//compiler_gym/spaces:scalar",
"//compiler_gym/spaces:sequence",
"//compiler_gym/spaces:space_sequence",
"//compiler_gym/spaces:tuple",
],
)
Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/service/proto/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ cg_py_library(
compiler_gym::spaces::dict
compiler_gym::spaces::discrete
compiler_gym::spaces::named_discrete
compiler_gym::spaces::permutation
compiler_gym::spaces::scalar
compiler_gym::spaces::sequence
compiler_gym::spaces::space_sequence
compiler_gym::spaces::tuple
)

Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/service/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
SendSessionParameterRequest,
SessionParameter,
Space,
SpaceSequenceSpace,
StartSessionReply,
StartSessionRequest,
StepReply,
Expand Down Expand Up @@ -125,6 +126,7 @@
"ServiceTransportError",
"SessionParameter",
"Space",
"SpaceSequenceSpace",
"StartSessionReply",
"StartSessionRequest",
"StepReply",
Expand Down
8 changes: 8 additions & 0 deletions compiler_gym/service/proto/compiler_gym_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ message StringSpace {
Int64Range length_range = 1;
}

// A variable length sequence of spaces.
message SpaceSequenceSpace {
// The number of spaces in the sequence.
Int64Range length_range = 1;
Space space = 2;
}

// Can be used in Space.any_value or Event.any_value to describe an opaque
// serialized data.
message Opaque {
Expand Down Expand Up @@ -326,6 +333,7 @@ message Space {
FloatSequenceSpace float_sequence = 16;
DoubleSequenceSpace double_sequence = 17;
StringSequenceSpace string_sequence = 18;
SpaceSequenceSpace space_sequence = 25;
BooleanBox boolean_box = 19;
ByteBox byte_box = 20;
Int64Box int64_box = 21;
Expand Down
40 changes: 40 additions & 0 deletions compiler_gym/service/proto/py_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ObservationSpace,
Opaque,
Space,
SpaceSequenceSpace,
StringSequenceSpace,
StringSpace,
StringTensor,
Expand All @@ -65,8 +66,10 @@
from compiler_gym.spaces.dict import Dict
from compiler_gym.spaces.discrete import Discrete
from compiler_gym.spaces.named_discrete import NamedDiscrete
from compiler_gym.spaces.permutation import Permutation
from compiler_gym.spaces.scalar import Scalar
from compiler_gym.spaces.sequence import Sequence
from compiler_gym.spaces.space_sequence import SpaceSequence
from compiler_gym.spaces.tuple import Tuple


Expand Down Expand Up @@ -215,6 +218,27 @@ def convert_bytes_to_numpy(arr: bytes) -> np.ndarray:
return np.frombuffer(arr, dtype=np.int8)


def convert_permutation_space_message(space: Space) -> Permutation:
if (
space.int64_sequence.scalar_range.max
- space.int64_sequence.scalar_range.min
+ 1
!= space.int64_sequence.length_range.min
or space.int64_sequence.length_range.min
!= space.int64_sequence.length_range.max
):
raise ValueError(
f"Invalid permutation space message:\n{space}."
" Variable sequence length is not allowed."
" A permutation must also include all integers in its range "
"[min, min + length)."
)
return Permutation(
name=None,
scalar_range=convert_range_message(space.int64_sequence.scalar_range),
)


class NumpyToTensorMessageConverter:
dtype_conversion_map: DictType[Type, Callable[[Any], Message]]

Expand Down Expand Up @@ -463,9 +487,11 @@ def make_message_default_converter() -> Callable[[Any], Any]:

conversion_map[Space] = TypeIdDispatchConverter(
default_converter=SpaceMessageDefaultConverter(res),
conversion_map={"permutation": convert_permutation_space_message},
)
conversion_map[ListSpace] = ListSpaceMessageConverter(conversion_map[Space])
conversion_map[DictSpace] = DictSpaceMessageConverter(conversion_map[Space])
conversion_map[SpaceSequenceSpace] = SpaceSequenceSpaceMessageConverter(res)
conversion_map[ActionSpace] = ActionSpaceMessageConverter(res)
conversion_map[ObservationSpace] = ObservationSpaceMessageConverter(res)

Expand Down Expand Up @@ -724,6 +750,20 @@ def __call__(
convert_to_sequence_space_message = ToSequenceSpaceMessageConverter()


class SpaceSequenceSpaceMessageConverter:
space_message_converter: Callable[[Space], GymSpace]

def __init__(self, space_message_converter):
self.space_message_converter = space_message_converter

def __call__(self, seq: SpaceSequenceSpace) -> GymSpace:
return SpaceSequence(
name=None,
space=self.space_message_converter(seq.space),
size_range=(seq.length_range.min, seq.length_range.max),
)


class SpaceMessageDefaultConverter:
message_converter: TypeBasedConverter

Expand Down
18 changes: 18 additions & 0 deletions compiler_gym/spaces/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ py_library(
":dict",
":discrete",
":named_discrete",
":permutation",
":reward",
":scalar",
":sequence",
":space_sequence",
":tuple",
],
)
Expand Down Expand Up @@ -90,8 +92,24 @@ py_library(
],
)

py_library(
name = "space_sequence",
srcs = ["space_sequence.py"],
visibility = ["//compiler_gym:__subpackages__"],
)

py_library(
name = "tuple",
srcs = ["tuple.py"],
visibility = ["//compiler_gym:__subpackages__"],
)

py_library(
name = "permutation",
srcs = ["permutation.py"],
visibility = ["//compiler_gym:__subpackages__"],
deps = [
":scalar",
":sequence",
],
)
21 changes: 21 additions & 0 deletions compiler_gym/spaces/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ cg_py_library(
::dict
::discrete
::named_discrete
::permutation
::reward
::scalar
::sequence
::space_sequence
::tuple
PUBLIC
)
Expand Down Expand Up @@ -96,7 +98,26 @@ cg_py_library(
PUBLIC
)

cg_py_library(
NAME
space_sequence
SRCS
"space_sequence.py"
PUBLIC
)

cg_py_library(
NAME tuple
SRCS "tuple.py"
)

cg_py_library(
NAME
permutation
SRCS
"permutation.py"
DEPS
::scalar
::sequence
PUBLIC
)
4 changes: 4 additions & 0 deletions compiler_gym/spaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from compiler_gym.spaces.dict import Dict
from compiler_gym.spaces.discrete import Discrete
from compiler_gym.spaces.named_discrete import NamedDiscrete
from compiler_gym.spaces.permutation import Permutation
from compiler_gym.spaces.reward import DefaultRewardFromObservation, Reward
from compiler_gym.spaces.scalar import Scalar
from compiler_gym.spaces.sequence import Sequence
from compiler_gym.spaces.space_sequence import SpaceSequence
from compiler_gym.spaces.tuple import Tuple

__all__ = [
Expand All @@ -20,8 +22,10 @@
"Dict",
"Discrete",
"NamedDiscrete",
"Permutation",
"Reward",
"Scalar",
"Sequence",
"SpaceSequence",
"Tuple",
]
43 changes: 43 additions & 0 deletions compiler_gym/spaces/permutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from numbers import Integral

import numpy as np

from compiler_gym.spaces.scalar import Scalar
from compiler_gym.spaces.sequence import Sequence


class Permutation(Sequence):
"""The space of permutations of all numbers in the range `scalar_range`."""

def __init__(self, name: str, scalar_range: Scalar):
"""Constructor.
:param name: The name of the permutation space.
:param scalar_range: Range of numbers in the permutation.
For example the scalar range [1, 3] would define permutations like
[1, 2, 3] or [2, 1, 3], etc.
:raises TypeError: If `scalar_range.dtype` is not an integral type.
"""
if not issubclass(np.dtype(scalar_range.dtype).type, Integral):
raise TypeError("Permutation space can have integral scalar range only.")
sz = scalar_range.max - scalar_range.min + 1
super().__init__(
name=name,
size_range=(sz, sz),
dtype=scalar_range.dtype,
scalar_range=scalar_range,
)

def sample(self):
return (
np.random.choice(self.size_range[0], size=self.size_range[1], replace=False)
+ self.scalar_range.min
)

def __eq__(self, other) -> bool:
return isinstance(self, other.__class__) and super().__eq__(other)
61 changes: 61 additions & 0 deletions compiler_gym/spaces/space_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections import Counter
from collections.abc import Collection
from typing import Optional, Tuple

import numpy as np
from gym.spaces import Space


class SpaceSequence(Space):
"""Variable-length sequence of subspaces that have the same definition."""

def __init__(
self, name: str, space: Space, size_range: Tuple[int, Optional[int]] = (0, None)
):
"""Constructor.
:param name: The name of the space.
:param space: Shared definition of the spaces in the sequence.
:param size_range: Range of the sequence length.
"""
self.name = name
self.space = space
self.size_range = size_range

def contains(self, x):
if not isinstance(x, Collection):
return False

lower_bound = self.size_range[0]
upper_bound = float("inf") if self.size_range[1] is None else self.size_range[1]
if not (lower_bound <= len(x) <= upper_bound):
return False

for element in x:
if not self.space.contains(element):
return False
return True

def __eq__(self, other) -> bool:
return (
isinstance(self, other.__class__)
and self.name == other.name
and Counter(self.size_range) == Counter(other.size_range)
and self.space == other.space
)

def sample(self):
return [
self.space.sample()
for _ in range(
np.random.randint(
low=self.size_range[0],
high=None if self.size_range[1] is None else self.size_range[1] + 1,
)
)
]
1 change: 1 addition & 0 deletions compiler_gym/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ py_library(
"logs.py",
"minimize_trajectory.py",
"parallelization.py",
"permutation.py",
"registration.py",
"runfiles_path.py",
"shell_format.py",
Expand Down
1 change: 1 addition & 0 deletions compiler_gym/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cg_py_library(
"logs.py"
"minimize_trajectory.py"
"parallelization.py"
"permutation.py"
"registration.py"
"runfiles_path.py"
"shell_format.py"
Expand Down
35 changes: 35 additions & 0 deletions compiler_gym/util/permutation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from numbers import Integral
from typing import List

import numpy as np


def convert_number_to_permutation(
n: Integral, permutation_size: Integral
) -> List[Integral]:
m = n
res = np.zeros(permutation_size, dtype=type(permutation_size))
elements = np.arange(permutation_size, dtype=type(permutation_size))
for i in range(permutation_size):
j = m % (permutation_size - i)
m = m // (permutation_size - i)
res[i] = elements[j]
elements[j] = elements[permutation_size - i - 1]
return res


def convert_permutation_to_number(permutation: List[Integral]) -> Integral:
pos = np.arange(len(permutation), dtype=int)
elements = np.arange(len(permutation), dtype=int)
m = 1
res = 0
for i in range(len(permutation) - 1):
res += m * pos[permutation[i]]
m = m * (len(permutation) - i)
pos[elements[len(permutation) - i - 1]] = pos[permutation[i]]
elements[pos[permutation[i]]] = elements[len(permutation) - i - 1]
return res
Loading

0 comments on commit b803856

Please sign in to comment.