Source code for prob_spaces.dict

"""Module for probability distributions over Dict spaces."""

from gymnasium import spaces


[docs] class DictDist(spaces.Dict): """Probability distribution for Dict spaces."""
[docs] def __call__(self, prob: dict, mask: dict = None) -> dict: """Create a dict of distributions based on input probabilities. Args: prob: Dictionary of probability tensors for each space. mask: Optional dictionary of masks for each space. Returns: Dictionary of distribution objects. """ dist_dict = {} mask = mask or {} for key, s in self.spaces.items(): if isinstance(s, spaces.Discrete) or isinstance(s, spaces.MultiDiscrete): space_mask = mask.get(key, None) if isinstance(mask, dict) else None dist_dict[key] = s(prob[key], space_mask) # type: ignore else: dist_dict[key] = s( # type: ignore prob[key][0], prob[key][1], ) return dist_dict