-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from ihmeuw-msca/develop
Add initial module transform.py with delta method
- Loading branch information
Showing
3 changed files
with
280 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
|
||
install_requirements = [ | ||
'numpy', | ||
'pandas', | ||
'scipy', | ||
] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
"""Transform data from one space to another. | ||
Transform data, in the form of sample statistics and their standard | ||
errors, from one space to another using a given transform function. | ||
TODO: | ||
* Add user-defined transform function | ||
* Add functions for confidence intervals | ||
* Add decorators for accepting floats or vectors | ||
""" | ||
from typing import Tuple | ||
import warnings | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
|
||
TRANSFORM_DICT = { | ||
'log': [ | ||
np.log, | ||
lambda x: 1.0/x | ||
], 'logit': [ | ||
lambda x: np.log(x/(1.0 - x)), | ||
lambda x: 1.0/(x*(1.0 - x)) | ||
], 'exp': [ | ||
np.exp, | ||
np.exp | ||
], 'expit': [ | ||
lambda x: 1.0/(1.0 + np.exp(-x)), | ||
lambda x: np.exp(-x)/(1.0 + np.exp(-x))**2 | ||
] | ||
} | ||
METHOD_LIST = ['delta'] | ||
|
||
|
||
def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, | ||
method: str = 'delta') -> Tuple[np.ndarray, np.ndarray]: | ||
"""Transform data from one space to another. | ||
Transform data, in the form of sample statistics and their standard | ||
errors, from one space to another using a given transform function. | ||
No assumptions are made about the underlying distributions of the | ||
given data. | ||
Parameters | ||
---------- | ||
mu : array_like | ||
Sample statistics. | ||
sigma : array_like | ||
Standard errors. | ||
transform : {'log', 'logit', 'exp', 'expit'} | ||
Transform function. | ||
method : {'delta'}, optional | ||
Method used to transform data. | ||
Returns | ||
------- | ||
mu_trans : numpy.ndarray | ||
Sample stastistics in the transform space. | ||
sigma_trans : numpy.ndarray | ||
Standard errors in the transform space. | ||
""" | ||
mu, sigma = np.array(mu), np.array(sigma) | ||
_check_input(method, transform, mu, sigma) | ||
if method == 'delta': | ||
return delta_method(mu, sigma, transform) | ||
|
||
|
||
def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \ | ||
Tuple[np.ndarray, np.ndarray]: | ||
"""Transform data using the delta method. | ||
Transform data, in the form of sample statistics and their standard | ||
errors, from one space to another using a given transform function | ||
and the delta method. No assumptions are made about the underlying | ||
distributions of the given data. | ||
Parameters | ||
---------- | ||
mu : array_like | ||
Sample statistics. | ||
sigma : array_like | ||
Standard errors. | ||
transform : {'log', 'logit', 'exp', 'expit'} | ||
Transform function. | ||
Returns | ||
------- | ||
mu_trans : numpy.ndarray | ||
Sample statistics in the transform space. | ||
sigma_trans : numpy.ndarray | ||
Standard errors in the transform space. | ||
Notes | ||
----- | ||
The delta method expands a function of a random variable about its | ||
mean with a one-step Taylor approximation and then takes the | ||
variance. | ||
""" | ||
mu, sigma = np.array(mu), np.array(sigma) | ||
_check_input('delta', transform, mu, sigma) | ||
mu_trans = TRANSFORM_DICT[transform][0](mu) | ||
sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu) | ||
return mu_trans, sigma_trans | ||
|
||
|
||
def _check_input(method: str, transform: str, mu: npt.ArrayLike, | ||
sigma: npt.ArrayLike) -> None: | ||
"""Run checks on input data. | ||
Parameters | ||
---------- | ||
method : {'delta'} | ||
Method used to transform data. | ||
transform : {'log', 'logit', 'exp', 'expit'} | ||
Transform function. | ||
mu : array_like | ||
Sample statistics. | ||
sigma : array_like | ||
Standard errors. | ||
""" | ||
_check_method_valid(method) | ||
_check_transform_valid(transform) | ||
_check_lengths_match(mu, sigma) | ||
_check_sigma_positive(sigma) | ||
|
||
|
||
def _check_method_valid(method: str) -> None: | ||
"""Check that `method` is in METHOD_LIST. | ||
Parameters | ||
---------- | ||
method : {'delta'} | ||
Method used to transform data. | ||
""" | ||
if method not in METHOD_LIST: | ||
raise ValueError(f"Invalid method '{method}'.") | ||
|
||
|
||
def _check_transform_valid(transform: str) -> None: | ||
"""Check that `transform` is in TRANSFORM_DICT. | ||
Parameters | ||
---------- | ||
transform : {'log', 'logit', 'exp', 'expit'} | ||
Transform function. | ||
""" | ||
if transform not in TRANSFORM_DICT: | ||
raise ValueError(f"Invalid transform '{transform}'.") | ||
|
||
|
||
def _check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: | ||
"""Check that `mu` and `sigma` have the same lengths. | ||
Parameters | ||
---------- | ||
mu : array_like | ||
Sample statistics. | ||
sigma : array_like | ||
Standard errors. | ||
""" | ||
if len(mu) != len(sigma): | ||
raise ValueError("Lengths of mu and sigma don't match.") | ||
|
||
|
||
def _check_sigma_positive(sigma: npt.ArrayLike) -> None: | ||
"""Check that `sigma` is positive. | ||
Parameters | ||
---------- | ||
sigma : array_like | ||
Standard errors. | ||
""" | ||
if np.any(sigma == 0.0): | ||
warnings.warn("Sigma vector contains zeros.") | ||
if np.any(sigma < 0.0): | ||
raise ValueError("Sigma values must be positive.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
"""Tests for transforms.py module.""" | ||
import numpy as np | ||
import pytest | ||
|
||
from distrx.transforms import transform_data, delta_method | ||
|
||
|
||
TRANSFORM_DICT = { | ||
'log': [ | ||
np.log, | ||
lambda x: 1.0/x | ||
], 'logit': [ | ||
lambda x: np.log(x/(1.0 - x)), | ||
lambda x: 1.0/(x*(1.0 - x)) | ||
], 'exp': [ | ||
np.exp, | ||
np.exp | ||
], 'expit': [ | ||
lambda x: 1.0/(1.0 + np.exp(-x)), | ||
lambda x: np.exp(-x)/(1.0 + np.exp(-x))**2 | ||
] | ||
} | ||
TRANSFORM_LIST = list(TRANSFORM_DICT.keys()) | ||
FUNCTION_LIST = [transform_data, delta_method] | ||
VALS = [0.1]*2 | ||
|
||
|
||
@pytest.mark.parametrize('transform', TRANSFORM_LIST) | ||
def test_method_name_valid(transform): | ||
"""Raise ValueError for invalue `method`.""" | ||
with pytest.raises(ValueError): | ||
transform_data(VALS, VALS, transform, method='dummy') | ||
|
||
|
||
@pytest.mark.parametrize('function', FUNCTION_LIST) | ||
@pytest.mark.parametrize('transform', TRANSFORM_LIST) | ||
def test_input_len_match(function, transform): | ||
"""Raise ValueError if lengths of input vectors don't match.""" | ||
with pytest.raises(ValueError): | ||
function(VALS, VALS*2, transform) | ||
|
||
|
||
@pytest.mark.parametrize('function', FUNCTION_LIST) | ||
@pytest.mark.parametrize('transform', TRANSFORM_LIST) | ||
def test_sigma_negative(function, transform): | ||
"""Raise ValueError if `sigma` contains negative values.""" | ||
vals = VALS + [-0.1] | ||
with pytest.raises(ValueError): | ||
function(vals, vals, transform) | ||
|
||
|
||
@pytest.mark.parametrize('function', FUNCTION_LIST) | ||
@pytest.mark.parametrize('transform', TRANSFORM_LIST) | ||
def test_sigma_zero(function, transform): | ||
"""Display warning if `sigma` contains zeros.""" | ||
vals = VALS + [0.0] | ||
with pytest.warns(UserWarning): | ||
function(vals, vals, transform) | ||
|
||
|
||
@pytest.mark.parametrize('function', FUNCTION_LIST) | ||
def test_transform_name_valid(function): | ||
"""Raise ValueError for invalid `transform`.""" | ||
with pytest.raises(ValueError): | ||
function(VALS, VALS, 'dummy') | ||
|
||
|
||
@pytest.mark.parametrize('function', FUNCTION_LIST) | ||
@pytest.mark.parametrize('transform', TRANSFORM_LIST) | ||
def test_output_type(function, transform): | ||
"""Output should be numpy arrays.""" | ||
mu, sigma = function(VALS, VALS, transform) | ||
assert isinstance(mu, np.ndarray) | ||
assert isinstance(sigma, np.ndarray) | ||
|
||
|
||
@pytest.mark.parametrize('function', FUNCTION_LIST) | ||
@pytest.mark.parametrize('transform', TRANSFORM_LIST) | ||
def test_outout_len_match(function, transform): | ||
"""Length of output vectors should match.""" | ||
mu, sigma = function(VALS, VALS, transform) | ||
assert len(mu) == len(sigma) | ||
|
||
|
||
@pytest.mark.parametrize('transform', TRANSFORM_LIST) | ||
def test_delta_result(transform): | ||
"""Check expected results.""" | ||
mu = np.random.uniform(0.1, 1.0, size=10) | ||
sigma = np.random.uniform(0.1, 1.0, size=10) | ||
mu_ref = TRANSFORM_DICT[transform][0](mu) | ||
sigma_ref = sigma*TRANSFORM_DICT[transform][1](mu) | ||
mu_trans, sigma_trans = delta_method(mu, sigma, transform) | ||
assert np.allclose(mu_trans, mu_ref) | ||
assert np.allclose(sigma_trans, sigma_ref) |