import gymnasium as gym
from prob_spaces.box import BoxDist
from prob_spaces.dict import DictDist
from prob_spaces.discrete import DiscreteDist
from prob_spaces.multi_discrete import MultiDiscreteDist
from prob_spaces.tuple import TupleDist
Spaces = gym.spaces.Box | gym.spaces.Discrete | gym.spaces.MultiDiscrete
DistSpaces = BoxDist | DiscreteDist | MultiDiscreteDist | None
[docs]
def convert_to_prob_space(action_space: Spaces) -> DistSpaces:
"""
Converts an action space into its corresponding probability distribution space.
This function supports different types of action spaces and creates an appropriate
distribution space for each one. Supported action spaces include `MultiDiscrete`,
`Discrete`, `Box`, and `Dict`. For `Dict` action spaces, the function recursively
converts each subspace into its probability distribution space.
:param action_space: The input action space to be converted. This can be an instance
of `gym.spaces.MultiDiscrete`, `gym.spaces.Discrete`, `gym.spaces.Box`, or
`gym.spaces.Dict`.
:type action_space: Spaces
:raises NotImplementedError: If the input action space type is not supported.
:return: The corresponding probability distribution space created based on the input
action space type.
"""
if isinstance(action_space, gym.spaces.MultiDiscrete):
space_dist = MultiDiscreteDist.from_space(action_space)
elif isinstance(action_space, gym.spaces.Discrete):
space_dist = DiscreteDist.from_space(action_space) # type: ignore
elif isinstance(action_space, gym.spaces.Box):
space_dist = BoxDist.from_space(action_space) # type: ignore
elif isinstance(action_space, gym.spaces.Dict):
space_dist = DictDist()
for k, v in action_space.spaces.items():
space_dist[k] = convert_to_prob_space(v)
elif isinstance(action_space, gym.spaces.Tuple):
space_list = []
for v in action_space.spaces:
space_list.append(convert_to_prob_space(v))
space_dist = TupleDist(space_list)
else:
raise NotImplementedError(f"Action space {type(action_space)} not supported")
return space_dist