Source code for gym_jiminy.common.bases.interfaces

"""Controller and observer abstract interfaces from reinforcement learning,
specifically design for Jiminy engine, and defined as mixin classes. Any
observer/controller block must inherit and implement those interfaces.
"""
from abc import abstractmethod, ABC
from collections import OrderedDict
from typing import Dict, Any, TypeVar, Generic, no_type_check, TYPE_CHECKING
from typing_extensions import TypeAlias

import numpy as np
import numpy.typing as npt
import gymnasium as gym

import jiminy_py.core as jiminy
from jiminy_py.simulator import Simulator
from jiminy_py.viewer.viewer import is_display_available

from ..utils import DataNested
if TYPE_CHECKING:
    from ..quantities import QuantityManager


# Temporal resolution of simulator steps
DT_EPS: float = 1e-6  # 'SIMULATION_MIN_TIMESTEP'


ObsT = TypeVar('ObsT', bound=DataNested)
ActT = TypeVar('ActT', bound=DataNested)
BaseObsT = TypeVar('BaseObsT', bound=DataNested)
BaseActT = TypeVar('BaseActT', bound=DataNested)

SensorMeasurementStackMap = Dict[str, npt.NDArray[np.float64]]
InfoType = Dict[str, Any]


# class EngineObsType(TypedDict):
#     t: np.ndarray
#     state:  DataNested
#     features: DataNested


EngineObsType: TypeAlias = DataNested


