Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Introduce new OldAPIStack decorator; Do-over of all API decorators #43657

Merged
merged 7 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions rllib/connectors/action/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from ray.rllib.connectors.registry import register_connector
from ray.rllib.utils.spaces.space_utils import clip_action, get_base_struct_from_space
from ray.rllib.utils.typing import ActionConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack


@PublicAPI(stability="alpha")
@OldAPIStack
class ClipActionsConnector(ActionConnector):
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/action/immutable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from ray.rllib.connectors.registry import register_connector
from ray.rllib.utils.numpy import make_action_immutable
from ray.rllib.utils.typing import ActionConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack


@PublicAPI(stability="alpha")
@OldAPIStack
class ImmutableActionsConnector(ActionConnector):
def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
assert isinstance(
Expand Down
6 changes: 3 additions & 3 deletions rllib/connectors/action/lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
StateBatches,
TensorStructType,
)
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack


@PublicAPI(stability="alpha")
@OldAPIStack
def register_lambda_action_connector(
name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType]
) -> Type[ActionConnector]:
Expand Down Expand Up @@ -64,7 +64,7 @@ def from_state(ctx: ConnectorContext, params: Any):


# Convert actions and states into numpy arrays if necessary.
ConvertToNumpyConnector = PublicAPI(stability="alpha")(
ConvertToNumpyConnector = OldAPIStack(
register_lambda_action_connector(
"ConvertToNumpyConnector",
lambda actions, states, fetches: (
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/action/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
unsquash_action,
)
from ray.rllib.utils.typing import ActionConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack


@PublicAPI(stability="alpha")
@OldAPIStack
class NormalizeActionsConnector(ActionConnector):
def __init__(self, ctx: ConnectorContext):
super().__init__(ctx)
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/action/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
ConnectorPipeline,
)
from ray.rllib.connectors.registry import get_connector, register_connector
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.utils.typing import ActionConnectorDataType
from ray.util.annotations import PublicAPI
from ray.util.timer import _Timer


logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
@OldAPIStack
class ActionConnectorPipeline(ConnectorPipeline, ActionConnector):
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
super().__init__(ctx, connectors)
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/agent/clip_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from ray.rllib.connectors.registry import register_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import AgentConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack


@PublicAPI(stability="alpha")
@OldAPIStack
class ClipRewardAgentConnector(AgentConnector):
def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
super().__init__(ctx)
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/agent/env_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.utils.typing import AgentConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack


@PublicAPI(stability="alpha")
@OldAPIStack
class EnvSamplingAgentConnector(AgentConnector):
def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
super().__init__(ctx)
Expand Down
8 changes: 4 additions & 4 deletions rllib/connectors/agent/lambdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
AgentConnectorDataType,
AgentConnectorsOutput,
)
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack


@PublicAPI(stability="alpha")
@OldAPIStack
def register_lambda_agent_connector(
name: str, fn: Callable[[Any], Any]
) -> Type[AgentConnector]:
Expand Down Expand Up @@ -54,7 +54,7 @@ def from_state(ctx: ConnectorContext, params: Any):
return LambdaAgentConnector


@PublicAPI(stability="alpha")
@OldAPIStack
def flatten_data(data: AgentConnectorsOutput):
assert isinstance(
data, AgentConnectorsOutput
Expand All @@ -81,6 +81,6 @@ def flatten_data(data: AgentConnectorsOutput):

# Agent connector to build and return a flattened observation SampleBatch
# in addition to the original input dict.
FlattenDataAgentConnector = PublicAPI(stability="alpha")(
FlattenDataAgentConnector = OldAPIStack(
register_lambda_agent_connector("FlattenDataAgentConnector", flatten_data)
)
6 changes: 3 additions & 3 deletions rllib/connectors/agent/mean_std_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
)
from ray.rllib.connectors.registry import register_connector
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.filter import MeanStdFilter, ConcurrentMeanStdFilter
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import AgentConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.filter import RunningStat


@PublicAPI(stability="alpha")
@OldAPIStack
class MeanStdObservationFilterAgentConnector(SyncedFilterAgentConnector):
"""A connector used to mean-std-filter observations.

Expand Down Expand Up @@ -149,7 +149,7 @@ def sync(self, other: "AgentConnector") -> None:
return self.filter.sync(other.filter)


@PublicAPI(stability="alpha")
@OldAPIStack
class ConcurrentMeanStdObservationFilterAgentConnector(
MeanStdObservationFilterAgentConnector
):
Expand Down
7 changes: 2 additions & 5 deletions rllib/connectors/agent/obs_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@
from ray.rllib.models.preprocessors import get_preprocessor, NoPreprocessor
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.typing import AgentConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack


# Bridging between current obs preprocessors and connector.
# We should not introduce any new preprocessors.
# TODO(jungong) : migrate and implement preprocessor library in Connector framework.
@PublicAPI(stability="alpha")
@OldAPIStack
class ObsPreprocessorConnector(AgentConnector):
"""A connector that wraps around existing RLlib observation preprocessors.

Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/agent/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
)
from ray.rllib.connectors.registry import get_connector, register_connector
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack
from ray.util.timer import _Timer


logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
@OldAPIStack
class AgentConnectorPipeline(ConnectorPipeline, AgentConnector):
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
super().__init__(ctx, connectors)
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/agent/state_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
from ray.rllib.core.models.base import STATE_OUT
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack


logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
@OldAPIStack
class StateBufferConnector(AgentConnector):
def __init__(self, ctx: ConnectorContext, states: Any = None):
super().__init__(ctx)
Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/agent/synced_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
AgentConnector,
ConnectorContext,
)
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.utils.filter import Filter


