diff --git a/src/rolch/abc.py b/src/rolch/abc.py index d36b60f..84f018c 100644 --- a/src/rolch/abc.py +++ b/src/rolch/abc.py @@ -24,6 +24,18 @@ def derivative(self, x: np.ndarray) -> np.ndarray: class Distribution(ABC): + self.links: Dict[int, Linkfunction] | List[Linkfunction] + self._param_structure: Dict[int, str] + + def _check_links(self): + for p in range(self.n_params): + if self.param_structure[p] not in self.links[p]._valid_structures: + raise ValueError( + f"Link function does not match parameter structure for parameter {p}. \n" + f"Parameter structure is {self.param_structure[p]}. \n" + f"Link function supports {self.links[p]._valid_structures}" + ) + @abstractmethod def theta_to_params(self, theta: np.ndarray) -> Tuple: """Take the fitted values and return tuple of vectors for distribution parameters.""" diff --git a/src/rolch/distributions/johnsonsu.py b/src/rolch/distributions/johnsonsu.py index 26644a8..371c7a1 100644 --- a/src/rolch/distributions/johnsonsu.py +++ b/src/rolch/distributions/johnsonsu.py @@ -34,6 +34,8 @@ def __init__( self.shape_link, # skew self.tail_link, # tail ] + self._param_structure = {0: "vector", 1: "vector", 2: "vector", 3: "vector"} + self._check_links() def theta_to_params(self, theta): mu = theta[:, 0] diff --git a/src/rolch/distributions/normal.py b/src/rolch/distributions/normal.py index 2ebe6f5..d7e55a0 100644 --- a/src/rolch/distributions/normal.py +++ b/src/rolch/distributions/normal.py @@ -13,6 +13,8 @@ def __init__(self, loc_link=IdentityLink(), scale_link=LogLink()): self.loc_link = loc_link self.scale_link = scale_link self.links = [self.loc_link, self.scale_link] + self._param_structure = {0: "vector", 1: "vector", 2} + self._check_links() def theta_to_params(self, theta): mu = theta[:, 0] diff --git a/src/rolch/distributions/studentt.py b/src/rolch/distributions/studentt.py index d54a012..c48a0d5 100644 --- a/src/rolch/distributions/studentt.py +++ b/src/rolch/distributions/studentt.py @@ -17,6 +17,9 @@ def __init__( self.scale_link = scale_link self.tail_link = tail_link self.links = [self.loc_link, self.scale_link, self.tail_link] + self._param_structure = {0: "vector", 1: "vector", 2, 3: "vector"} + self._check_links() + def theta_to_params(self, theta): mu = theta[:, 0]