Skip to content

Commit

Permalink
Merge pull request #1 from ihmeuw-msca/develop
Browse files Browse the repository at this point in the history
Add initial module transform.py with delta method
  • Loading branch information
zhengp0 committed May 6, 2024
2 parents 4727bd1 + 67479c9 commit 7c7e90e
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

install_requirements = [
'numpy',
'pandas',
'scipy',
]

Expand Down
185 changes: 185 additions & 0 deletions src/distrx/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""Transform data from one space to another.
Transform data, in the form of sample statistics and their standard
errors, from one space to another using a given transform function.
TODO:
* Add user-defined transform function
* Add functions for confidence intervals
* Add decorators for accepting floats or vectors
"""
from typing import Tuple
import warnings

import numpy as np
import numpy.typing as npt


TRANSFORM_DICT = {
'log': [
np.log,
lambda x: 1.0/x
], 'logit': [
lambda x: np.log(x/(1.0 - x)),
lambda x: 1.0/(x*(1.0 - x))
], 'exp': [
np.exp,
np.exp
], 'expit': [
lambda x: 1.0/(1.0 + np.exp(-x)),
lambda x: np.exp(-x)/(1.0 + np.exp(-x))**2
]
}
METHOD_LIST = ['delta']


def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str,
method: str = 'delta') -> Tuple[np.ndarray, np.ndarray]:
"""Transform data from one space to another.
Transform data, in the form of sample statistics and their standard
errors, from one space to another using a given transform function.
No assumptions are made about the underlying distributions of the
given data.
Parameters
----------
mu : array_like
Sample statistics.
sigma : array_like
Standard errors.
transform : {'log', 'logit', 'exp', 'expit'}
Transform function.
method : {'delta'}, optional
Method used to transform data.
Returns
-------
mu_trans : numpy.ndarray
Sample stastistics in the transform space.
sigma_trans : numpy.ndarray
Standard errors in the transform space.
"""
mu, sigma = np.array(mu), np.array(sigma)
_check_input(method, transform, mu, sigma)
if method == 'delta':
return delta_method(mu, sigma, transform)


def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \
Tuple[np.ndarray, np.ndarray]:
"""Transform data using the delta method.
Transform data, in the form of sample statistics and their standard
errors, from one space to another using a given transform function
and the delta method. No assumptions are made about the underlying
distributions of the given data.
Parameters
----------
mu : array_like
Sample statistics.
sigma : array_like
Standard errors.
transform : {'log', 'logit', 'exp', 'expit'}
Transform function.
Returns
-------
mu_trans : numpy.ndarray
Sample statistics in the transform space.
sigma_trans : numpy.ndarray
Standard errors in the transform space.
Notes
-----
The delta method expands a function of a random variable about its
mean with a one-step Taylor approximation and then takes the
variance.
"""
mu, sigma = np.array(mu), np.array(sigma)
_check_input('delta', transform, mu, sigma)
mu_trans = TRANSFORM_DICT[transform][0](mu)
sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu)
return mu_trans, sigma_trans


def _check_input(method: str, transform: str, mu: npt.ArrayLike,
sigma: npt.ArrayLike) -> None:
"""Run checks on input data.
Parameters
----------
method : {'delta'}
Method used to transform data.
transform : {'log', 'logit', 'exp', 'expit'}
Transform function.
mu : array_like
Sample statistics.
sigma : array_like
Standard errors.
"""
_check_method_valid(method)
_check_transform_valid(transform)
_check_lengths_match(mu, sigma)
_check_sigma_positive(sigma)


def _check_method_valid(method: str) -> None:
"""Check that `method` is in METHOD_LIST.
Parameters
----------
method : {'delta'}
Method used to transform data.
"""
if method not in METHOD_LIST:
raise ValueError(f"Invalid method '{method}'.")


def _check_transform_valid(transform: str) -> None:
"""Check that `transform` is in TRANSFORM_DICT.
Parameters
----------
transform : {'log', 'logit', 'exp', 'expit'}
Transform function.
"""
if transform not in TRANSFORM_DICT:
raise ValueError(f"Invalid transform '{transform}'.")


def _check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None:
"""Check that `mu` and `sigma` have the same lengths.
Parameters
----------
mu : array_like
Sample statistics.
sigma : array_like
Standard errors.
"""
if len(mu) != len(sigma):
raise ValueError("Lengths of mu and sigma don't match.")


def _check_sigma_positive(sigma: npt.ArrayLike) -> None:
"""Check that `sigma` is positive.
Parameters
----------
sigma : array_like
Standard errors.
"""
if np.any(sigma == 0.0):
warnings.warn("Sigma vector contains zeros.")
if np.any(sigma < 0.0):
raise ValueError("Sigma values must be positive.")
94 changes: 94 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Tests for transforms.py module."""
import numpy as np
import pytest

