diff --git a/setup.py b/setup.py index 3baea96..d3a9cb6 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ install_requirements = [ 'numpy', + 'pandas', 'scipy', ] diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py new file mode 100644 index 0000000..bae43f0 --- /dev/null +++ b/src/distrx/transforms.py @@ -0,0 +1,185 @@ +"""Transform data from one space to another. + +Transform data, in the form of sample statistics and their standard +errors, from one space to another using a given transform function. + +TODO: +* Add user-defined transform function +* Add functions for confidence intervals +* Add decorators for accepting floats or vectors + +""" +from typing import Tuple +import warnings + +import numpy as np +import numpy.typing as npt + + +TRANSFORM_DICT = { + 'log': [ + np.log, + lambda x: 1.0/x + ], 'logit': [ + lambda x: np.log(x/(1.0 - x)), + lambda x: 1.0/(x*(1.0 - x)) + ], 'exp': [ + np.exp, + np.exp + ], 'expit': [ + lambda x: 1.0/(1.0 + np.exp(-x)), + lambda x: np.exp(-x)/(1.0 + np.exp(-x))**2 + ] +} +METHOD_LIST = ['delta'] + + +def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, + method: str = 'delta') -> Tuple[np.ndarray, np.ndarray]: + """Transform data from one space to another. + + Transform data, in the form of sample statistics and their standard + errors, from one space to another using a given transform function. + No assumptions are made about the underlying distributions of the + given data. + + Parameters + ---------- + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. + method : {'delta'}, optional + Method used to transform data. + + Returns + ------- + mu_trans : numpy.ndarray + Sample stastistics in the transform space. + sigma_trans : numpy.ndarray + Standard errors in the transform space. + + """ + mu, sigma = np.array(mu), np.array(sigma) + _check_input(method, transform, mu, sigma) + if method == 'delta': + return delta_method(mu, sigma, transform) + + +def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \ + Tuple[np.ndarray, np.ndarray]: + """Transform data using the delta method. + + Transform data, in the form of sample statistics and their standard + errors, from one space to another using a given transform function + and the delta method. No assumptions are made about the underlying + distributions of the given data. + + Parameters + ---------- + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. + + Returns + ------- + mu_trans : numpy.ndarray + Sample statistics in the transform space. + sigma_trans : numpy.ndarray + Standard errors in the transform space. + + Notes + ----- + The delta method expands a function of a random variable about its + mean with a one-step Taylor approximation and then takes the + variance. + + """ + mu, sigma = np.array(mu), np.array(sigma) + _check_input('delta', transform, mu, sigma) + mu_trans = TRANSFORM_DICT[transform][0](mu) + sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu) + return mu_trans, sigma_trans + + +def _check_input(method: str, transform: str, mu: npt.ArrayLike, + sigma: npt.ArrayLike) -> None: + """Run checks on input data. + + Parameters + ---------- + method : {'delta'} + Method used to transform data. + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + + """ + _check_method_valid(method) + _check_transform_valid(transform) + _check_lengths_match(mu, sigma) + _check_sigma_positive(sigma) + + +def _check_method_valid(method: str) -> None: + """Check that `method` is in METHOD_LIST. + + Parameters + ---------- + method : {'delta'} + Method used to transform data. + + """ + if method not in METHOD_LIST: + raise ValueError(f"Invalid method '{method}'.") + + +def _check_transform_valid(transform: str) -> None: + """Check that `transform` is in TRANSFORM_DICT. + + Parameters + ---------- + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. + + """ + if transform not in TRANSFORM_DICT: + raise ValueError(f"Invalid transform '{transform}'.") + + +def _check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: + """Check that `mu` and `sigma` have the same lengths. + + Parameters + ---------- + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + + """ + if len(mu) != len(sigma): + raise ValueError("Lengths of mu and sigma don't match.") + + +def _check_sigma_positive(sigma: npt.ArrayLike) -> None: + """Check that `sigma` is positive. + + Parameters + ---------- + sigma : array_like + Standard errors. + + """ + if np.any(sigma == 0.0): + warnings.warn("Sigma vector contains zeros.") + if np.any(sigma < 0.0): + raise ValueError("Sigma values must be positive.") diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..0a84206 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,94 @@ +"""Tests for transforms.py module.""" +import numpy as np +import pytest + +from distrx.transforms import transform_data, delta_method + + +TRANSFORM_DICT = { + 'log': [ + np.log, + lambda x: 1.0/x + ], 'logit': [ + lambda x: np.log(x/(1.0 - x)), + lambda x: 1.0/(x*(1.0 - x)) + ], 'exp': [ + np.exp, + np.exp + ], 'expit': [ + lambda x: 1.0/(1.0 + np.exp(-x)), + lambda x: np.exp(-x)/(1.0 + np.exp(-x))**2 + ] +} +TRANSFORM_LIST = list(TRANSFORM_DICT.keys()) +FUNCTION_LIST = [transform_data, delta_method] +VALS = [0.1]*2 + + +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_method_name_valid(transform): + """Raise ValueError for invalue `method`.""" + with pytest.raises(ValueError): + transform_data(VALS, VALS, transform, method='dummy') + + +@pytest.mark.parametrize('function', FUNCTION_LIST) +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_input_len_match(function, transform): + """Raise ValueError if lengths of input vectors don't match.""" + with pytest.raises(ValueError): + function(VALS, VALS*2, transform) + + +@pytest.mark.parametrize('function', FUNCTION_LIST) +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_sigma_negative(function, transform): + """Raise ValueError if `sigma` contains negative values.""" + vals = VALS + [-0.1] + with pytest.raises(ValueError): + function(vals, vals, transform) + + +@pytest.mark.parametrize('function', FUNCTION_LIST) +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_sigma_zero(function, transform): + """Display warning if `sigma` contains zeros.""" + vals = VALS + [0.0] + with pytest.warns(UserWarning): + function(vals, vals, transform) + + +@pytest.mark.parametrize('function', FUNCTION_LIST) +def test_transform_name_valid(function): + """Raise ValueError for invalid `transform`.""" + with pytest.raises(ValueError): + function(VALS, VALS, 'dummy') + + +@pytest.mark.parametrize('function', FUNCTION_LIST) +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_output_type(function, transform): + """Output should be numpy arrays.""" + mu, sigma = function(VALS, VALS, transform) + assert isinstance(mu, np.ndarray) + assert isinstance(sigma, np.ndarray) + + +@pytest.mark.parametrize('function', FUNCTION_LIST) +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_outout_len_match(function, transform): + """Length of output vectors should match.""" + mu, sigma = function(VALS, VALS, transform) + assert len(mu) == len(sigma) + + +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_delta_result(transform): + """Check expected results.""" + mu = np.random.uniform(0.1, 1.0, size=10) + sigma = np.random.uniform(0.1, 1.0, size=10) + mu_ref = TRANSFORM_DICT[transform][0](mu) + sigma_ref = sigma*TRANSFORM_DICT[transform][1](mu) + mu_trans, sigma_trans = delta_method(mu, sigma, transform) + assert np.allclose(mu_trans, mu_ref) + assert np.allclose(sigma_trans, sigma_ref)