Source code for gym_jiminy.common.wrappers.observation_stack

""" TODO: Write documentation.
"""
from copy import deepcopy
from operator import getitem
from functools import reduce
from collections import deque
from typing import (
    List, Any, Dict, Optional, Tuple, Sequence, Iterator, Union, Generic,
    SupportsFloat)
from typing_extensions import TypeAlias

import numpy as np

import gymnasium as gym

from ..utils import is_breakpoint, zeros, copy, copyto
from ..bases import (DT_EPS,
                     ObsT,
                     ActT,
                     InfoType,
                     EngineObsType,
                     InterfaceJiminyEnv,
                     BasePipelineWrapper)


StackedObsType: TypeAlias = ObsT


[docs] class PartialObservationStack( gym.Wrapper, # [StackedObsType, ActT, ObsT, ActT], Generic[ObsT, ActT]): """Observation wrapper that partially stacks observations in a rolling manner. This wrapper combines and extends OpenAI Gym wrappers `FrameStack` and `FilteredJiminyEnv` to support nested filter keys. It adds one extra dimension to all the leaves of the original observation spaces that must be stacked. If so, the first dimension corresponds to the individual timesteps (from oldest [0] to latest [-1]). .. note:: The observation space must be `gym.spaces.Dict`, while, ultimately, stacked leaf fields must be `gym.spaces.Box`. """ def __init__(self, env: gym.Env[ObsT, ActT], num_stack: int, nested_filter_keys: Optional[ Sequence[Union[Sequence[str], str]]] = None, **kwargs: Any): """ :param env: Environment to wrap. :param nested_filter_keys: List of nested observation fields to stack. Those fields does not have to be leaves. If not, then every leaves fields from this root will be stacked. :param num_stack: Number of observation frames to partially stack. :param kwargs: Extra keyword arguments to allow automatic pipeline wrapper generation. """ # pylint: disable=unused-argument # Sanitize user arguments if necessary assert isinstance(env.observation_space, gym.spaces.Dict) if nested_filter_keys is None: nested_filter_keys = list( env.observation_space.keys()) # type: ignore[attr-defined] # Backup user argument(s) self.nested_filter_keys: List[List[str]] = list( list(fields) for fields in nested_filter_keys) self.num_stack = num_stack # Initialize base wrapper. # Note that `gym.Wrapper` automatically binds the action/observation to # the one of the environment if not overridden explicitly. super().__init__(env) # Do not forward extra arguments, if any # Get the leaf fields to stack def _get_branches(root: Any) -> Iterator[List[str]]: if isinstance(root, gym.spaces.Dict): for field, node in root.spaces.items(): if isinstance(node, gym.spaces.Dict): for path in _get_branches(node): yield [field] + path else: yield [field] self.leaf_fields_list: List[List[str]] = [] for fields in self.nested_filter_keys: root_field = reduce(getitem, # type: ignore[arg-type] fields, self.env.observation_space) if isinstance(root_field, gym.spaces.Dict): leaf_paths = _get_branches(root_field) self.leaf_fields_list += [fields + path for path in leaf_paths] else: self.leaf_fields_list.append(fields) # Compute stacked observation space self.observation_space = deepcopy(self.env.observation_space) for fields in self.leaf_fields_list: assert isinstance(self.observation_space, gym.spaces.Dict) root_space = reduce(getitem, # type: ignore[arg-type] fields[:-1], self.observation_space) space = root_space[fields[-1]] if not isinstance(space, gym.spaces.Box): raise TypeError( "Stacked leaf fields must be associated with " "`gym.spaces.Box` space") low = np.repeat(space.low[np.newaxis], self.num_stack, axis=0) high = np.repeat(space.high[np.newaxis], self.num_stack, axis=0) assert space.dtype is not None assert issubclass(space.dtype.type, (np.floating, np.integer)) root_space[fields[-1]] = gym.spaces.Box( low=low, high=high, dtype=space.dtype.type) # Bind observation of the environment for all keys but the stacked ones if isinstance(self.env, InterfaceJiminyEnv): self.observation = copy(self.env.observation) for fields in self.leaf_fields_list: assert isinstance(self.observation_space, gym.spaces.Dict) root_obs = reduce(getitem, fields[:-1], self.observation) space = reduce(getitem, # type: ignore[arg-type] fields, self.observation_space) root_obs[fields[-1]] = zeros(space) else: # Fallback to classical memory allocation self.observation = zeros(self.observation_space) # Allocate internal frames buffers self._frames: List[deque] = [ deque(maxlen=self.num_stack) for _ in self.leaf_fields_list] def _setup(self) -> None: """ TODO: Write documentation. """ # Reset frames to zero for fields, frames in zip(self.leaf_fields_list, self._frames): assert isinstance(self.env.observation_space, gym.spaces.Dict) leaf_space = reduce(getitem, # type: ignore[arg-type] fields, self.env.observation_space) for _ in range(self.num_stack): frames.append(zeros(leaf_space))
[docs] def refresh_observation(self, measurement: ObsT) -> None: """ TODO: Write documentation. """ # Copy measurement if impossible to bind memory in the first place if not isinstance(self.env, InterfaceJiminyEnv): copyto(self.observation, measurement) # Backup the nested observation fields to stack. # Leaf values are copied to ensure they do not get altered later on. for fields, frames in zip(self.leaf_fields_list, self._frames): leaf_obs = reduce(getitem, # type: ignore[arg-type] fields, measurement) assert isinstance(leaf_obs, np.ndarray) frames.append(leaf_obs.copy()) # Update nested fields of the observation by the stacked ones for fields, frames in zip(self.leaf_fields_list, self._frames): leaf_obs = reduce(getitem, fields, self.observation) assert isinstance(leaf_obs, np.ndarray) leaf_obs[:] = frames
[docs] def step(self, action: ActT ) -> Tuple[StackedObsType, SupportsFloat, bool, bool, InfoType]: obs, reward, terminated, truncated, info = self.env.step(action) self.refresh_observation(obs) return self.observation, reward, terminated, truncated, info
[docs] def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None, ) -> Tuple[StackedObsType, InfoType]: observation, info = self.env.reset(seed=seed, options=options) self._setup() self.refresh_observation(observation) return self.observation, info
[docs] class StackedJiminyEnv( BasePipelineWrapper[StackedObsType, ActT, ObsT, ActT], Generic[ObsT, ActT]): """ TODO: Write documentation. """ def __init__(self, env: InterfaceJiminyEnv[ObsT, ActT], skip_frames_ratio: int = 0, **kwargs: Any) -> None: """ TODO: Write documentation. """ # Backup some user argument(s) self.skip_frames_ratio = skip_frames_ratio # Initialize some internal buffers self.__n_last_stack = 0 # Instantiate wrapper self.wrapper = PartialObservationStack(env, **kwargs) # Initialize base classes super().__init__(env, **kwargs) # Bind the observation of the wrapper self.observation = self.wrapper.observation # Bind the action of the environment assert self.action_space.contains(env.action) self.action = env.action def _initialize_action_space(self) -> None: self.action_space = self.env.action_space def _initialize_observation_space(self) -> None: self.observation_space = self.wrapper.observation_space def _setup(self) -> None: # Call base implementation super()._setup() # Setup wrapper self.wrapper._setup() # Make sure observe update is discrete-time if self.env.observe_dt <= 0.0: raise ValueError( "This wrapper does not support time-continuous update.") # Copy observe and control update periods from wrapped environment self.observe_dt = self.env.observe_dt self.control_dt = self.env.control_dt # Re-initialize some internal buffer(s). # Note that the initial observation is always stored. self.__n_last_stack = self.skip_frames_ratio - 1
[docs] def refresh_observation(self, measurement: EngineObsType) -> None: # Get environment observation self.env.refresh_observation(measurement) # Update observed features if necessary if self.is_simulation_running and is_breakpoint( self.stepper_state.t, self.env.observe_dt, DT_EPS): self.__n_last_stack += 1 if self.__n_last_stack == self.skip_frames_ratio: self.__n_last_stack = -1 self.wrapper.refresh_observation(self.env.observation)
[docs] def compute_command(self, action: ActT, command: np.ndarray) -> None: """Compute the motors efforts to apply on the robot. It simply forwards the command computed by the wrapped environment without any processing. :param action: High-level target to achieve by means of the command. """ self.env.compute_command(action, command)