From b2fc571ae4bca986bbbc6ed8b604356755cd0718 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Tue, 7 Sep 2021 13:30:40 -0700 Subject: [PATCH 01/32] add pandas to requirements --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 3baea96..d3a9cb6 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ install_requirements = [ 'numpy', + 'pandas', 'scipy', ] From 8af88fdf3c111050270021e272542083156cfd9a Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 13:21:14 -0700 Subject: [PATCH 02/32] adding outline of transform module --- src/distrx/transforms.py | 54 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 src/distrx/transforms.py diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py new file mode 100644 index 0000000..d8d3424 --- /dev/null +++ b/src/distrx/transforms.py @@ -0,0 +1,54 @@ +"""Transform data from one space to another. + +Transform data, in the form of sample statistics and their standard +errors, from one space to another using a given transform function. + +""" +import numpy as np + + +def transform_data(mu, sigma, transform, method='delta'): + """Transform data from one space to another. + + Transform data, in the form of sample statistics and their standard + errors, from one space to another using a given transform function. + No assumptions are made about the underlying distributions of the + given data. + + Parameters + ---------- + mu : float or array_like of float + Vector of sample statistics. + sigma : float or array_like of float + Vector of standard errors. + transform : {'log', 'logit', array_like of function} + Transform function. Users may define a transform function by + supplying a function and its derivative. If `method` is + 'delta2', the second derivative is also required. + method : {'delta, 'delta2'}, optional + Method used to transform data. + + Returns + ------- + mu_transform : array_like of float + Vector of sample stastistics in the transform space. + sigma_transform : array_like of float + Vector of standard errors in the transform space. + + """ + pass + + +def transform_delta(mu, sigma, transform): + """Transform data using the delta method.""" + pass + + +def transform_delta2(mu, sigma, transform): + """Transform data using the second-order delta method.""" + pass + + +def get_transform(transform, order=1): + """Get vector of transform function and its derivatives.""" + pass From 18703e47cae03155120b1f6c4886c9ed098a9227 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 13:38:52 -0700 Subject: [PATCH 03/32] filled in remaining docstrings --- src/distrx/transforms.py | 72 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 6 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index d8d3424..b1cf1fd 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -30,9 +30,9 @@ def transform_data(mu, sigma, transform, method='delta'): Returns ------- - mu_transform : array_like of float + mu_transform : float or array_like of float Vector of sample stastistics in the transform space. - sigma_transform : array_like of float + sigma_transform : float or array_like of float Vector of standard errors in the transform space. """ @@ -40,15 +40,75 @@ def transform_data(mu, sigma, transform, method='delta'): def transform_delta(mu, sigma, transform): - """Transform data using the delta method.""" + """Transform data using the delta method. + + Parameters + ---------- + mu : float or array_like of float + Vector of sample statistics. + sigma : float or array_like of float + Vector of standard errors. + transform : {'log', 'logit', array_like of function} + Transform function. Users may define a transform function by + supplying a function and its derivative. + + Returns + ------- + mu_transform : float or array_like of float + Vector of sample statistics in the transform space. + sigma_transform : float or array_like of float + Vector of standard errors in the transform space. + + Notes + ----- + Description of delta method. + + """ pass def transform_delta2(mu, sigma, transform): - """Transform data using the second-order delta method.""" + """Transform data using the second-order delta method. + + Parameters + ---------- + mu : float or array_like of float + Vector of sample statistics. + sigma : float or array_like of float + Vector of standard errors. + transform : {'log', 'logit', array_like of function} + Transform function. Users may define a transform function by + supplying a function and its first two derivatives. + + Returns + ------- + mu_transform : float or array_like of float + Vector of sample statistics in transform space. + sigma_transform : float or array_like of float + Vector of standard errors in transform space. + + Notes + ----- + Description of second-order delta method. + + """ pass -def get_transform(transform, order=1): - """Get vector of transform function and its derivatives.""" +def get_transform(transform, order=0): + """Get transform function and its derivative(s). + + Parameters + ---------- + transform : {'log', 'logit'} + Transform function. + order : {0, 1, 2}, optional + Highest order of derivative needed. + + Returns + ------- + transform : function or array_like of function + Transform function and its derivative(s). + + """ pass From 366ece7fe279f9eb30a3598706d5051843724de7 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 14:00:51 -0700 Subject: [PATCH 04/32] add descriptions for delta methods --- src/distrx/transforms.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index b1cf1fd..8dd0870 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -61,7 +61,9 @@ def transform_delta(mu, sigma, transform): Notes ----- - Description of delta method. + The delta method expands a function of a random variable about its + mean with a one-step Taylor approximation and then takes the + variance. """ pass @@ -89,7 +91,11 @@ def transform_delta2(mu, sigma, transform): Notes ----- - Description of second-order delta method. + The second-order delta method expands a function of a random + variable about its mean with a two-step Taylor approximation and + then takes the variance. This method is useful if the derivative of + the transform function is zero (so the first-order delta method + cannot be applied), or the sample size is small. """ pass From a5e36eee69b2549107e57c46786a9547e079435b Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 14:05:08 -0700 Subject: [PATCH 05/32] expand description of delta functions --- src/distrx/transforms.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 8dd0870..2f73d71 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -42,6 +42,11 @@ def transform_data(mu, sigma, transform, method='delta'): def transform_delta(mu, sigma, transform): """Transform data using the delta method. + Transform data, in the form of sample statistics and their standard + errors, from one space to another using a given transform function + and the delta method. No assumptions are made about the underlying + distributions of the given data. + Parameters ---------- mu : float or array_like of float @@ -72,6 +77,11 @@ def transform_delta(mu, sigma, transform): def transform_delta2(mu, sigma, transform): """Transform data using the second-order delta method. + Transform data, in the form of sample statistics and their standard + errors, from one space to another using a given transform function + and the second-order delta method. No assumptions are made about + the underlying distributions of the given data. + Parameters ---------- mu : float or array_like of float From 1b913607a2dc737586f2681a2ece19e5b4725e8e Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 14:11:52 -0700 Subject: [PATCH 06/32] add test module for transforms --- tests/test_transforms.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 tests/test_transforms.py diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..b3d90aa --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,8 @@ +"""Tests for transforms.py module.""" +import pytest + +import distrx.transforms as trx + + +def test_dummy(): + assert True From efaf9185447bbb002194e9623154aafea5082326 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 15:08:45 -0700 Subject: [PATCH 07/32] add tests for get_transform --- tests/test_transforms.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index b3d90aa..bc9c7bf 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,8 +1,37 @@ """Tests for transforms.py module.""" +import types + import pytest -import distrx.transforms as trx +from distrx.transforms import get_transform + + +def test_transform_transform_value(): + """Raise ValueError for invalid `transform`.""" + for order in [0, 1, 2]: + with pytest.raises(ValueError): + get_transform('dummy', order) + + +def test_transform_order_value(): + """Raise ValueError for invalid `order`.""" + for transform in ['log', 'logit']: + with pytest.raises(ValueError): + get_transform(transform, 3) + + +def test_transform_ouput_len(): + """Length of output should correspond to `order`.""" + for transform in ['log', 'logit']: + assert len(get_transform(transform, 1)) == 2 + assert len(get_transform(transform, 2)) == 3 -def test_dummy(): - assert True +def test_transform_output_type(): + """Type of output should correspond to `order`.""" + for transform in ['log', 'logit']: + assert isinstance(get_transform(transform), types.FunctionType) + assert isinstance(get_transform(transform, 0), types.FunctionType) + for order in [1, 2]: + for function in get_transform(transform, order): + assert isinstance(function, types.FunctionType) From 48632caf5bae80c1b86b3c9ea082b017f9a7f3ba Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 15:27:33 -0700 Subject: [PATCH 08/32] add exp and expit to tests --- tests/test_transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index bc9c7bf..0adc4fd 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -15,21 +15,21 @@ def test_transform_transform_value(): def test_transform_order_value(): """Raise ValueError for invalid `order`.""" - for transform in ['log', 'logit']: + for transform in ['log', 'logit', 'exp', 'expit']: with pytest.raises(ValueError): get_transform(transform, 3) def test_transform_ouput_len(): """Length of output should correspond to `order`.""" - for transform in ['log', 'logit']: + for transform in ['log', 'logit', 'exp', 'expit']: assert len(get_transform(transform, 1)) == 2 assert len(get_transform(transform, 2)) == 3 def test_transform_output_type(): """Type of output should correspond to `order`.""" - for transform in ['log', 'logit']: + for transform in ['log', 'logit', 'exp', 'expit']: assert isinstance(get_transform(transform), types.FunctionType) assert isinstance(get_transform(transform, 0), types.FunctionType) for order in [1, 2]: From a2706eb65de85446a33076c32e4fe28f41fad583 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 16:37:46 -0700 Subject: [PATCH 09/32] write get_transforms --- src/distrx/transforms.py | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 2f73d71..42674d4 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -114,17 +114,51 @@ def transform_delta2(mu, sigma, transform): def get_transform(transform, order=0): """Get transform function and its derivative(s). + Returns transform function if `order` is 0. + Otherwise returns an array of functions, including the transform + function and its derivatives up to the specified order. + Parameters ---------- - transform : {'log', 'logit'} + transform : {'log', 'logit', 'exp', 'expit'} Transform function. order : {0, 1, 2}, optional Highest order of derivative needed. Returns ------- - transform : function or array_like of function + transforms : function or array_like of function Transform function and its derivative(s). """ - pass + # Check input + if transform not in ['log', 'logit', 'exp', 'expit']: + raise ValueError(f"Invalid transform function '{transform}'.") + if order not in [0, 1, 2]: + raise ValueError(f"Invalid order '{order}'.") + + # Define transform functions + transform_dict = { + 'log': [ + lambda x: np.log, + lambda x: 1/x, + lambda x: -1/x**2, + ], 'logit': [ + lambda x: np.log(x/(1 - x)), + lambda x: 1/(x*(1 - x)), + lambda x: (2*x - 1)/(x**2*(1 - x)**2) + ], 'exp': [ + lambda x: np.exp, + lambda x: np.exp, + lambda x: np.exp + ], 'expit': [ + lambda x: 1/(1 + np.exp(-x)), + lambda x: np.exp(-x)/(1 + np.exp(-x))**2, + lambda x: np.exp(-x)*(np.exp(-x) - 1)/(1 + np.exp(-x))**3 + ] + } + + # Get function or list of functions + if order == 0: + return transform_dict[transform][order] + return transform_dict[transform][:order+1] From fdcf6a87ca9eb3c1fba5f9a30f977783053af4d2 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 16:42:19 -0700 Subject: [PATCH 10/32] changed test names --- tests/test_transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 0adc4fd..8b87b77 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -6,28 +6,28 @@ from distrx.transforms import get_transform -def test_transform_transform_value(): +def test_get_transform_transform_value(): """Raise ValueError for invalid `transform`.""" for order in [0, 1, 2]: with pytest.raises(ValueError): get_transform('dummy', order) -def test_transform_order_value(): +def test_get_transform_order_value(): """Raise ValueError for invalid `order`.""" for transform in ['log', 'logit', 'exp', 'expit']: with pytest.raises(ValueError): get_transform(transform, 3) -def test_transform_ouput_len(): +def test_get_transform_ouput_len(): """Length of output should correspond to `order`.""" for transform in ['log', 'logit', 'exp', 'expit']: assert len(get_transform(transform, 1)) == 2 assert len(get_transform(transform, 2)) == 3 -def test_transform_output_type(): +def test_get_transform_output_type(): """Type of output should correspond to `order`.""" for transform in ['log', 'logit', 'exp', 'expit']: assert isinstance(get_transform(transform), types.FunctionType) From 0e4d648f64a5e03aac2d4f8658d9ee40dff4062c Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Thu, 9 Sep 2021 19:38:04 -0700 Subject: [PATCH 11/32] add get_delta and tests --- src/distrx/transforms.py | 94 ++++++++++++++++++++++------------------ tests/test_transforms.py | 59 ++++++++++++++++++++++--- 2 files changed, 105 insertions(+), 48 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 42674d4..65d7398 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -3,6 +3,12 @@ Transform data, in the form of sample statistics and their standard errors, from one space to another using a given transform function. +TODO: +* Implement transform_delta2 +* Implement transform_data +* Add typing in function definition? +* Add user-defined transform function + """ import numpy as np @@ -17,26 +23,24 @@ def transform_data(mu, sigma, transform, method='delta'): Parameters ---------- - mu : float or array_like of float - Vector of sample statistics. - sigma : float or array_like of float - Vector of standard errors. - transform : {'log', 'logit', array_like of function} - Transform function. Users may define a transform function by - supplying a function and its derivative. If `method` is - 'delta2', the second derivative is also required. + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. method : {'delta, 'delta2'}, optional Method used to transform data. Returns ------- - mu_transform : float or array_like of float - Vector of sample stastistics in the transform space. - sigma_transform : float or array_like of float - Vector of standard errors in the transform space. + mu_transform : numpy.ndarray + Sample stastistics in the transform space. + sigma_transform : numpy.ndarray + Standard errors in the transform space. """ - pass + return def transform_delta(mu, sigma, transform): @@ -49,20 +53,19 @@ def transform_delta(mu, sigma, transform): Parameters ---------- - mu : float or array_like of float - Vector of sample statistics. - sigma : float or array_like of float - Vector of standard errors. - transform : {'log', 'logit', array_like of function} - Transform function. Users may define a transform function by - supplying a function and its derivative. + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. Returns ------- - mu_transform : float or array_like of float - Vector of sample statistics in the transform space. - sigma_transform : float or array_like of float - Vector of standard errors in the transform space. + mu_transform : numpy.ndarray + Sample statistics in the transform space. + sigma_transform : numpy.ndarray + Standard errors in the transform space. Notes ----- @@ -71,7 +74,17 @@ def transform_delta(mu, sigma, transform): variance. """ - pass + # Check mu and sigma + mu = np.array(mu) + sigma = np.array(sigma) + if len(mu) != len(sigma): + raise ValueError("Lengths of mu and sigma don't match.") + if np.any(sigma <= 0.0): + raise ValueError("Sigma values must be positive.") + + # Approximate transformed data + transform = get_transform(transform, 1) + return transform[0](mu), sigma*transform[1](mu)**2 def transform_delta2(mu, sigma, transform): @@ -84,20 +97,19 @@ def transform_delta2(mu, sigma, transform): Parameters ---------- - mu : float or array_like of float - Vector of sample statistics. - sigma : float or array_like of float - Vector of standard errors. - transform : {'log', 'logit', array_like of function} - Transform function. Users may define a transform function by - supplying a function and its first two derivatives. + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. Returns ------- - mu_transform : float or array_like of float - Vector of sample statistics in transform space. - sigma_transform : float or array_like of float - Vector of standard errors in transform space. + mu_transform : numpy.ndarray + Sample statistics in transform space. + sigma_transform : numpy.ndarray + Standard errors in transform space. Notes ----- @@ -108,7 +120,7 @@ def transform_delta2(mu, sigma, transform): cannot be applied), or the sample size is small. """ - pass + return def get_transform(transform, order=0): @@ -140,7 +152,7 @@ def get_transform(transform, order=0): # Define transform functions transform_dict = { 'log': [ - lambda x: np.log, + np.log, lambda x: 1/x, lambda x: -1/x**2, ], 'logit': [ @@ -148,9 +160,9 @@ def get_transform(transform, order=0): lambda x: 1/(x*(1 - x)), lambda x: (2*x - 1)/(x**2*(1 - x)**2) ], 'exp': [ - lambda x: np.exp, - lambda x: np.exp, - lambda x: np.exp + np.exp, + np.exp, + np.exp ], 'expit': [ lambda x: 1/(1 + np.exp(-x)), lambda x: np.exp(-x)/(1 + np.exp(-x))**2, diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 8b87b77..5700fbf 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,12 +1,19 @@ -"""Tests for transforms.py module.""" -import types +"""Tests for transforms.py module. +TODO: +* Add random vectors and vector lengths +* Create global list of transforms +* Implement delta2 tests +* Implement transform_data tests + +""" +import numpy as np import pytest -from distrx.transforms import get_transform +from distrx.transforms import get_transform, transform_delta -def test_get_transform_transform_value(): +def test_get_transform_transform(): """Raise ValueError for invalid `transform`.""" for order in [0, 1, 2]: with pytest.raises(ValueError): @@ -30,8 +37,46 @@ def test_get_transform_ouput_len(): def test_get_transform_output_type(): """Type of output should correspond to `order`.""" for transform in ['log', 'logit', 'exp', 'expit']: - assert isinstance(get_transform(transform), types.FunctionType) - assert isinstance(get_transform(transform, 0), types.FunctionType) + assert callable(get_transform(transform)) + assert callable(get_transform(transform, 0)) for order in [1, 2]: for function in get_transform(transform, order): - assert isinstance(function, types.FunctionType) + assert callable(function) + + +def test_transform_delta_input_len(): + """Raise ValueError if lengths of mu and sigma don't match.""" + for transform in ['log', 'logit', 'exp', 'expit']: + with pytest.raises(ValueError): + transform_delta([0.1]*2, [0.1]*3, transform) + + +def test_transform_delta_sigma(): + """Raise ValueError if `sigma` contains non-positive values.""" + for transform in ['log', 'logit', 'exp', 'expit', [np.sin, np.cos]]: + vals = [0.1, -0.1] + with pytest.raises(ValueError): + transform_delta(vals, vals, transform) + + +def test_transform_delta_transform(): + """Raise ValueError for invalid `transform`.""" + with pytest.raises(ValueError): + transform_delta([0.1], [0.1], 'dummy') + + +def test_transform_delta_output_type(): + """Output should be numpy arrays.""" + vals = [0.1]*2 + for transform in ['log', 'logit', 'exp', 'expit']: + mu, sigma = transform_delta(vals, vals, transform) + assert isinstance(mu, np.ndarray) + assert isinstance(sigma, np.ndarray) + + +def test_transform_delta_outout_len(): + """Length of output vectors should match.""" + vals = [0.1]*2 + for transform in ['log', 'logit', 'exp', 'expit']: + mu, sigma = transform_delta(vals, vals, transform) + assert len(mu) == len(sigma) From 05a23caf02c1271dc443b0f056f2260309d76c15 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Fri, 10 Sep 2021 14:37:45 -0700 Subject: [PATCH 12/32] add type hints and transform_dict, remove get_transform --- src/distrx/transforms.py | 114 ++++++++++++++++----------------------- tests/test_transforms.py | 33 +----------- 2 files changed, 46 insertions(+), 101 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 65d7398..f98d79c 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -6,14 +6,36 @@ TODO: * Implement transform_delta2 * Implement transform_data -* Add typing in function definition? * Add user-defined transform function """ import numpy as np - - -def transform_data(mu, sigma, transform, method='delta'): +import numpy.typing as npt + + +TRANSFORM_DICT = { + 'log': [ + np.log, + lambda x: 1/x, + lambda x: -1/x**2, + ], 'logit': [ + lambda x: np.log(x/(1 - x)), + lambda x: 1/(x*(1 - x)), + lambda x: (2*x - 1)/(x**2*(1 - x)**2) + ], 'exp': [ + np.exp, + np.exp, + np.exp + ], 'expit': [ + lambda x: 1/(1 + np.exp(-x)), + lambda x: np.exp(-x)/(1 + np.exp(-x))**2, + lambda x: np.exp(-x)*(np.exp(-x) - 1)/(1 + np.exp(-x))**3 + ] +} + + +def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, + method: str = 'delta') -> tuple[np.ndarray, np.ndarray]: """Transform data from one space to another. Transform data, in the form of sample statistics and their standard @@ -34,16 +56,17 @@ def transform_data(mu, sigma, transform, method='delta'): Returns ------- - mu_transform : numpy.ndarray + mu_trans : numpy.ndarray Sample stastistics in the transform space. - sigma_transform : numpy.ndarray + sigma_trans : numpy.ndarray Standard errors in the transform space. """ return -def transform_delta(mu, sigma, transform): +def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, + transform: str) -> tuple[np.ndarray, np.ndarray]: """Transform data using the delta method. Transform data, in the form of sample statistics and their standard @@ -62,9 +85,9 @@ def transform_delta(mu, sigma, transform): Returns ------- - mu_transform : numpy.ndarray + mu_trans : numpy.ndarray Sample statistics in the transform space. - sigma_transform : numpy.ndarray + sigma_trans : numpy.ndarray Standard errors in the transform space. Notes @@ -82,12 +105,18 @@ def transform_delta(mu, sigma, transform): if np.any(sigma <= 0.0): raise ValueError("Sigma values must be positive.") + # Check transform + if transform not in ['log', 'logit', 'exp', 'expit']: + raise ValueError(f"Invalid transform '{transform}'.") + # Approximate transformed data - transform = get_transform(transform, 1) - return transform[0](mu), sigma*transform[1](mu)**2 + mu_trans = TRANSFORM_DICT[transform][0](mu) + sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu)**2 + return mu_trans, sigma_trans -def transform_delta2(mu, sigma, transform): +def transform_delta2(mu: npt.ArrayLike, sigma: npt.ArrayLike, + transform: str) -> tuple[np.ndarray, np.ndarray]: """Transform data using the second-order delta method. Transform data, in the form of sample statistics and their standard @@ -106,10 +135,10 @@ def transform_delta2(mu, sigma, transform): Returns ------- - mu_transform : numpy.ndarray - Sample statistics in transform space. - sigma_transform : numpy.ndarray - Standard errors in transform space. + mu_trans : numpy.ndarray + Sample statistics in the transform space. + sigma_trans : numpy.ndarray + Standard errors in the transform space. Notes ----- @@ -121,56 +150,3 @@ def transform_delta2(mu, sigma, transform): """ return - - -def get_transform(transform, order=0): - """Get transform function and its derivative(s). - - Returns transform function if `order` is 0. - Otherwise returns an array of functions, including the transform - function and its derivatives up to the specified order. - - Parameters - ---------- - transform : {'log', 'logit', 'exp', 'expit'} - Transform function. - order : {0, 1, 2}, optional - Highest order of derivative needed. - - Returns - ------- - transforms : function or array_like of function - Transform function and its derivative(s). - - """ - # Check input - if transform not in ['log', 'logit', 'exp', 'expit']: - raise ValueError(f"Invalid transform function '{transform}'.") - if order not in [0, 1, 2]: - raise ValueError(f"Invalid order '{order}'.") - - # Define transform functions - transform_dict = { - 'log': [ - np.log, - lambda x: 1/x, - lambda x: -1/x**2, - ], 'logit': [ - lambda x: np.log(x/(1 - x)), - lambda x: 1/(x*(1 - x)), - lambda x: (2*x - 1)/(x**2*(1 - x)**2) - ], 'exp': [ - np.exp, - np.exp, - np.exp - ], 'expit': [ - lambda x: 1/(1 + np.exp(-x)), - lambda x: np.exp(-x)/(1 + np.exp(-x))**2, - lambda x: np.exp(-x)*(np.exp(-x) - 1)/(1 + np.exp(-x))**3 - ] - } - - # Get function or list of functions - if order == 0: - return transform_dict[transform][order] - return transform_dict[transform][:order+1] diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 5700fbf..146982a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -10,38 +10,7 @@ import numpy as np import pytest -from distrx.transforms import get_transform, transform_delta - - -def test_get_transform_transform(): - """Raise ValueError for invalid `transform`.""" - for order in [0, 1, 2]: - with pytest.raises(ValueError): - get_transform('dummy', order) - - -def test_get_transform_order_value(): - """Raise ValueError for invalid `order`.""" - for transform in ['log', 'logit', 'exp', 'expit']: - with pytest.raises(ValueError): - get_transform(transform, 3) - - -def test_get_transform_ouput_len(): - """Length of output should correspond to `order`.""" - for transform in ['log', 'logit', 'exp', 'expit']: - assert len(get_transform(transform, 1)) == 2 - assert len(get_transform(transform, 2)) == 3 - - -def test_get_transform_output_type(): - """Type of output should correspond to `order`.""" - for transform in ['log', 'logit', 'exp', 'expit']: - assert callable(get_transform(transform)) - assert callable(get_transform(transform, 0)) - for order in [1, 2]: - for function in get_transform(transform, order): - assert callable(function) +from distrx.transforms import transform_delta def test_transform_delta_input_len(): From 3cecb2ad7648b75c36c9e54645e01e596d9ec302 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Fri, 10 Sep 2021 14:41:10 -0700 Subject: [PATCH 13/32] changed tuple to typing.Tuple for Python < 3.9 --- src/distrx/transforms.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index f98d79c..2ce664f 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -9,6 +9,8 @@ * Add user-defined transform function """ +from typing import Tuple + import numpy as np import numpy.typing as npt @@ -35,7 +37,7 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, - method: str = 'delta') -> tuple[np.ndarray, np.ndarray]: + method: str = 'delta') -> Tuple[np.ndarray, np.ndarray]: """Transform data from one space to another. Transform data, in the form of sample statistics and their standard @@ -66,7 +68,7 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, - transform: str) -> tuple[np.ndarray, np.ndarray]: + transform: str) -> Tuple[np.ndarray, np.ndarray]: """Transform data using the delta method. Transform data, in the form of sample statistics and their standard @@ -116,7 +118,7 @@ def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, def transform_delta2(mu: npt.ArrayLike, sigma: npt.ArrayLike, - transform: str) -> tuple[np.ndarray, np.ndarray]: + transform: str) -> Tuple[np.ndarray, np.ndarray]: """Transform data using the second-order delta method. Transform data, in the form of sample statistics and their standard From b1a649b6255a4fc7429f53a952fdb5c4846c1a38 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Fri, 10 Sep 2021 15:31:42 -0700 Subject: [PATCH 14/32] start to implement transform_data wrapper function --- src/distrx/transforms.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 2ce664f..b81a3ee 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -64,6 +64,13 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Standard errors in the transform space. """ + # Check method + if method not in ['delta', 'delta2']: + raise ValueError(f"Invalid method '{method}'.") + + # Approximate transformed data + if method == 'delta': + return transform_delta(mu, sigma, transform) return @@ -108,7 +115,7 @@ def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, raise ValueError("Sigma values must be positive.") # Check transform - if transform not in ['log', 'logit', 'exp', 'expit']: + if transform not in TRANSFORM_DICT: raise ValueError(f"Invalid transform '{transform}'.") # Approximate transformed data From e5abb27358a893b316959f3eb696dfd6cba3dcf5 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Fri, 10 Sep 2021 15:32:26 -0700 Subject: [PATCH 15/32] add decorators --- tests/test_transforms.py | 42 +++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 146982a..b85fbb8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -2,7 +2,6 @@ TODO: * Add random vectors and vector lengths -* Create global list of transforms * Implement delta2 tests * Implement transform_data tests @@ -10,42 +9,59 @@ import numpy as np import pytest -from distrx.transforms import transform_delta +from distrx.transforms import transform_data, transform_delta -def test_transform_delta_input_len(): - """Raise ValueError if lengths of mu and sigma don't match.""" - for transform in ['log', 'logit', 'exp', 'expit']: +TRANSFORM_LIST = ['log', 'logit', 'exp', 'expit'] +FUNCTION_LIST = [transform_data, transform_delta] + + +@pytest.mark.parametrize("transform", TRANSFORM_LIST) +def test_method_name_valid(transform): + """Raise ValueError for invalue `method`.""" + vals = [0.1]*2 + with pytest.raises(ValueError): + transform_data(vals, vals, transform, method='dummy') + + +@pytest.mark.parametrize("function", FUNCTION_LIST) +def test_input_len_match(function): + """Raise ValueError if lengths of input vectors don't match.""" + for transform in TRANSFORM_LIST: with pytest.raises(ValueError): - transform_delta([0.1]*2, [0.1]*3, transform) + function([0.1]*2, [0.1]*3, transform) -def test_transform_delta_sigma(): +@pytest.mark.parametrize("function", FUNCTION_LIST) +def test_sigma_positive(function): """Raise ValueError if `sigma` contains non-positive values.""" - for transform in ['log', 'logit', 'exp', 'expit', [np.sin, np.cos]]: + for transform in TRANSFORM_LIST: vals = [0.1, -0.1] with pytest.raises(ValueError): transform_delta(vals, vals, transform) -def test_transform_delta_transform(): +@pytest.mark.parametrize("function", FUNCTION_LIST) +def test_transform_name_valid(function): """Raise ValueError for invalid `transform`.""" with pytest.raises(ValueError): transform_delta([0.1], [0.1], 'dummy') -def test_transform_delta_output_type(): +@pytest.mark.parametrize("function", FUNCTION_LIST) +def test_output_type(function): """Output should be numpy arrays.""" vals = [0.1]*2 - for transform in ['log', 'logit', 'exp', 'expit']: + for transform in TRANSFORM_LIST: mu, sigma = transform_delta(vals, vals, transform) assert isinstance(mu, np.ndarray) assert isinstance(sigma, np.ndarray) -def test_transform_delta_outout_len(): +@pytest.mark.parametrize("function", FUNCTION_LIST) +def test_outout_len_match(function): """Length of output vectors should match.""" vals = [0.1]*2 - for transform in ['log', 'logit', 'exp', 'expit']: + for transform in TRANSFORM_LIST: mu, sigma = transform_delta(vals, vals, transform) assert len(mu) == len(sigma) From 06c177f6167c175b67da1fddf3b2934e075e055f Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Fri, 10 Sep 2021 16:10:54 -0700 Subject: [PATCH 16/32] move checks into one function --- src/distrx/transforms.py | 102 +++++++++++++++++++++++++++++++-------- 1 file changed, 83 insertions(+), 19 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index b81a3ee..59733b7 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -64,14 +64,10 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Standard errors in the transform space. """ - # Check method - if method not in ['delta', 'delta2']: - raise ValueError(f"Invalid method '{method}'.") - - # Approximate transformed data + mu, sigma = np.array(mu), np.array(sigma) + check_input(mu, sigma, transform, method) if method == 'delta': return transform_delta(mu, sigma, transform) - return def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, @@ -106,19 +102,8 @@ def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, variance. """ - # Check mu and sigma - mu = np.array(mu) - sigma = np.array(sigma) - if len(mu) != len(sigma): - raise ValueError("Lengths of mu and sigma don't match.") - if np.any(sigma <= 0.0): - raise ValueError("Sigma values must be positive.") - - # Check transform - if transform not in TRANSFORM_DICT: - raise ValueError(f"Invalid transform '{transform}'.") - - # Approximate transformed data + mu, sigma = np.array(mu), np.array(sigma) + check_input(mu, sigma, transform) mu_trans = TRANSFORM_DICT[transform][0](mu) sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu)**2 return mu_trans, sigma_trans @@ -158,4 +143,83 @@ def transform_delta2(mu: npt.ArrayLike, sigma: npt.ArrayLike, cannot be applied), or the sample size is small. """ + mu, sigma = np.array(mu), np.array(sigma) + check_input(mu, sigma, transform) return + + +def check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, + method: str = None) -> None: + """Run checks on input data. + + Parameters + ---------- + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. + method : {None, 'delta', 'delta2'}, optional + Method used to transform data. + + """ + check_lengths_match(mu, sigma) + check_sigma_positive(sigma) + check_transform_valid(transform) + if method is not None: + check_method_valid(method) + + +def check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: + """Check that `mu` and `sigma` have the same lengths. + + Parameters + ---------- + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + + """ + if len(mu) != len(sigma): + raise ValueError("Lengths of mu and sigma don't match.") + + +def check_sigma_positive(sigma: npt.ArrayLike) -> None: + """Check that `sigma` is positive. + + Parameters + ---------- + sigma : array_like + Standard errors. + + """ + if np.any(sigma <= 0): + raise ValueError("Sigma values must be positive.") + + +def check_transform_valid(transform: str) -> None: + """Check that `transform` is in TRANSFORM_DICT. + + Parameters + ---------- + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. + + """ + if transform not in TRANSFORM_DICT: + raise ValueError(f"Invalid transform '{transform}'.") + + +def check_method_valid(method: str) -> None: + """Check that `method` is in ['delta', 'delta2']. + + Parameters + ---------- + method : {'delta', 'delta2'} + Method used to transform data. + + """ + if method not in ['delta', 'delta2']: + raise ValueError(f"Invalid method '{method}'.") From 0fcd9994004922871968a3f0aab767a8a48f37d3 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Fri, 10 Sep 2021 16:15:03 -0700 Subject: [PATCH 17/32] add transform_delta2 to function list --- tests/test_transforms.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index b85fbb8..f977b60 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -9,11 +9,11 @@ import numpy as np import pytest -from distrx.transforms import transform_data, transform_delta +from distrx.transforms import transform_data, transform_delta, transform_delta2 TRANSFORM_LIST = ['log', 'logit', 'exp', 'expit'] -FUNCTION_LIST = [transform_data, transform_delta] +FUNCTION_LIST = [transform_data, transform_delta, transform_delta2] @pytest.mark.parametrize("transform", TRANSFORM_LIST) @@ -21,7 +21,7 @@ def test_method_name_valid(transform): """Raise ValueError for invalue `method`.""" vals = [0.1]*2 with pytest.raises(ValueError): - transform_data(vals, vals, transform, method='dummy') + transform_data(vals, vals, transform, method='dummy') @pytest.mark.parametrize("function", FUNCTION_LIST) @@ -38,14 +38,14 @@ def test_sigma_positive(function): for transform in TRANSFORM_LIST: vals = [0.1, -0.1] with pytest.raises(ValueError): - transform_delta(vals, vals, transform) + function(vals, vals, transform) @pytest.mark.parametrize("function", FUNCTION_LIST) def test_transform_name_valid(function): """Raise ValueError for invalid `transform`.""" with pytest.raises(ValueError): - transform_delta([0.1], [0.1], 'dummy') + function([0.1], [0.1], 'dummy') @pytest.mark.parametrize("function", FUNCTION_LIST) @@ -53,7 +53,7 @@ def test_output_type(function): """Output should be numpy arrays.""" vals = [0.1]*2 for transform in TRANSFORM_LIST: - mu, sigma = transform_delta(vals, vals, transform) + mu, sigma = function(vals, vals, transform) assert isinstance(mu, np.ndarray) assert isinstance(sigma, np.ndarray) @@ -63,5 +63,5 @@ def test_outout_len_match(function): """Length of output vectors should match.""" vals = [0.1]*2 for transform in TRANSFORM_LIST: - mu, sigma = transform_delta(vals, vals, transform) + mu, sigma = function(vals, vals, transform) assert len(mu) == len(sigma) From 9511839e3fea4183be0806ccbce3eab05768c1ba Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Fri, 10 Sep 2021 17:33:58 -0700 Subject: [PATCH 18/32] save before deleting delta2 functions --- src/distrx/transforms.py | 38 +++++++++++++++++++++++++++++--------- tests/test_transforms.py | 25 +++++++++++++++---------- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 59733b7..ed1ede1 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -37,7 +37,8 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, - method: str = 'delta') -> Tuple[np.ndarray, np.ndarray]: + method: str = 'delta', kurtosis: npt.ArrayLike = None) -> \ + Tuple[np.ndarray, np.ndarray]: """Transform data from one space to another. Transform data, in the form of sample statistics and their standard @@ -45,6 +46,9 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, No assumptions are made about the underlying distributions of the given data. + If `method` is 'delta2' and `kurtosis` is None, uses the fourth + non-central moment of the Gaussian distribution, 3*`sigma`**4. + Parameters ---------- mu : array_like @@ -55,6 +59,9 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Transform function. method : {'delta, 'delta2'}, optional Method used to transform data. + kurtosis : array_like, optional + Fourth non-central moments. + If None, uses Gaussian values 3*`sigma`**4. Returns ------- @@ -65,9 +72,11 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, """ mu, sigma = np.array(mu), np.array(sigma) - check_input(mu, sigma, transform, method) + kurtosis = 3*sigma**4 if kurtosis is None else np.array(kurtosis) + check_input(mu, sigma, transform, method, kurtosis) if method == 'delta': return transform_delta(mu, sigma, transform) + return transform_delta2(mu, sigma, kurtosis, transform) def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, @@ -105,12 +114,13 @@ def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, mu, sigma = np.array(mu), np.array(sigma) check_input(mu, sigma, transform) mu_trans = TRANSFORM_DICT[transform][0](mu) - sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu)**2 + sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu) return mu_trans, sigma_trans def transform_delta2(mu: npt.ArrayLike, sigma: npt.ArrayLike, - transform: str) -> Tuple[np.ndarray, np.ndarray]: + kurtosis: npt.ArrayLike, transform: str) -> \ + Tuple[np.ndarray, np.ndarray]: """Transform data using the second-order delta method. Transform data, in the form of sample statistics and their standard @@ -124,9 +134,10 @@ def transform_delta2(mu: npt.ArrayLike, sigma: npt.ArrayLike, Sample statistics. sigma : array_like Standard errors. + kurtosis : array_like + Fourth non-central moments. transform : {'log', 'logit', 'exp', 'expit'} Transform function. - Returns ------- mu_trans : numpy.ndarray @@ -143,13 +154,17 @@ def transform_delta2(mu: npt.ArrayLike, sigma: npt.ArrayLike, cannot be applied), or the sample size is small. """ - mu, sigma = np.array(mu), np.array(sigma) - check_input(mu, sigma, transform) - return + mu, sigma, kurtosis = np.array(mu), np.array(sigma), np.array(kurtosis) + check_input(mu, sigma, transform, kurtosis) + mu_trans = TRANSFORM_DICT[transform][0](mu) + \ + sigma**2*TRANSFORM_DICT[transform][2](mu)/2 + sigma_trans = np.sqrt(sigma**2*TRANSFORM_DICT[transform][1]**2 + + kurtosis*TRANSFORM_DICT[transform][2]**2/2) + return mu_trans, sigma_trans def check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, - method: str = None) -> None: + method: str = None, kurtosis: npt.ArrayLike = None) -> None: """Run checks on input data. Parameters @@ -162,6 +177,8 @@ def check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Transform function. method : {None, 'delta', 'delta2'}, optional Method used to transform data. + kurtosis : array_like + Fourth non-central moments. """ check_lengths_match(mu, sigma) @@ -169,6 +186,9 @@ def check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, check_transform_valid(transform) if method is not None: check_method_valid(method) + if kurtosis is not None: + check_lengths_match(mu, kurtosis) + check_sigma_positive(kurtosis) def check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: diff --git a/tests/test_transforms.py b/tests/test_transforms.py index f977b60..b4dbe8f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -14,14 +14,13 @@ TRANSFORM_LIST = ['log', 'logit', 'exp', 'expit'] FUNCTION_LIST = [transform_data, transform_delta, transform_delta2] - +VALS = [0.1]*2 @pytest.mark.parametrize("transform", TRANSFORM_LIST) def test_method_name_valid(transform): """Raise ValueError for invalue `method`.""" - vals = [0.1]*2 with pytest.raises(ValueError): - transform_data(vals, vals, transform, method='dummy') + transform_data(VALS, VALS, transform, method='dummy') @pytest.mark.parametrize("function", FUNCTION_LIST) @@ -29,31 +28,38 @@ def test_input_len_match(function): """Raise ValueError if lengths of input vectors don't match.""" for transform in TRANSFORM_LIST: with pytest.raises(ValueError): - function([0.1]*2, [0.1]*3, transform) + function(VALS, VALS*2, transform) @pytest.mark.parametrize("function", FUNCTION_LIST) def test_sigma_positive(function): """Raise ValueError if `sigma` contains non-positive values.""" + vals = VALS + [-0.1] for transform in TRANSFORM_LIST: - vals = [0.1, -0.1] with pytest.raises(ValueError): function(vals, vals, transform) +def test_kurtosis_positive(function): + """Raise ValueError if `kurtosis` contains non-positive values.""" + vals = VALS + [-0.1] + for transform in TRANSFORM_LIST: + with pytest.raises(ValueError): + transform_delta2(vals, VALS + [0.1], vals, transform) + + @pytest.mark.parametrize("function", FUNCTION_LIST) def test_transform_name_valid(function): """Raise ValueError for invalid `transform`.""" with pytest.raises(ValueError): - function([0.1], [0.1], 'dummy') + function(VALS, VALS, 'dummy') @pytest.mark.parametrize("function", FUNCTION_LIST) def test_output_type(function): """Output should be numpy arrays.""" - vals = [0.1]*2 for transform in TRANSFORM_LIST: - mu, sigma = function(vals, vals, transform) + mu, sigma = function(VALS, VALS, transform) assert isinstance(mu, np.ndarray) assert isinstance(sigma, np.ndarray) @@ -61,7 +67,6 @@ def test_output_type(function): @pytest.mark.parametrize("function", FUNCTION_LIST) def test_outout_len_match(function): """Length of output vectors should match.""" - vals = [0.1]*2 for transform in TRANSFORM_LIST: - mu, sigma = function(vals, vals, transform) + mu, sigma = function(VALS, VALS, transform) assert len(mu) == len(sigma) From d5916a5459ab6d20950d09d53f3386c2bb12b4c3 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Fri, 10 Sep 2021 17:45:05 -0700 Subject: [PATCH 19/32] removed second-order delta method, renamed transform_delta --- src/distrx/transforms.py | 83 ++++++---------------------------------- tests/test_transforms.py | 13 ++----- 2 files changed, 15 insertions(+), 81 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index ed1ede1..88d40be 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -34,11 +34,11 @@ lambda x: np.exp(-x)*(np.exp(-x) - 1)/(1 + np.exp(-x))**3 ] } +METHOD_LIST = ['delta'] def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, - method: str = 'delta', kurtosis: npt.ArrayLike = None) -> \ - Tuple[np.ndarray, np.ndarray]: + method: str = 'delta') -> Tuple[np.ndarray, np.ndarray]: """Transform data from one space to another. Transform data, in the form of sample statistics and their standard @@ -46,9 +46,6 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, No assumptions are made about the underlying distributions of the given data. - If `method` is 'delta2' and `kurtosis` is None, uses the fourth - non-central moment of the Gaussian distribution, 3*`sigma`**4. - Parameters ---------- mu : array_like @@ -57,11 +54,8 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Standard errors. transform : {'log', 'logit', 'exp', 'expit'} Transform function. - method : {'delta, 'delta2'}, optional + method : {'delta'}, optional Method used to transform data. - kurtosis : array_like, optional - Fourth non-central moments. - If None, uses Gaussian values 3*`sigma`**4. Returns ------- @@ -72,15 +66,13 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, """ mu, sigma = np.array(mu), np.array(sigma) - kurtosis = 3*sigma**4 if kurtosis is None else np.array(kurtosis) - check_input(mu, sigma, transform, method, kurtosis) + check_input(mu, sigma, transform, method) if method == 'delta': - return transform_delta(mu, sigma, transform) - return transform_delta2(mu, sigma, kurtosis, transform) + return delta_method(mu, sigma, transform) -def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, - transform: str) -> Tuple[np.ndarray, np.ndarray]: +def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \ + Tuple[np.ndarray, np.ndarray]: """Transform data using the delta method. Transform data, in the form of sample statistics and their standard @@ -112,59 +104,14 @@ def transform_delta(mu: npt.ArrayLike, sigma: npt.ArrayLike, """ mu, sigma = np.array(mu), np.array(sigma) - check_input(mu, sigma, transform) + check_input(mu, sigma, transform, 'delta') mu_trans = TRANSFORM_DICT[transform][0](mu) sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu) return mu_trans, sigma_trans -def transform_delta2(mu: npt.ArrayLike, sigma: npt.ArrayLike, - kurtosis: npt.ArrayLike, transform: str) -> \ - Tuple[np.ndarray, np.ndarray]: - """Transform data using the second-order delta method. - - Transform data, in the form of sample statistics and their standard - errors, from one space to another using a given transform function - and the second-order delta method. No assumptions are made about - the underlying distributions of the given data. - - Parameters - ---------- - mu : array_like - Sample statistics. - sigma : array_like - Standard errors. - kurtosis : array_like - Fourth non-central moments. - transform : {'log', 'logit', 'exp', 'expit'} - Transform function. - Returns - ------- - mu_trans : numpy.ndarray - Sample statistics in the transform space. - sigma_trans : numpy.ndarray - Standard errors in the transform space. - - Notes - ----- - The second-order delta method expands a function of a random - variable about its mean with a two-step Taylor approximation and - then takes the variance. This method is useful if the derivative of - the transform function is zero (so the first-order delta method - cannot be applied), or the sample size is small. - - """ - mu, sigma, kurtosis = np.array(mu), np.array(sigma), np.array(kurtosis) - check_input(mu, sigma, transform, kurtosis) - mu_trans = TRANSFORM_DICT[transform][0](mu) + \ - sigma**2*TRANSFORM_DICT[transform][2](mu)/2 - sigma_trans = np.sqrt(sigma**2*TRANSFORM_DICT[transform][1]**2 + - kurtosis*TRANSFORM_DICT[transform][2]**2/2) - return mu_trans, sigma_trans - - def check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, - method: str = None, kurtosis: npt.ArrayLike = None) -> None: + method: str) -> None: """Run checks on input data. Parameters @@ -175,20 +122,14 @@ def check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Standard errors. transform : {'log', 'logit', 'exp', 'expit'} Transform function. - method : {None, 'delta', 'delta2'}, optional + method : {'delta'} Method used to transform data. - kurtosis : array_like - Fourth non-central moments. """ check_lengths_match(mu, sigma) check_sigma_positive(sigma) check_transform_valid(transform) - if method is not None: - check_method_valid(method) - if kurtosis is not None: - check_lengths_match(mu, kurtosis) - check_sigma_positive(kurtosis) + check_method_valid(method) def check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: @@ -241,5 +182,5 @@ def check_method_valid(method: str) -> None: Method used to transform data. """ - if method not in ['delta', 'delta2']: + if method not in METHOD_LIST: raise ValueError(f"Invalid method '{method}'.") diff --git a/tests/test_transforms.py b/tests/test_transforms.py index b4dbe8f..b6d3b6d 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -9,13 +9,14 @@ import numpy as np import pytest -from distrx.transforms import transform_data, transform_delta, transform_delta2 +from distrx.transforms import transform_data, delta_method TRANSFORM_LIST = ['log', 'logit', 'exp', 'expit'] -FUNCTION_LIST = [transform_data, transform_delta, transform_delta2] +FUNCTION_LIST = [transform_data, delta_method] VALS = [0.1]*2 + @pytest.mark.parametrize("transform", TRANSFORM_LIST) def test_method_name_valid(transform): """Raise ValueError for invalue `method`.""" @@ -40,14 +41,6 @@ def test_sigma_positive(function): function(vals, vals, transform) -def test_kurtosis_positive(function): - """Raise ValueError if `kurtosis` contains non-positive values.""" - vals = VALS + [-0.1] - for transform in TRANSFORM_LIST: - with pytest.raises(ValueError): - transform_delta2(vals, VALS + [0.1], vals, transform) - - @pytest.mark.parametrize("function", FUNCTION_LIST) def test_transform_name_valid(function): """Raise ValueError for invalid `transform`.""" From 9f8027414d3e593a5af71b2bc854a93688eb854b Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Tue, 14 Sep 2021 16:57:04 -0700 Subject: [PATCH 20/32] remove second derivatives from dictionary --- src/distrx/transforms.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 88d40be..e0d7d6b 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -18,20 +18,16 @@ TRANSFORM_DICT = { 'log': [ np.log, - lambda x: 1/x, - lambda x: -1/x**2, + lambda x: 1/x ], 'logit': [ lambda x: np.log(x/(1 - x)), - lambda x: 1/(x*(1 - x)), - lambda x: (2*x - 1)/(x**2*(1 - x)**2) + lambda x: 1/(x*(1 - x)) ], 'exp': [ - np.exp, np.exp, np.exp ], 'expit': [ lambda x: 1/(1 + np.exp(-x)), - lambda x: np.exp(-x)/(1 + np.exp(-x))**2, - lambda x: np.exp(-x)*(np.exp(-x) - 1)/(1 + np.exp(-x))**3 + lambda x: np.exp(-x)/(1 + np.exp(-x))**2 ] } METHOD_LIST = ['delta'] From 69c594b5cf3e407ce11badc8f2858dba98458ec5 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Tue, 14 Sep 2021 17:02:11 -0700 Subject: [PATCH 21/32] make test functions private --- src/distrx/transforms.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index e0d7d6b..7e057bf 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -62,7 +62,7 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, """ mu, sigma = np.array(mu), np.array(sigma) - check_input(mu, sigma, transform, method) + _check_input(mu, sigma, transform, method) if method == 'delta': return delta_method(mu, sigma, transform) @@ -100,13 +100,13 @@ def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \ """ mu, sigma = np.array(mu), np.array(sigma) - check_input(mu, sigma, transform, 'delta') + _check_input(mu, sigma, transform, 'delta') mu_trans = TRANSFORM_DICT[transform][0](mu) sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu) return mu_trans, sigma_trans -def check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, +def _check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, method: str) -> None: """Run checks on input data. @@ -122,13 +122,13 @@ def check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Method used to transform data. """ - check_lengths_match(mu, sigma) - check_sigma_positive(sigma) - check_transform_valid(transform) - check_method_valid(method) + _check_lengths_match(mu, sigma) + _check_sigma_positive(sigma) + _check_transform_valid(transform) + _check_method_valid(method) -def check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: +def _check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: """Check that `mu` and `sigma` have the same lengths. Parameters @@ -143,7 +143,7 @@ def check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: raise ValueError("Lengths of mu and sigma don't match.") -def check_sigma_positive(sigma: npt.ArrayLike) -> None: +def _check_sigma_positive(sigma: npt.ArrayLike) -> None: """Check that `sigma` is positive. Parameters @@ -156,7 +156,7 @@ def check_sigma_positive(sigma: npt.ArrayLike) -> None: raise ValueError("Sigma values must be positive.") -def check_transform_valid(transform: str) -> None: +def _check_transform_valid(transform: str) -> None: """Check that `transform` is in TRANSFORM_DICT. Parameters @@ -169,7 +169,7 @@ def check_transform_valid(transform: str) -> None: raise ValueError(f"Invalid transform '{transform}'.") -def check_method_valid(method: str) -> None: +def _check_method_valid(method: str) -> None: """Check that `method` is in ['delta', 'delta2']. Parameters From c43be460a14d5a01b802c420fbf681385b1bc3a9 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Tue, 14 Sep 2021 17:04:41 -0700 Subject: [PATCH 22/32] update module-level docstrings --- src/distrx/transforms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 7e057bf..1de0e45 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -4,9 +4,9 @@ errors, from one space to another using a given transform function. TODO: -* Implement transform_delta2 -* Implement transform_data * Add user-defined transform function +* Add functions for confidence intervals +* Add decorators for accepting floats or vectors """ from typing import Tuple @@ -107,7 +107,7 @@ def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \ def _check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, - method: str) -> None: + method: str) -> None: """Run checks on input data. Parameters From 5e1828ee352798404c3955007d8bbc2902c2b8b2 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Tue, 14 Sep 2021 17:19:38 -0700 Subject: [PATCH 23/32] add warning if sigma contains zeros --- src/distrx/transforms.py | 5 ++++- tests/test_transforms.py | 34 ++++++++++++++++++---------------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 1de0e45..41abe76 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -10,6 +10,7 @@ """ from typing import Tuple +import warnings import numpy as np import numpy.typing as npt @@ -152,7 +153,9 @@ def _check_sigma_positive(sigma: npt.ArrayLike) -> None: Standard errors. """ - if np.any(sigma <= 0): + if np.any(sigma == 0): + warnings.warn("Sigma vector contains zeros.") + if np.any(sigma < 0): raise ValueError("Sigma values must be positive.") diff --git a/tests/test_transforms.py b/tests/test_transforms.py index b6d3b6d..c85ab92 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,11 +1,4 @@ -"""Tests for transforms.py module. - -TODO: -* Add random vectors and vector lengths -* Implement delta2 tests -* Implement transform_data tests - -""" +"""Tests for transforms.py module.""" import numpy as np import pytest @@ -17,14 +10,14 @@ VALS = [0.1]*2 -@pytest.mark.parametrize("transform", TRANSFORM_LIST) +@pytest.mark.parametrize('transform', TRANSFORM_LIST) def test_method_name_valid(transform): """Raise ValueError for invalue `method`.""" with pytest.raises(ValueError): transform_data(VALS, VALS, transform, method='dummy') -@pytest.mark.parametrize("function", FUNCTION_LIST) +@pytest.mark.parametrize('function', FUNCTION_LIST) def test_input_len_match(function): """Raise ValueError if lengths of input vectors don't match.""" for transform in TRANSFORM_LIST: @@ -32,23 +25,32 @@ def test_input_len_match(function): function(VALS, VALS*2, transform) -@pytest.mark.parametrize("function", FUNCTION_LIST) -def test_sigma_positive(function): - """Raise ValueError if `sigma` contains non-positive values.""" +@pytest.mark.parametrize('function', FUNCTION_LIST) +def test_sigma_negative(function): + """Raise ValueError if `sigma` contains negative values.""" vals = VALS + [-0.1] for transform in TRANSFORM_LIST: with pytest.raises(ValueError): function(vals, vals, transform) -@pytest.mark.parametrize("function", FUNCTION_LIST) +@pytest.mark.parametrize('function', FUNCTION_LIST) +def test_sigma_zero(function): + """Display warning if `sigma` contains zeros.""" + vals = VALS + [0.0] + for transform in TRANSFORM_LIST: + with pytest.warns(UserWarning): + function(vals, vals, transform) + + +@pytest.mark.parametrize('function', FUNCTION_LIST) def test_transform_name_valid(function): """Raise ValueError for invalid `transform`.""" with pytest.raises(ValueError): function(VALS, VALS, 'dummy') -@pytest.mark.parametrize("function", FUNCTION_LIST) +@pytest.mark.parametrize('function', FUNCTION_LIST) def test_output_type(function): """Output should be numpy arrays.""" for transform in TRANSFORM_LIST: @@ -57,7 +59,7 @@ def test_output_type(function): assert isinstance(sigma, np.ndarray) -@pytest.mark.parametrize("function", FUNCTION_LIST) +@pytest.mark.parametrize('function', FUNCTION_LIST) def test_outout_len_match(function): """Length of output vectors should match.""" for transform in TRANSFORM_LIST: From ae77bc36f5b81b209814fc5e4586fb8f4aa8f6ee Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Tue, 5 Oct 2021 18:22:12 -0700 Subject: [PATCH 24/32] add warning for sigma == 0 --- src/distrx/transforms.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 41abe76..3f38067 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -19,16 +19,16 @@ TRANSFORM_DICT = { 'log': [ np.log, - lambda x: 1/x + lambda x: 1.0/x ], 'logit': [ - lambda x: np.log(x/(1 - x)), - lambda x: 1/(x*(1 - x)) + lambda x: np.log(x/(1.0 - x)), + lambda x: 1.0/(x*(1.0 - x)) ], 'exp': [ np.exp, np.exp ], 'expit': [ - lambda x: 1/(1 + np.exp(-x)), - lambda x: np.exp(-x)/(1 + np.exp(-x))**2 + lambda x: 1.0/(1.0 + np.exp(-x)), + lambda x: np.exp(-x)/(1.0 + np.exp(-x))**2 ] } METHOD_LIST = ['delta'] @@ -62,8 +62,6 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Standard errors in the transform space. """ - mu, sigma = np.array(mu), np.array(sigma) - _check_input(mu, sigma, transform, method) if method == 'delta': return delta_method(mu, sigma, transform) @@ -153,9 +151,9 @@ def _check_sigma_positive(sigma: npt.ArrayLike) -> None: Standard errors. """ - if np.any(sigma == 0): + if np.any(sigma == 0.0): warnings.warn("Sigma vector contains zeros.") - if np.any(sigma < 0): + if np.any(sigma < 0.0): raise ValueError("Sigma values must be positive.") From 5502a30792a95ddc2b7f80a86086ed4e33060481 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Tue, 5 Oct 2021 18:33:55 -0700 Subject: [PATCH 25/32] change where method name is checked --- src/distrx/transforms.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index 3f38067..be65876 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -62,6 +62,7 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Standard errors in the transform space. """ + _check_method_valid(method) if method == 'delta': return delta_method(mu, sigma, transform) @@ -99,14 +100,26 @@ def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \ """ mu, sigma = np.array(mu), np.array(sigma) - _check_input(mu, sigma, transform, 'delta') + _check_input(mu, sigma, transform) mu_trans = TRANSFORM_DICT[transform][0](mu) sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu) return mu_trans, sigma_trans -def _check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, - method: str) -> None: +def _check_method_valid(method: str) -> None: + """Check that `method` is in METHOD_LIST. + + Parameters + ---------- + method : {'delta'} + Method used to transform data. + + """ + if method not in METHOD_LIST: + raise ValueError(f"Invalid method '{method}'.") + + +def _check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> None: """Run checks on input data. Parameters @@ -117,14 +130,11 @@ def _check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Standard errors. transform : {'log', 'logit', 'exp', 'expit'} Transform function. - method : {'delta'} - Method used to transform data. """ _check_lengths_match(mu, sigma) _check_sigma_positive(sigma) _check_transform_valid(transform) - _check_method_valid(method) def _check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: @@ -168,16 +178,3 @@ def _check_transform_valid(transform: str) -> None: """ if transform not in TRANSFORM_DICT: raise ValueError(f"Invalid transform '{transform}'.") - - -def _check_method_valid(method: str) -> None: - """Check that `method` is in ['delta', 'delta2']. - - Parameters - ---------- - method : {'delta', 'delta2'} - Method used to transform data. - - """ - if method not in METHOD_LIST: - raise ValueError(f"Invalid method '{method}'.") From 641c3d24623a92628b13f68ce732e88bcdb2a992 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Wed, 13 Oct 2021 17:14:37 -0700 Subject: [PATCH 26/32] add checks to both methods --- src/distrx/transforms.py | 53 ++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/src/distrx/transforms.py b/src/distrx/transforms.py index be65876..bae43f0 100644 --- a/src/distrx/transforms.py +++ b/src/distrx/transforms.py @@ -62,7 +62,8 @@ def transform_data(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str, Standard errors in the transform space. """ - _check_method_valid(method) + mu, sigma = np.array(mu), np.array(sigma) + _check_input(method, transform, mu, sigma) if method == 'delta': return delta_method(mu, sigma, transform) @@ -100,12 +101,34 @@ def delta_method(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> \ """ mu, sigma = np.array(mu), np.array(sigma) - _check_input(mu, sigma, transform) + _check_input('delta', transform, mu, sigma) mu_trans = TRANSFORM_DICT[transform][0](mu) sigma_trans = sigma*TRANSFORM_DICT[transform][1](mu) return mu_trans, sigma_trans +def _check_input(method: str, transform: str, mu: npt.ArrayLike, + sigma: npt.ArrayLike) -> None: + """Run checks on input data. + + Parameters + ---------- + method : {'delta'} + Method used to transform data. + transform : {'log', 'logit', 'exp', 'expit'} + Transform function. + mu : array_like + Sample statistics. + sigma : array_like + Standard errors. + + """ + _check_method_valid(method) + _check_transform_valid(transform) + _check_lengths_match(mu, sigma) + _check_sigma_positive(sigma) + + def _check_method_valid(method: str) -> None: """Check that `method` is in METHOD_LIST. @@ -119,22 +142,17 @@ def _check_method_valid(method: str) -> None: raise ValueError(f"Invalid method '{method}'.") -def _check_input(mu: npt.ArrayLike, sigma: npt.ArrayLike, transform: str) -> None: - """Run checks on input data. +def _check_transform_valid(transform: str) -> None: + """Check that `transform` is in TRANSFORM_DICT. Parameters ---------- - mu : array_like - Sample statistics. - sigma : array_like - Standard errors. transform : {'log', 'logit', 'exp', 'expit'} Transform function. """ - _check_lengths_match(mu, sigma) - _check_sigma_positive(sigma) - _check_transform_valid(transform) + if transform not in TRANSFORM_DICT: + raise ValueError(f"Invalid transform '{transform}'.") def _check_lengths_match(mu: npt.ArrayLike, sigma: npt.ArrayLike) -> None: @@ -165,16 +183,3 @@ def _check_sigma_positive(sigma: npt.ArrayLike) -> None: warnings.warn("Sigma vector contains zeros.") if np.any(sigma < 0.0): raise ValueError("Sigma values must be positive.") - - -def _check_transform_valid(transform: str) -> None: - """Check that `transform` is in TRANSFORM_DICT. - - Parameters - ---------- - transform : {'log', 'logit', 'exp', 'expit'} - Transform function. - - """ - if transform not in TRANSFORM_DICT: - raise ValueError(f"Invalid transform '{transform}'.") From f25a64b181b2d2709b8060919b51384caa2271c0 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Wed, 13 Oct 2021 17:21:58 -0700 Subject: [PATCH 27/32] add nested parametrized tests --- tests/test_transforms.py | 42 ++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index c85ab92..58b6e5a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -18,29 +18,29 @@ def test_method_name_valid(transform): @pytest.mark.parametrize('function', FUNCTION_LIST) -def test_input_len_match(function): +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_input_len_match(function, transform): """Raise ValueError if lengths of input vectors don't match.""" - for transform in TRANSFORM_LIST: - with pytest.raises(ValueError): - function(VALS, VALS*2, transform) + with pytest.raises(ValueError): + function(VALS, VALS*2, transform) @pytest.mark.parametrize('function', FUNCTION_LIST) -def test_sigma_negative(function): +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_sigma_negative(function, transform): """Raise ValueError if `sigma` contains negative values.""" vals = VALS + [-0.1] - for transform in TRANSFORM_LIST: - with pytest.raises(ValueError): - function(vals, vals, transform) + with pytest.raises(ValueError): + function(vals, vals, transform) @pytest.mark.parametrize('function', FUNCTION_LIST) -def test_sigma_zero(function): +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_sigma_zero(function, transform): """Display warning if `sigma` contains zeros.""" vals = VALS + [0.0] - for transform in TRANSFORM_LIST: - with pytest.warns(UserWarning): - function(vals, vals, transform) + with pytest.warns(UserWarning): + function(vals, vals, transform) @pytest.mark.parametrize('function', FUNCTION_LIST) @@ -51,17 +51,17 @@ def test_transform_name_valid(function): @pytest.mark.parametrize('function', FUNCTION_LIST) -def test_output_type(function): +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_output_type(function, transform): """Output should be numpy arrays.""" - for transform in TRANSFORM_LIST: - mu, sigma = function(VALS, VALS, transform) - assert isinstance(mu, np.ndarray) - assert isinstance(sigma, np.ndarray) + mu, sigma = function(VALS, VALS, transform) + assert isinstance(mu, np.ndarray) + assert isinstance(sigma, np.ndarray) @pytest.mark.parametrize('function', FUNCTION_LIST) -def test_outout_len_match(function): +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_outout_len_match(function, transform): """Length of output vectors should match.""" - for transform in TRANSFORM_LIST: - mu, sigma = function(VALS, VALS, transform) - assert len(mu) == len(sigma) + mu, sigma = function(VALS, VALS, transform) + assert len(mu) == len(sigma) From b2c741840456ce12bfeb7aebd9913bb935fb62d1 Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Wed, 13 Oct 2021 17:48:44 -0700 Subject: [PATCH 28/32] add expected results test --- tests/test_transforms.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 58b6e5a..0a84206 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -5,7 +5,22 @@ from distrx.transforms import transform_data, delta_method -TRANSFORM_LIST = ['log', 'logit', 'exp', 'expit'] +TRANSFORM_DICT = { + 'log': [ + np.log, + lambda x: 1.0/x + ], 'logit': [ + lambda x: np.log(x/(1.0 - x)), + lambda x: 1.0/(x*(1.0 - x)) + ], 'exp': [ + np.exp, + np.exp + ], 'expit': [ + lambda x: 1.0/(1.0 + np.exp(-x)), + lambda x: np.exp(-x)/(1.0 + np.exp(-x))**2 + ] +} +TRANSFORM_LIST = list(TRANSFORM_DICT.keys()) FUNCTION_LIST = [transform_data, delta_method] VALS = [0.1]*2 @@ -65,3 +80,15 @@ def test_outout_len_match(function, transform): """Length of output vectors should match.""" mu, sigma = function(VALS, VALS, transform) assert len(mu) == len(sigma) + + +@pytest.mark.parametrize('transform', TRANSFORM_LIST) +def test_delta_result(transform): + """Check expected results.""" + mu = np.random.uniform(0.1, 1.0, size=10) + sigma = np.random.uniform(0.1, 1.0, size=10) + mu_ref = TRANSFORM_DICT[transform][0](mu) + sigma_ref = sigma*TRANSFORM_DICT[transform][1](mu) + mu_trans, sigma_trans = delta_method(mu, sigma, transform) + assert np.allclose(mu_trans, mu_ref) + assert np.allclose(sigma_trans, sigma_ref) From 8cdc66066e5528d999e9c1eb8e0756b6ce3f427c Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Mon, 15 Nov 2021 09:56:30 -0800 Subject: [PATCH 29/32] add test file --- src/distrx/temp.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/distrx/temp.py diff --git a/src/distrx/temp.py b/src/distrx/temp.py new file mode 100644 index 0000000..e69de29 From 72d1beec1ff9fbf05ab2896899d424fefd7df9dd Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Mon, 15 Nov 2021 09:57:18 -0800 Subject: [PATCH 30/32] remove test file --- src/distrx/temp.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/distrx/temp.py diff --git a/src/distrx/temp.py b/src/distrx/temp.py deleted file mode 100644 index e69de29..0000000 From 1526cd74e54b307829772198fad901b84215291c Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Mon, 15 Nov 2021 10:01:10 -0800 Subject: [PATCH 31/32] add test file --- src/distrx/temp.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/distrx/temp.py diff --git a/src/distrx/temp.py b/src/distrx/temp.py new file mode 100644 index 0000000..e69de29 From 67479c9bb8e28e780940dfa2b2861205e5308e9b Mon Sep 17 00:00:00 2001 From: Kelsey Maass Date: Mon, 15 Nov 2021 10:01:57 -0800 Subject: [PATCH 32/32] remove test file --- src/distrx/temp.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/distrx/temp.py diff --git a/src/distrx/temp.py b/src/distrx/temp.py deleted file mode 100644 index e69de29..0000000