From 469b79451f8d9309b4f96504461a998591e1cffd Mon Sep 17 00:00:00 2001 From: mbi6245 Date: Tue, 14 May 2024 12:15:09 -0700 Subject: [PATCH] fix bugs in transforms.py, tests now passing --- .github/workflows/build.yml | 1 + src/distrx/transforms.py | 73 +++++++++++++++++++++++++------------ 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bf1a895..74cb5e1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,6 +17,7 @@ jobs: run: python -m pip install .[test] --upgrade pip --no-cache-dir - name: Test with pytest run: pytest # --cov=./ --cov-report=xml + # repository > settings > github apps > configure codecov # - name: Build package distribution # if: startsWith(github.ref, 'refs/tags') diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index ec985a3..96e0384 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -9,45 +9,64 @@ * Add decorators for accepting floats or vectors """ -from typing import Tuple + import warnings +from typing import Tuple import numpy as np import numpy.typing as npt class FirstOrder: - def __init__(self, transform: str, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: + def __init__( + self, transform: str, mu: npt.ArrayLike, sigma: npt.ArrayLike + ) -> None: # you may be able to make transform a field, though it might not be necessary - # TODO: DSLETE DS.STORE - # TODO: USE RUFF LINTER # TODO: put each parameter on its separate line if theres a lot (reference pypkg style guide) - self.transform = input(transform) + self.transform = transform self.mu = mu self.sigma = sigma match self.transform: case "log": - self.mu_trans, self.sigma_trans = self.log_trans(self.mu, self.sigma) + self.mu_trans, self.sigma_trans = self.log_trans( + self.mu, self.sigma + ) case "logit": - self.mu_trans, self.sigma_trans = self.logit_trans(self.mu, self.sigma) + self.mu_trans, self.sigma_trans = self.logit_trans( + self.mu, self.sigma + ) case "exp": - self.mu_trans, self.sigma_trans = self.exp_trans(self.mu, self.sigma) + self.mu_trans, self.sigma_trans = self.exp_trans( + self.mu, self.sigma + ) case "expit": - self.mu_trans, self.sigma_trans = self.expit_trans(self.mu, self.sigma) + self.mu_trans, self.sigma_trans = self.expit_trans( + self.mu, self.sigma + ) case _: raise ValueError(f"Invalid transform '{self.transform}'.") - def log_trans(self, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> Tuple[np.ndarray, np.ndarray]: + def log_trans( + self, mu: npt.ArrayLike, sigma: npt.ArrayLike + ) -> Tuple[np.ndarray, np.ndarray]: return np.log(mu), sigma / mu - def logit_trans(self, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> Tuple[np.ndarray, np.ndarray]: + def logit_trans( + self, mu: npt.ArrayLike, sigma: npt.ArrayLike + ) -> Tuple[np.ndarray, np.ndarray]: return np.log(mu / (1.0 - mu)), sigma / (mu * (1.0 - mu)) - def exp_trans(self, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> Tuple[np.ndarray, np.ndarray]: - return np.exp(mu), np.exp(mu) + def exp_trans( + self, mu: npt.ArrayLike, sigma: npt.ArrayLike + ) -> Tuple[np.ndarray, np.ndarray]: + return np.exp(mu), sigma * np.exp(mu) - def expit_trans(self, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> Tuple[np.ndarray, np.ndarray]: - return 1.0 / (1.0 + np.exp(-mu)), sigma * np.exp(-mu) / (1.0 + np.exp(-mu)) ** 2 + def expit_trans( + self, mu: npt.ArrayLike, sigma: npt.ArrayLike + ) -> Tuple[np.ndarray, np.ndarray]: + return 1.0 / (1.0 + np.exp(-mu)), sigma * np.exp(-mu) / ( + 1.0 + np.exp(-mu) + ) ** 2 def get_mu_trans(self): return self.mu_trans @@ -56,11 +75,15 @@ def get_sigma_trans(self): return self.sigma_trans -METHOD_LIST = ['delta'] +METHOD_LIST = ["delta"] -def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, - method: str = 'delta') -> Tuple[np.ndarray, np.ndarray]: +def transform_data( + mu: npt.ArrayLike, + sigma: npt.ArrayLike, + transform: str, + method: str = "delta", +) -> Tuple[np.ndarray, np.ndarray]: """Transform data from one space to another. Transform data, in the form of sample statistics and their standard @@ -89,12 +112,13 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, """ mu, sigma = np.array(mu), np.array(sigma) _check_input(method, transform, mu, sigma) - if method == 'delta': + if method == "delta": return delta_method(mu, sigma, transform) -def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \ - Tuple[np.ndarray, np.ndarray]: +def delta_method( + mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str +) -> Tuple[np.ndarray, np.ndarray]: """Transform data using the delta method. Transform data, in the form of sample statistics and their standard @@ -126,13 +150,14 @@ def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \ """ mu, sigma = np.array(mu), np.array(sigma) - _check_input('delta', transform, mu, sigma) + _check_input("delta", transform, mu, sigma) transformer = FirstOrder(transform, mu, sigma) return transformer.get_mu_trans(), transformer.get_sigma_trans() -def _check_input(method: str, transform: str, mu: npt.ArrayLike, - sigma: npt.ArrayLike) -> None: +def _check_input( + method: str, transform: str, mu: npt.ArrayLike, sigma: npt.ArrayLike +) -> None: """Run checks on input data. Parameters