Source code for prob_spaces.converter

"""Module for converting gymnasium action spaces to probability distribution spaces."""

import gymnasium as gym

from prob_spaces.box import BoxDist
from prob_spaces.dict import DictDist
from prob_spaces.discrete import DiscreteDist
from prob_spaces.multi_discrete import MultiDiscreteDist
from prob_spaces.tuple import TupleDist

Spaces = gym.spaces.Box | gym.spaces.Discrete | gym.spaces.MultiDiscrete
DistSpaces = BoxDist | DiscreteDist | MultiDiscreteDist | None


[docs] def convert_to_prob_space(action_space: Spaces) -> DistSpaces: """Convert an action space into its corresponding probability distribution space. This function supports different types of action spaces and creates an appropriate distribution space for each one. Supported action spaces include `MultiDiscrete`, `Discrete`, `Box`, and `Dict`. For `Dict` action spaces, the function recursively converts each subspace into its probability distribution space. :param action_space: The input action space to be converted. This can be an instance of `gym.spaces.MultiDiscrete`, `gym.spaces.Discrete`, `gym.spaces.Box`, or `gym.spaces.Dict`. :type action_space: Spaces :raises NotImplementedError: If the input action space type is not supported. :return: The corresponding probability distribution space created based on the input action space type. Returns ------- DistSpaces The corresponding probability distribution space created based on the input action space type. """ if isinstance(action_space, gym.spaces.MultiDiscrete): space_dist = MultiDiscreteDist.from_space(action_space) elif isinstance(action_space, gym.spaces.Discrete): space_dist = DiscreteDist.from_space(action_space) # type: ignore elif isinstance(action_space, gym.spaces.Box): space_dist = BoxDist.from_space(action_space) # type: ignore elif isinstance(action_space, gym.spaces.Dict): space_dist = DictDist() for k, v in action_space.spaces.items(): space_dist[k] = convert_to_prob_space(v) elif isinstance(action_space, gym.spaces.Tuple): space_list = [] for v in action_space.spaces: space_list.append(convert_to_prob_space(v)) space_dist = TupleDist(space_list) else: raise NotImplementedError(f"Action space {type(action_space)} not supported") return space_dist