Skip to content

Commit

Permalink
Add param structure to distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-hirsch committed Aug 30, 2024
1 parent 31e191b commit e0cf3f8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/rolch/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def derivative(self, x: np.ndarray) -> np.ndarray:

class Distribution(ABC):

self.links: Dict[int, Linkfunction] | List[Linkfunction]
self._param_structure: Dict[int, str]

def _check_links(self):
for p in range(self.n_params):
if self.param_structure[p] not in self.links[p]._valid_structures:
raise ValueError(
f"Link function does not match parameter structure for parameter {p}. \n"
f"Parameter structure is {self.param_structure[p]}. \n"
f"Link function supports {self.links[p]._valid_structures}"
)

@abstractmethod
def theta_to_params(self, theta: np.ndarray) -> Tuple:
"""Take the fitted values and return tuple of vectors for distribution parameters."""
Expand Down
2 changes: 2 additions & 0 deletions src/rolch/distributions/johnsonsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __init__(
self.shape_link, # skew
self.tail_link, # tail
]
self._param_structure = {0: "vector", 1: "vector", 2: "vector", 3: "vector"}
self._check_links()

def theta_to_params(self, theta):
mu = theta[:, 0]
Expand Down
2 changes: 2 additions & 0 deletions src/rolch/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def __init__(self, loc_link=IdentityLink(), scale_link=LogLink()):
self.loc_link = loc_link
self.scale_link = scale_link
self.links = [self.loc_link, self.scale_link]
self._param_structure = {0: "vector", 1: "vector", 2}
self._check_links()

def theta_to_params(self, theta):
mu = theta[:, 0]
Expand Down
3 changes: 3 additions & 0 deletions src/rolch/distributions/studentt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def __init__(
self.scale_link = scale_link
self.tail_link = tail_link
self.links = [self.loc_link, self.scale_link, self.tail_link]
self._param_structure = {0: "vector", 1: "vector", 2, 3: "vector"}
self._check_links()


def theta_to_params(self, theta):
mu = theta[:, 0]
Expand Down

0 comments on commit e0cf3f8

Please sign in to comment.