Source code for gym_jiminy.common.compositions.locomotion

"""Rewards mainly relevant for locomotion tasks on floating-base robots.
"""
from functools import partial
from dataclasses import dataclass
from typing import Optional, Union, Sequence, Literal, Callable, cast

import numpy as np
import numba as nb

import jiminy_py.core as jiminy
import pinocchio as pin

from ..bases import (
    InterfaceJiminyEnv, StateQuantity, InterfaceQuantity, QuantityEvalMode,
    QuantityReward)
from ..quantities import (
    OrientationType, MaskedQuantity, UnaryOpQuantity, FrameOrientation,
    BaseRelativeHeight, BaseOdometryPose, BaseOdometryAverageVelocity,
    CapturePoint, MultiFramePosition, MultiFootRelativeXYZQuat,
    MultiContactNormalizedSpatialForce, MultiFootNormalizedForceVertical,
    MultiFootCollisionDetection, AverageBaseMomentum)
from ..utils import quat_difference, quat_to_yaw

from .generic import (
    ArrayLikeOrScalar, TrackingQuantityReward, QuantityTermination,
    DriftTrackingQuantityTermination, ShiftTrackingQuantityTermination)
from .mixin import radial_basis_function


[docs] class TrackingBaseHeightReward(TrackingQuantityReward): """Reward the agent for tracking the height of the floating base of the robot wrt some reference trajectory. .. seealso:: See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, cutoff: float) -> None: """ :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. """ super().__init__( env, "reward_tracking_base_height", lambda mode: (MaskedQuantity, dict( quantity=(UnaryOpQuantity, dict( quantity=(StateQuantity, dict( update_kinematics=False, mode=mode)), op=lambda state: state.q)), axis=0, keys=(2,))), cutoff)
[docs] class TrackingBaseOdometryVelocityReward(TrackingQuantityReward): """Reward the agent for tracking the odometry velocity wrt some reference trajectory. .. seealso:: See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, cutoff: float) -> None: """ :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. """ super().__init__( env, "reward_tracking_odometry_velocity", lambda mode: (BaseOdometryAverageVelocity, dict(mode=mode)), cutoff)
[docs] class TrackingCapturePointReward(TrackingQuantityReward): """Reward the agent for tracking the capture point wrt some reference trajectory. .. seealso:: See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, cutoff: float) -> None: """ :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. """ super().__init__( env, "reward_tracking_capture_point", lambda mode: (CapturePoint, dict( reference_frame=pin.ReferenceFrame.LOCAL, mode=mode)), cutoff)
[docs] class TrackingFootPositionsReward(TrackingQuantityReward): """Reward the agent for tracking the relative position of the feet wrt each other. .. seealso:: See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, cutoff: float, *, frame_names: Union[Sequence[str], Literal['auto']] = 'auto' ) -> None: """ :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. :param frame_names: Name of the frames corresponding to the feet of the robot. 'auto' to automatically detect them from the set of contact and force sensors of the robot. Optional: 'auto' by default. """ super().__init__( env, "reward_tracking_foot_positions", lambda mode: (MaskedQuantity, dict( quantity=(MultiFootRelativeXYZQuat, dict( frame_names=frame_names, mode=mode)), axis=0, keys=(0, 1, 2))), cutoff)
[docs] class TrackingFootOrientationsReward(TrackingQuantityReward): """Reward the agent for tracking the relative orientation of the feet wrt each other. .. seealso:: See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, cutoff: float, *, frame_names: Union[Sequence[str], Literal['auto']] = 'auto' ) -> None: """ :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. :param frame_names: Name of the frames corresponding to the feet of the robot. 'auto' to automatically detect them from the set of contact and force sensors of the robot. Optional: 'auto' by default. """ super().__init__( env, "reward_tracking_foot_orientations", lambda mode: (MaskedQuantity, dict( quantity=(MultiFootRelativeXYZQuat, dict( frame_names=frame_names, mode=mode)), axis=0, keys=(3, 4, 5, 6))), cutoff, op=cast(Callable[ [np.ndarray, np.ndarray], np.ndarray], quat_difference))
[docs] class TrackingFootForceDistributionReward(TrackingQuantityReward): """Reward the agent for tracking the relative vertical force in world frame applied on each foot. .. note:: The force is normalized by the weight of the robot rather than the total force applied on all feet. This is important as it not only takes into account the force distribution between the feet, but also the overall ground contact interact force. This way, building up momentum before jumping will be distinguished for standing still. Moreover, it ensures that the reward is always properly defined, even if the robot has no contact with the ground at all, which typically arises during the flying phase of running. .. seealso:: See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, cutoff: float, *, frame_names: Union[Sequence[str], Literal['auto']] = 'auto' ) -> None: """ :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. :param frame_names: Name of the frames corresponding to the feet of the robot. 'auto' to automatically detect them from the set of contact and force sensors of the robot. Optional: 'auto' by default. """ super().__init__( env, "reward_tracking_foot_force_distribution", lambda mode: (MultiFootNormalizedForceVertical, dict( frame_names=frame_names, mode=mode)), cutoff)
[docs] class MinimizeAngularMomentumReward(QuantityReward): """Reward the agent for minimizing the angular momentum in world plane. The angular momentum along x- and y-axes in local odometry frame is transform in a normalized reward to maximize by applying RBF kernel on the error. See `TrackingQuantityReward` documentation for technical details. """ def __init__(self, env: InterfaceJiminyEnv, cutoff: float) -> None: """ :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. """ # Backup some user argument(s) self.cutoff = cutoff # Call base implementation super().__init__( env, "reward_momentum", (AverageBaseMomentum, dict(mode=QuantityEvalMode.TRUE)), partial(radial_basis_function, cutoff=self.cutoff, order=2), is_normalized=True, is_terminal=False)
[docs] class MinimizeFrictionReward(QuantityReward): """Reward the agent for minimizing the tangential forces at all the contact points and collision bodies, and to avoid jerky intermittent contact state. The L^2-norm is used to aggregate all the local tangential forces. While the L^1-norm would be more natural in this specific cases, using the L-2 norm is preferable as it promotes space-time regularity, ie balancing the force distribution evenly between all the candidate contact points and avoiding jerky contact forces over time (high-frequency vibrations), phenomena to which the L^1-norm is completely insensitive. """ def __init__(self, env: InterfaceJiminyEnv, cutoff: float) -> None: """ :param env: Base or wrapped jiminy environment. :param cutoff: Cutoff threshold for the RBF kernel transform. """ # Backup some user argument(s) self.cutoff = cutoff # Call base implementation super().__init__( env, "reward_friction", (MaskedQuantity, dict( quantity=(MultiContactNormalizedSpatialForce, dict()), axis=0, keys=(0, 1))), partial(radial_basis_function, cutoff=self.cutoff, order=2), is_normalized=True, is_terminal=False)
[docs] class BaseRollPitchTermination(QuantityTermination): """Encourages the agent to keep the floating base straight, ie its torso in case of a humanoid robot, by prohibiting excessive roll and pitch angles. """ def __init__(self, env: InterfaceJiminyEnv, low: Optional[ArrayLikeOrScalar], high: Optional[ArrayLikeOrScalar], grace_period: float = 0.0, *, is_training_only: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. :param low: Lower bound below which termination is triggered. :param high: Upper bound above which termination is triggered. :param grace_period: Grace period effective only at the very beginning of the episode, during which the latter is bound to continue whatever happens. Optional: 0.0 by default. :param is_training_only: Whether the termination condition should be completely by-passed if the environment is in evaluation mode. Optional: False by default. """ super().__init__( env, "termination_base_roll_pitch", (MaskedQuantity, dict( # type: ignore[arg-type] quantity=(FrameOrientation, dict( frame_name="root_joint", type=OrientationType.EULER)), axis=0, keys=(0, 1))), low, high, grace_period, is_truncation=False, is_training_only=is_training_only)
[docs] class FallingTermination(QuantityTermination): """Terminate the episode immediately if the floating base of the robot gets too close from the ground. It is assumed that the state is no longer recoverable when its condition is triggered. As such, the episode is terminated on the spot as the situation is hopeless. Generally speaking, aborting an epsiode in anticipation of catastrophic failure is beneficial. Assuming the condition is on point, doing this improves the signal to noice ratio when estimating the gradient by avoiding cluterring the training batches with irrelevant information. """ def __init__(self, env: InterfaceJiminyEnv, min_base_height: float, grace_period: float = 0.0, *, is_training_only: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. :param min_base_height: Minimum height of the floating base of the robot below which termination is triggered. :param grace_period: Grace period effective only at the very beginning of the episode, during which the latter is bound to continue whatever happens. Optional: 0.0 by default. :param is_training_only: Whether the termination condition should be completely by-passed if the environment is in evaluation mode. Optional: False by default. """ super().__init__( env, "termination_base_height", (BaseRelativeHeight, {}), # type: ignore[arg-type] min_base_height, None, grace_period, is_truncation=False, is_training_only=is_training_only)
[docs] class FootCollisionTermination(QuantityTermination): """Terminate the episode immediately if some of the feet of the robot are getting too close from each other. Self-collision must be avoided at all cost, as it can damage the hardware. Considering this condition as a dramatically failure urges the agent to do his best in this matter, to the point of becoming risk averse. """ def __init__(self, env: InterfaceJiminyEnv, security_margin: float = 0.0, grace_period: float = 0.0, frame_names: Union[Sequence[str], Literal['auto']] = 'auto', *, is_training_only: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. :param security_margin: Minimum signed distance below which termination is triggered. This can be interpreted as inflating or deflating the geometry objects by the safety margin depending on whether it is positive or negative. See `MultiFootCollisionDetection` for details. :param grace_period: Grace period effective only at the very beginning of the episode, during which the latter is bound to continue whatever happens. Optional: 0.0 by default. :param frame_names: Name of the frames corresponding to the feet of the robot. 'auto' to automatically detect them from the set of contact and force sensors of the robot. Optional: 'auto' by default. :param is_training_only: Whether the termination condition should be completely by-passed if the environment is in evaluation mode. Optional: False by default. """ super().__init__( env, "termination_foot_collision", (MultiFootCollisionDetection, dict( # type: ignore[arg-type] frame_names=frame_names, security_margin=security_margin)), False, False, grace_period, is_truncation=False, is_training_only=is_training_only)
[docs] @dataclass(unsafe_hash=True) class _MultiContactMinGroundDistance(InterfaceQuantity[float]): """Minimum distance from the ground profile among all the contact points. .. note:: Internally, it does not compute the exact shortest distance from the ground profile because it would be computionally too demanding for now. As a surrogate, it relies on a first order approximation assuming zero local curvature around all the contact points individually. .. warning:: The set of contact points must not change over episodes. In addition, collision bodies are not supported for now. """ def __init__(self, env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity]) -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. """ # Get the name of all the contact points contact_frame_names = env.robot.contact_frame_names # Call base implementation super().__init__( env, parent, requirements=dict( positions=(MultiFramePosition, dict( frame_names=contact_frame_names, mode=QuantityEvalMode.TRUE ))), auto_refresh=False) # Jit-able method computing the minimum first-order depth @nb.jit(nopython=True, cache=True, fastmath=True) def min_depth(positions: np.ndarray, heights: np.ndarray, normals: np.ndarray) -> float: """Approximate minimum distance from the ground profile among a set of the query points. Internally, it uses a first order approximation assuming zero local curvature around each query point. :param positions: Position of all the query points from which to compute from the ground profile, as a 2D array whose first dimension gathers the 3 position coordinates (X, Y, Z) while the second correponds to the N individual query points. :param heights: Vertical height wrt the ground profile of the N individual query points in world frame as 1D array. :param normals: Normal of the ground profile for the projection in world plane of all the query points, as a 2D array whose first dimension gathers the 3 position coordinates (X, Y, Z) while the second correponds to the N individual query points. """ return np.min((positions[2] - heights) * normals[2]) self._min_depth = min_depth # Reference to the heightmap function for the ongoing epsiode self._heightmap = jiminy.HeightmapFunction(lambda: None) # Allocate memory for the height and normal of all the contact points self._heights = np.zeros((len(contact_frame_names),)) self._normals = np.zeros((3, len(contact_frame_names)), order="F")
[docs] def initialize(self) -> None: # Call base implementation super().initialize() # Refresh the heighmap function engine_options = self.env.unwrapped.engine.get_options() self._heightmap = engine_options["world"]["groundProfile"]
[docs] def refresh(self) -> float: # Query the height and normal to the ground profile for the position in # world plane of all the contact points. positions = self.positions.get() jiminy.query_heightmap(self._heightmap, positions[:2], self._heights, self._normals) # Make sure the ground normal is normalized # self._normals /= np.linalg.norm(self._normals, axis=0) # First-order distance estimation assuming no curvature return self._min_depth(positions, self._heights, self._normals)
[docs] class FlyingTermination(QuantityTermination): """Discourage the agent of jumping by terminating the episode immediately if the robot is flying too high above the ground. This kind of behavior is unsually undesirable because it may be frightning for people nearby, damage the hardware, be difficult to predict and be hardly repeatable. Moreover, such dynamic motions tend to transfer poorly to reality because the simulation to real gap is worsening. """ def __init__(self, env: InterfaceJiminyEnv, max_height: float, grace_period: float = 0.0, *, is_training_only: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. :param max_height: Maximum height of the lowest contact points wrt the groupd above which termination is triggered. :param grace_period: Grace period effective only at the very beginning of the episode, during which the latter is bound to continue whatever happens. Optional: 0.0 by default. :param is_training_only: Whether the termination condition should be completely by-passed if the environment is in evaluation mode. Optional: False by default. """ super().__init__( env, "termination_flying", (_MultiContactMinGroundDistance, {}), # type: ignore[arg-type] None, max_height, grace_period, is_truncation=False, is_training_only=is_training_only)
[docs] class ImpactForceTermination(QuantityTermination): """Terminate the episode immediately in case of violent impact on the ground. Similarly to the jumping behavior, this kind of behavior is usually undesirable. See `FlyingTermination` documentation for details. """ def __init__(self, env: InterfaceJiminyEnv, max_force_rel: float, grace_period: float = 0.0, *, is_training_only: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. :param max_force_rel: Maximum vertical force applied on any of the contact points or collision bodies above which termination is triggered. :param grace_period: Grace period effective only at the very beginning of the episode, during which the latter is bound to continue whatever happens. Optional: 0.0 by default. :param is_training_only: Whether the termination condition should be completely by-passed if the environment is in evaluation mode. Optional: False by default. """ super().__init__( env, "termination_impact_force", (MaskedQuantity, dict( # type: ignore[arg-type] quantity=(MultiContactNormalizedSpatialForce, dict()), axis=0, keys=(2,))), None, max_force_rel, grace_period, is_truncation=False, is_training_only=is_training_only)
[docs] class DriftTrackingBaseOdometryPositionTermination( DriftTrackingQuantityTermination): """Terminate the episode if the current base odometry position is drifting too much over wrt some reference trajectory that is being tracked. It is generally important to make sure that the robot is not deviating too much from some reference trajectory. It sounds appealing to make sure that the absolute error between the current and reference trajectory is bounded at all time. However, such a condition is very restrictive, especially for robots dealing with external disturbances or evolving on an uneven terrain. Moreover, when it comes to infinite-horizon trajectories in particular, eg periodic motions, avoiding drifting away over time involves being able to sense the absolute position of the robot in world frame via exteroceptive navigation sensors such as depth cameras or LIDARs. This kind of advanced sensor may not be able, thereby making the objective out of reach. Still, in the case of legged locomotion, what really matters is tracking accurately a nominal limit cycle as long as doing so does not compromise local stability. If it does, then the agent expected to make every effort to recover balance as fast as possible before going back to the nominal limit cycle, without trying to catch up with the ensuing drift since the exact absolute odometry pose in world frame is of little interest. See `BaseOdometryPose` and `DriftTrackingQuantityTermination` documentations for details. """ def __init__(self, env: InterfaceJiminyEnv, max_position_err: float, horizon: float, grace_period: float = 0.0, *, is_training_only: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. :param max_position_err: Maximum drift error in translation (X, Y) in world plane above which termination is triggered. :param horizon: Horizon over which values of the quantity will be stacked before computing the drift. :param grace_period: Grace period effective only at the very beginning of the episode, during which the latter is bound to continue whatever happens. Optional: 0.0 by default. :param is_training_only: Whether the termination condition should be completely by-passed if the environment is in evaluation mode. Optional: False by default. """ super().__init__( env, "termination_tracking_base_odom_position", lambda mode: ( # type: ignore[arg-type, return-value] MaskedQuantity, dict( quantity=(BaseOdometryPose, dict( mode=mode)), axis=0, keys=(0, 1))), None, max_position_err, horizon, grace_period, post_fn=np.linalg.norm, is_truncation=False, is_training_only=is_training_only)
[docs] class DriftTrackingBaseOdometryOrientationTermination( DriftTrackingQuantityTermination): """Terminate the episode if the current base odometry orientation is drifting too much over wrt some reference trajectory that is being tracked. See `BaseOdometryPose` and `DriftTrackingBaseOdometryPositionTermination` documentations for details. """ def __init__(self, env: InterfaceJiminyEnv, max_orientation_err: float, horizon: float, grace_period: float = 0.0, *, is_training_only: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. :param max_orientation_err: Maximum drift error in orientation (yaw,) in world plane above which termination is triggered. :param horizon: Horizon over which values of the quantity will be stacked before computing the drift. :param grace_period: Grace period effective only at the very beginning of the episode, during which the latter is bound to continue whatever happens. Optional: 0.0 by default. :param is_training_only: Whether the termination condition should be completely by-passed if the environment is in evaluation mode. Optional: False by default. """ super().__init__( env, "termination_tracking_base_odom_orientation", lambda mode: ( # type: ignore[arg-type, return-value] MaskedQuantity, dict( quantity=(BaseOdometryPose, dict( mode=mode)), axis=0, keys=(2,))), -max_orientation_err, max_orientation_err, horizon, grace_period, is_truncation=False, is_training_only=is_training_only)
[docs] class ShiftTrackingFootOdometryPositionsTermination( ShiftTrackingQuantityTermination): """Terminate the episode if the selected reference trajectory is not tracked with expected accuracy regarding the relative foot odometry positions, whatever the timestep being considered over some fixed-size sliding window. See `MultiFootRelativeXYZQuat` and `ShiftTrackingMotorPositionsTermination` documentation for details. """ def __init__(self, env: InterfaceJiminyEnv, max_position_err: float, horizon: float, grace_period: float = 0.0, frame_names: Union[Sequence[str], Literal['auto']] = 'auto', *, is_training_only: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. :param max_position_err: Maximum drift error in translation (X, Y) in world plane above which termination is triggered. :param horizon: Horizon over which values of the quantity will be stacked before computing the shift. :param grace_period: Grace period effective only at the very beginning of the episode, during which the latter is bound to continue whatever happens. Optional: 0.0 by default. :param frame_names: Name of the frames corresponding to the feet of the robot. 'auto' to automatically detect them from the set of contact and force sensors of the robot. Optional: 'auto' by default. :param is_training_only: Whether the termination condition should be completely by-passed if the environment is in evaluation mode. Optional: False by default. """ super().__init__( env, "termination_tracking_foot_odom_positions", lambda mode: ( # type: ignore[arg-type, return-value] MaskedQuantity, dict( quantity=(MultiFootRelativeXYZQuat, dict( frame_names=frame_names, mode=mode)), axis=0, keys=(0, 1))), max_position_err, horizon, grace_period, is_truncation=False, is_training_only=is_training_only)
[docs] class ShiftTrackingFootOdometryOrientationsTermination( ShiftTrackingQuantityTermination): """Terminate the episode if the selected reference trajectory is not tracked with expected accuracy regarding the relative foot odometry orientations, whatever the timestep being considered over some fixed-size sliding window. See `MultiFootRelativeXYZQuat` and `ShiftTrackingMotorPositionsTermination` documentation for details. """ def __init__(self, env: InterfaceJiminyEnv, max_orientation_err: float, horizon: float, grace_period: float = 0.0, frame_names: Union[Sequence[str], Literal['auto']] = 'auto', *, is_training_only: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. Maximum shift error in orientation (yaw,) in world plane above which termination is triggered. :param horizon: Horizon over which values of the quantity will be stacked before computing the shift. :param grace_period: Grace period effective only at the very beginning of the episode, during which the latter is bound to continue whatever happens. Optional: 0.0 by default. :param frame_names: Name of the frames corresponding to the feet of the robot. 'auto' to automatically detect them from the set of contact and force sensors of the robot. Optional: 'auto' by default. :param is_training_only: Whether the termination condition should be completely by-passed if the environment is in evaluation mode. Optional: False by default. """ # Call base implementation super().__init__( env, "termination_tracking_foot_odom_orientations", lambda mode: (UnaryOpQuantity, dict( quantity=(MaskedQuantity, dict( quantity=(MultiFootRelativeXYZQuat, dict( frame_names=frame_names, mode=mode)), axis=0, keys=(3, 4, 5, 6))), op=quat_to_yaw)), max_orientation_err, horizon, grace_period, is_truncation=False, is_training_only=is_training_only)