[docs] class InterfaceObserver(ABC, Generic[ObsT, BaseObsT]): """Observer interface for both observers and environments. """ observe_dt: float = -1 observation_space: gym.Space # [ObsT] observation: ObsT def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize the observer interface. :param args: Extra arguments that may be useful for mixing multiple inheritance through multiple inheritance. :param kwargs: Extra keyword arguments. See 'args'. """ # Call super to allow mixing interfaces through multiple inheritance super().__init__(*args, **kwargs) # Refresh the observation space self._initialize_observation_space() @abstractmethod def _initialize_observation_space(self) -> None: """Configure the observation space. """
[docs] @abstractmethod def refresh_observation(self, measurement: BaseObsT) -> None: """Compute observed features based on the current simulation state and lower-level measure. .. warning: When overloading this method, one is expected to use the internal buffer `observation` to store the observation by updating it by reference. It may be error prone and tricky to get use to it, but it is computationally more efficient as it avoids allocating memory multiple times and redundant calculus. Additionally, it enables to retrieve the observation later on by calling `get_observation`. :param measurement: Low-level measure from the environment to process to get higher-level observation. """
[docs] class InterfaceController(ABC, Generic[ActT, BaseActT]): """Controller interface for both controllers and environments. """ control_dt: float = -1 action_space: gym.Space # [ActT] def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize the controller interface. :param args: Extra arguments that may be useful for mixing multiple inheritance through multiple inheritance. :param kwargs: Extra keyword arguments. See 'args'. """ # Call super to allow mixing interfaces through multiple inheritance super().__init__(*args, **kwargs) # Refresh the action space self._initialize_action_space() @abstractmethod def _initialize_action_space(self) -> None: """Configure the action space. """
[docs] @abstractmethod def compute_command(self, action: ActT, command: BaseActT) -> None: """Compute the command to send to the subsequent block, based on the action and current observation of the environment. .. note:: By design, the observation of the environment has been refreshed automatically prior to calling this method. :param action: High-level target to achieve by means of the command. :param command: Command to send to the subsequent block. It corresponds to the target features of another lower-level controller if any, the target motors efforts of the environment to ultimately control otherwise. It must be updated in-place. """
[docs] def compute_reward(self, terminated: bool, # pylint: disable=unused-argument truncated: bool, # pylint: disable=unused-argument info: InfoType # pylint: disable=unused-argument ) -> float: """Compute the reward related to a specific control block. For the corresponding MDP to be stationary, the computation of the reward is supposed to involve only the transition from previous to current state of the simulation (possibly comprising multiple agents) under the ongoing action. By default, it returns 0.0 no matter what. It is up to the user to provide a dedicated reward function whenever appropriate. .. warning:: Only returning an aggregated scalar reward is supported. Yet, it is possible to update 'info' by reference if one wants for keeping track of individual reward components or any kind of extra info that may be helpful for monitoring or debugging purposes. :param terminated: Whether the episode has reached the terminal state of the MDP at the current step. This flag can be used to compute a specific terminal reward. :param truncated: Whether a truncation condition outside the scope of the MDP has been satisfied at the current step. This flag can be used to adapt the reward. :param info: Dictionary of extra information for monitoring. :returns: Aggregated reward for the current step. """ return 0.0
# Note that `InterfaceJiminyEnv` must inherit from `InterfaceObserver` # before `InterfaceController` to initialize the action space before the # observation space since the action itself may be part of the observation. # Similarly, `gym.Env` must be last to make sure all the other initialization # methods are called first.
[docs] class InterfaceJiminyEnv( InterfaceObserver[ObsT, EngineObsType], InterfaceController[ActT, np.ndarray], gym.Env[ObsT, ActT], Generic[ObsT, ActT]): """Observer plus controller interface for both generic pipeline blocks, including environments. """ metadata: Dict[str, Any] = { "render_modes": ( ['rgb_array'] + (['human'] if is_display_available() else [])) } simulator: Simulator robot: jiminy.Robot stepper_state: jiminy.StepperState robot_state: jiminy.RobotState sensor_measurements: SensorMeasurementStackMap is_simulation_running: npt.NDArray[np.bool_] quantities: "QuantityManager" action: ActT def __init__(self, *args: Any, **kwargs: Any) -> None: # Track whether the observation has been refreshed manually since the # last called '_controller_handle'. It typically happens at the end of # every simulation step to return an observation that is consistent # with the updated state of the agent. self.__is_observation_refreshed = True # Store latest engine measurement for efficiency self.__measurement: EngineObsType = OrderedDict( t=np.array(0.0), states=OrderedDict( agent=OrderedDict(q=np.array([]), v=np.array([]))), measurements=OrderedDict(self.robot.sensor_measurements)) self._sensors_types = tuple(self.robot.sensor_measurements.keys()) # Call super to allow mixing interfaces through multiple inheritance super().__init__(*args, **kwargs) def _setup(self) -> None: """Configure the observer-controller. In practice, it only resets the controller and observer update periods. .. note:: This method must be called once, after the environment has been reset. This is done automatically when calling `reset` method. """ # Re-initialize observation and control periods to ill-defined value to # trigger exception if not set properly later on. self.observe_dt = -1 self.control_dt = -1 # The observation must always be refreshed after setup self.__is_observation_refreshed = False @no_type_check def _observer_handle(self, t: float, q: np.ndarray, v: np.ndarray, sensor_measurements: jiminy.SensorMeasurementTree ) -> None: """Thin wrapper around user-specified `refresh_observation` method. .. warning:: This method is not supposed to be called manually nor overloaded. :param t: Current simulation time. :param q: Current extended configuration vector of the robot. :param v: Current extended velocity vector of the robot. :param sensor_measurements: Current sensor data. """ # Refresh the observation if not already done but only if a simulation # is already running. It would be pointless to refresh the observation # at this point since the controller will be called multiple times at # start. Besides, it would defeat the purpose `_initialize_buffers`, # that is supposed to be executed before `refresh_observation` is being # called for the first time of an episode. if not self.__is_observation_refreshed and self.is_simulation_running: measurement = self.__measurement measurement["t"][()] = t measurement["states"]["agent"]["q"] = q measurement["states"]["agent"]["v"] = v measurement_sensors = measurement["measurements"] sensor_measurements_it = iter(sensor_measurements.values()) for sensor_type in self._sensors_types: measurement_sensors[sensor_type] = next(sensor_measurements_it) try: self.refresh_observation(measurement) except RuntimeError as e: raise RuntimeError( "The observation space must be constant.") from e self.__is_observation_refreshed = True def _controller_handle(self, t: float, q: np.ndarray, v: np.ndarray, sensor_measurements: jiminy.SensorMeasurementTree, command: np.ndarray) -> None: """Thin wrapper around user-specified `refresh_observation` and `compute_command` methods. .. note:: The internal cache of managed quantities is cleared right away systematically, before anything else. .. warning:: This method is not supposed to be called manually nor overloaded. It will be used by the base environment to instantiate a `jiminy.FunctionalController` responsible for both refreshing observations and compute commands of all the way through a given pipeline in the correct order of the blocks to finally sends command motor torques directly to the robot. :param t: Current simulation time. :param q: Current extended configuration vector of the robot. :param v: Current actual velocity vector of the robot. :param sensor_measurements: Current sensor measurements. :param command: Output argument corresponding to motors torques to apply on the robot. It must be updated by reference using `[:]` or `np.copyto`. :returns: Motors torques to apply on the robot. """ # Reset the quantity manager. # In principle, the internal cache of quantities should be cleared not # each time the state of the robot and/or its derivative changes. This # is hard to do because there is no way to detect this specifically at # the time being. However, `_controller_handle` is never called twice # in the exact same state by the engine, so resetting quantities at the # beginning of the method should cover most cases. Yet, quantities # cannot be used reliably in the definition of profile forces because # they are always updated before the controller gets called, no matter # if either one or the other is time-continuous. Hacking the internal # dynamics to clear quantities does not address this issue either. self.quantities.clear() # Refresh the observation self._observer_handle(t, q, v, sensor_measurements) # No need to check for breakpoints of the controller because it already # matches the update period by design. self.compute_command(self.action, command) # Always consider that the observation must be refreshed after calling # '_controller_handle' as it is never called more often than necessary. self.__is_observation_refreshed = False @property def unwrapped(self) -> "InterfaceJiminyEnv": """Base environment of the pipeline. """ return self @property @abstractmethod def step_dt(self) -> float: """Get timestep of a single 'step'. """ @property @abstractmethod def is_training(self) -> bool: """Check whether the environment is in 'train' or 'eval' mode. """
[docs] @abstractmethod def train(self) -> None: """Sets the environment in training mode. """
[docs] @abstractmethod def eval(self) -> None: """Sets the environment in evaluation mode. This only has an effect on certain environments. It can be used for instance to enable clipping or filtering of the action at evaluation time specifically. See documentations of a given environment for details about their behaviors in training and evaluation modes. """