Source code for prob_spaces.dict

from gymnasium import spaces


[docs] class DictDist(spaces.Dict):
[docs] def __call__(self, prob: dict, mask: dict = None) -> dict: """ Create a dict of distributions based on input probabilities. Args: prob: Dictionary of probability tensors for each space mask: Optional dictionary of masks for each space Returns: Dictionary of distribution objects """ dist_dict = {} mask = mask or {} for key, s in self.spaces.items(): if isinstance(s, spaces.Discrete) or isinstance(s, spaces.MultiDiscrete): space_mask = mask.get(key, None) if isinstance(mask, dict) else None dist_dict[key] = s(prob[key], space_mask) # type: ignore else: dist_dict[key] = s( # type: ignore prob[key][0], prob[key][1], ) return dist_dict