from distrx.transforms import transform_data, delta_method


TRANSFORM_DICT = {
'log': [
np.log,
lambda x: 1.0/x
], 'logit': [
lambda x: np.log(x/(1.0 - x)),
lambda x: 1.0/(x*(1.0 - x))
], 'exp': [
np.exp,
np.exp
], 'expit': [
lambda x: 1.0/(1.0 + np.exp(-x)),
lambda x: np.exp(-x)/(1.0 + np.exp(-x))**2
]
}
TRANSFORM_LIST = list(TRANSFORM_DICT.keys())
FUNCTION_LIST = [transform_data, delta_method]
VALS = [0.1]*2


@pytest.mark.parametrize('transform', TRANSFORM_LIST)
def test_method_name_valid(transform):
"""Raise ValueError for invalue `method`."""
with pytest.raises(ValueError):
transform_data(VALS, VALS, transform, method='dummy')


@pytest.mark.parametrize('function', FUNCTION_LIST)
@pytest.mark.parametrize('transform', TRANSFORM_LIST)
def test_input_len_match(function, transform):
"""Raise ValueError if lengths of input vectors don't match."""
with pytest.raises(ValueError):
function(VALS, VALS*2, transform)


@pytest.mark.parametrize('function', FUNCTION_LIST)
@pytest.mark.parametrize('transform', TRANSFORM_LIST)
def test_sigma_negative(function, transform):
"""Raise ValueError if `sigma` contains negative values."""
vals = VALS + [-0.1]
with pytest.raises(ValueError):
function(vals, vals, transform)


@pytest.mark.parametrize('function', FUNCTION_LIST)
@pytest.mark.parametrize('transform', TRANSFORM_LIST)
def test_sigma_zero(function, transform):
"""Display warning if `sigma` contains zeros."""
vals = VALS + [0.0]
with pytest.warns(UserWarning):
function(vals, vals, transform)


@pytest.mark.parametrize('function', FUNCTION_LIST)
def test_transform_name_valid(function):
"""Raise ValueError for invalid `transform`."""
with pytest.raises(ValueError):
function(VALS, VALS, 'dummy')


@pytest.mark.parametrize('function', FUNCTION_LIST)
@pytest.mark.parametrize('transform', TRANSFORM_LIST)
def test_output_type(function, transform):
"""Output should be numpy arrays."""
mu, sigma = function(VALS, VALS, transform)
assert isinstance(mu, np.ndarray)
assert isinstance(sigma, np.ndarray)


@pytest.mark.parametrize('function', FUNCTION_LIST)
@pytest.mark.parametrize('transform', TRANSFORM_LIST)
def test_outout_len_match(function, transform):
"""Length of output vectors should match."""
mu, sigma = function(VALS, VALS, transform)
assert len(mu) == len(sigma)


@pytest.mark.parametrize('transform', TRANSFORM_LIST)
def test_delta_result(transform):
"""Check expected results."""
mu = np.random.uniform(0.1, 1.0, size=10)
sigma = np.random.uniform(0.1, 1.0, size=10)
mu_ref = TRANSFORM_DICT[transform][0](mu)
sigma_ref = sigma*TRANSFORM_DICT[transform][1](mu)
mu_trans, sigma_trans = delta_method(mu, sigma, transform)
assert np.allclose(mu_trans, mu_ref)
assert np.allclose(sigma_trans, sigma_ref)

0 comments on commit 7c7e90e

Please sign in to comment.