Skip to content

Commit

Permalink
fix bugs in transforms.py, tests now passing
Browse files Browse the repository at this point in the history
  • Loading branch information
mbi6245 committed May 14, 2024
1 parent 8be74bd commit 469b794
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
run: python -m pip install .[test] --upgrade pip --no-cache-dir
- name: Test with pytest
run: pytest # --cov=./ --cov-report=xml
# repository > settings > github apps > configure codecov

# - name: Build package distribution
# if: startsWith(github.ref, 'refs/tags')
Expand Down
73 changes: 49 additions & 24 deletions src/distrx/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,64 @@
* Add decorators for accepting floats or vectors
"""
from typing import Tuple

import warnings
from typing import Tuple

import numpy as np
import numpy.typing as npt


class FirstOrder:
def __init__(self, transform: str, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None:
def __init__(
self, transform: str, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> None:
# you may be able to make transform a field, though it might not be necessary
# TODO: DSLETE DS.STORE
# TODO: USE RUFF LINTER
# TODO: put each parameter on its separate line if theres a lot (reference pypkg style guide)
self.transform = input(transform)
self.transform = transform
self.mu = mu
self.sigma = sigma
match self.transform:
case "log":
self.mu_trans, self.sigma_trans = self.log_trans(self.mu, self.sigma)
self.mu_trans, self.sigma_trans = self.log_trans(
self.mu, self.sigma
)
case "logit":
self.mu_trans, self.sigma_trans = self.logit_trans(self.mu, self.sigma)
self.mu_trans, self.sigma_trans = self.logit_trans(
self.mu, self.sigma
)
case "exp":
self.mu_trans, self.sigma_trans = self.exp_trans(self.mu, self.sigma)
self.mu_trans, self.sigma_trans = self.exp_trans(
self.mu, self.sigma
)
case "expit":
self.mu_trans, self.sigma_trans = self.expit_trans(self.mu, self.sigma)
self.mu_trans, self.sigma_trans = self.expit_trans(
self.mu, self.sigma
)
case _:
raise ValueError(f"Invalid transform '{self.transform}'.")

def log_trans(self, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
def log_trans(
self, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
return np.log(mu), sigma / mu

def logit_trans(self, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
def logit_trans(
self, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
return np.log(mu / (1.0 - mu)), sigma / (mu * (1.0 - mu))

def exp_trans(self, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
return np.exp(mu), np.exp(mu)
def exp_trans(
self, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
return np.exp(mu), sigma * np.exp(mu)

def expit_trans(self, mu: npt.ArrayLike, sigma: npt.ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
return 1.0 / (1.0 + np.exp(-mu)), sigma * np.exp(-mu) / (1.0 + np.exp(-mu)) ** 2
def expit_trans(
self, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> Tuple[np.ndarray, np.ndarray]:
return 1.0 / (1.0 + np.exp(-mu)), sigma * np.exp(-mu) / (
1.0 + np.exp(-mu)
) ** 2

def get_mu_trans(self):
return self.mu_trans
Expand All @@ -56,11 +75,15 @@ def get_sigma_trans(self):
return self.sigma_trans


METHOD_LIST = ['delta']
METHOD_LIST = ["delta"]


def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str,
method: str = 'delta') -> Tuple[np.ndarray, np.ndarray]:
def transform_data(
mu: npt.ArrayLike,
sigma: npt.ArrayLike,
transform: str,
method: str = "delta",
) -> Tuple[np.ndarray, np.ndarray]:
"""Transform data from one space to another.
Transform data, in the form of sample statistics and their standard
Expand Down Expand Up @@ -89,12 +112,13 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str,
"""
mu, sigma = np.array(mu), np.array(sigma)
_check_input(method, transform, mu, sigma)
if method == 'delta':
if method == "delta":
return delta_method(mu, sigma, transform)


def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \
Tuple[np.ndarray, np.ndarray]:
def delta_method(
mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str
) -> Tuple[np.ndarray, np.ndarray]:
"""Transform data using the delta method.
Transform data, in the form of sample statistics and their standard
Expand Down Expand Up @@ -126,13 +150,14 @@ def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \
"""
mu, sigma = np.array(mu), np.array(sigma)
_check_input('delta', transform, mu, sigma)
_check_input("delta", transform, mu, sigma)
transformer = FirstOrder(transform, mu, sigma)
return transformer.get_mu_trans(), transformer.get_sigma_trans()


def _check_input(method: str, transform: str, mu: npt.ArrayLike,
sigma: npt.ArrayLike) -> None:
def _check_input(
method: str, transform: str, mu: npt.ArrayLike, sigma: npt.ArrayLike
) -> None:
"""Run checks on input data.
Parameters
Expand Down

0 comments on commit 469b794

Please sign in to comment.