Source code for prob_spaces.dists.categorical

from typing import Any, Optional, Sequence, Union

import numpy as np
import torch
import torch as th
from numpy._typing import NDArray
from torchrl.modules.distributions import MaskedCategorical  # type: ignore[import-untyped]


[docs] class CategoricalDist(MaskedCategorical): def __init__( self, logits: Optional[th.Tensor] = None, probs: Optional[th.Tensor] = None, *, mask: th.Tensor = None, indices: th.Tensor = None, neg_inf: float = float("-inf"), padding_value: Optional[int] = None, start: int | np.integer[Any] | NDArray[np.integer[Any]] | list[int] = 0, ) -> None: super().__init__(logits, probs, mask=mask, indices=indices, neg_inf=neg_inf, padding_value=padding_value) self.start = start @property def th_start(self) -> th.Tensor: return th.tensor(self.start, device=self.probs.device)
[docs] def sample( self, sample_shape: Optional[Union[th.Size, Sequence[int]]] = None, ) -> th.Tensor: sample = super().sample(sample_shape) exact_sample = self._calc_exact(sample, sample_shape) return exact_sample
[docs] def rsample( self, sample_shape: Optional[Union[th.Size, Sequence[int]]] = None, ) -> th.Tensor: sample = super().rsample(sample_shape) exact_sample = self._calc_exact(sample, sample_shape) return exact_sample
def _calc_exact( self, sample: th.Tensor, sample_shape: Optional[Union[th.Size, Sequence[int]]], ) -> th.Tensor: if not isinstance(self.start, np.ndarray) or sum(self.start.shape) == 1: exact_sample = sample + self.start # type: ignore else: shape = self.start.shape if sample_shape is None else (*sample_shape, *self.start.shape) exact_sample = sample.reshape(shape) + self.th_start return exact_sample # type: ignore
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor: return super().log_prob(value=value - self.th_start)