@PublicAPI(stability="alpha")
@OldAPIStack
class SyncedFilterAgentConnector(AgentConnector):
"""An agent connector that filters with synchronized parameters."""

Expand Down
4 changes: 2 additions & 2 deletions rllib/connectors/agent/view_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
AgentConnectorDataType,
AgentConnectorsOutput,
)
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector


@PublicAPI(stability="alpha")
@OldAPIStack
class ViewRequirementAgentConnector(AgentConnector):
"""This connector does 2 things:
1. It filters data columns based on view_requirements for training and inference.
Expand Down
12 changes: 6 additions & 6 deletions rllib/connectors/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
AlgorithmConfigDict,
TensorType,
)
from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack

if TYPE_CHECKING:
from ray.rllib.policy.policy import Policy

logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
@OldAPIStack
class ConnectorContext:
"""Data bits that may be needed for running connectors.

Expand Down Expand Up @@ -79,7 +79,7 @@ def from_policy(policy: "Policy") -> "ConnectorContext":
)


@PublicAPI(stability="alpha")
@OldAPIStack
class Connector(abc.ABC):
"""Connector base class.

Expand Down Expand Up @@ -137,7 +137,7 @@ def from_state(self, ctx: ConnectorContext, params: Any) -> "Connector":
return NotImplementedError


@PublicAPI(stability="alpha")
@OldAPIStack
class AgentConnector(Connector):
"""Connector connecting user environments to RLlib policies.

Expand Down Expand Up @@ -277,7 +277,7 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
raise NotImplementedError


@PublicAPI(stability="alpha")
@OldAPIStack
class ActionConnector(Connector):
"""Action connector connects policy outputs including actions,
to user environments.
Expand Down Expand Up @@ -332,7 +332,7 @@ def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType
raise NotImplementedError


@PublicAPI(stability="alpha")
@OldAPIStack
class ConnectorPipeline(abc.ABC):
"""Utility class for quick manipulation of a connector pipeline."""

Expand Down
6 changes: 3 additions & 3 deletions rllib/connectors/registry.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Registry of connector names for global access."""
from typing import Any

from ray.util.annotations import PublicAPI
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.connectors.connector import Connector, ConnectorContext


ALL_CONNECTORS = dict()


@PublicAPI(stability="alpha")
@OldAPIStack
def register_connector(name: str, cls: Connector):
"""Register a connector for use with RLlib.

Expand All @@ -28,7 +28,7 @@ def register_connector(name: str, cls: Connector):
ALL_CONNECTORS[name] = cls


@PublicAPI(stability="alpha")
@OldAPIStack
def get_connector(name: str, ctx: ConnectorContext, params: Any = None) -> Connector:
# TODO(jungong) : switch the order of parameters man!!
"""Get a connector by its name and serialized config.
Expand Down
14 changes: 7 additions & 7 deletions rllib/connectors/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
MeanStdObservationFilterAgentConnector,
ConcurrentMeanStdObservationFilterAgentConnector,
)
from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.annotations import OldAPIStack
from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector

if TYPE_CHECKING:
Expand All @@ -44,7 +44,7 @@ def __clip_rewards(config: "AlgorithmConfig"):
return config.clip_rewards or config.is_atari


@PublicAPI(stability="alpha")
@OldAPIStack
def get_agent_connectors_from_config(
ctx: ConnectorContext,
config: "AlgorithmConfig",
Expand Down Expand Up @@ -78,7 +78,7 @@ def get_agent_connectors_from_config(
return AgentConnectorPipeline(ctx, connectors)


@PublicAPI(stability="alpha")
@OldAPIStack
def get_action_connectors_from_config(
ctx: ConnectorContext,
config: "AlgorithmConfig",
Expand All @@ -98,7 +98,7 @@ def get_action_connectors_from_config(
return ActionConnectorPipeline(ctx, connectors)


@PublicAPI(stability="alpha")
@OldAPIStack
def create_connectors_for_policy(policy: "Policy", config: "AlgorithmConfig"):
"""Util to create agent and action connectors for a Policy.

Expand All @@ -120,7 +120,7 @@ def create_connectors_for_policy(policy: "Policy", config: "AlgorithmConfig"):
logger.info(policy.action_connectors.__str__(indentation=4))


@PublicAPI(stability="alpha")
@OldAPIStack
def restore_connectors_for_policy(
policy: "Policy", connector_config: Tuple[str, Tuple[Any]]
) -> Connector:
Expand All @@ -136,7 +136,7 @@ def restore_connectors_for_policy(


# We need this filter selection mechanism temporarily to remain compatible to old API
@DeveloperAPI
@OldAPIStack
def get_synced_filter_connector(ctx: ConnectorContext):
filter_specifier = ctx.config.get("observation_filter")
if filter_specifier == "MeanStdFilter":
Expand All @@ -149,7 +149,7 @@ def get_synced_filter_connector(ctx: ConnectorContext):
raise Exception("Unknown observation_filter: " + str(filter_specifier))


@DeveloperAPI
@OldAPIStack
def maybe_get_filters_for_syncing(rollout_worker, policy_id):
# As long as the historic filter synchronization mechanism is in
# place, we need to put filters into self.filters so that they get
Expand Down
3 changes: 3 additions & 0 deletions rllib/env/apis/task_settable_env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import gymnasium as gym
from typing import List, Any

from ray.rllib.utils.annotations import OldAPIStack

TaskType = Any # Can be different types depending on env, e.g., int or dict


@OldAPIStack
class TaskSettableEnv(gym.Env):
"""
Extension of gym.Env to define a task-settable Env.
Expand Down
Loading
Loading