Skip to content

Commit

Permalink
fix test failures by removing sample size scaling in testing file
Browse files Browse the repository at this point in the history
  • Loading branch information
mbi6245 committed Jul 19, 2024
1 parent 549a2c9 commit 76bd5e1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,30 @@
def test_method_name_valid(transform):
"""Raise ValueError for invalue `method`."""
with pytest.raises(ValueError):
transform_univariate(VALS, VALS, N, transform, method="dummy")
transform_univariate(VALS, VALS, transform, method="dummy")


@pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST)
def test_input_len_match(transform):
"""Raise ValueError if lengths of input vectors don't match."""
with pytest.raises(ValueError):
transform_univariate(VALS, VALS * 2, N, transform)
transform_univariate(VALS, VALS * 2, transform)


@pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST)
def test_sigma_negative(transform):
"""Raise ValueError if `sigma` contains negative values."""
vals = VALS + [-0.1]
with pytest.raises(ValueError):
transform_univariate(vals, vals, N, transform)
transform_univariate(vals, vals, transform)


@pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST)
def test_sigma_zero(transform):
"""Display warning if `sigma` contains zeros."""
vals = VALS + [0.0]
with pytest.warns(UserWarning):
transform_univariate(vals, vals, N, transform)
transform_univariate(vals, vals, transform)


def test_transform_name_valid():
Expand All @@ -66,15 +66,15 @@ def test_transform_name_valid():
@pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST)
def test_output_type(transform):
"""Output should be numpy arrays."""
mu, sigma = transform_univariate(VALS, VALS, N, transform)
mu, sigma = transform_univariate(VALS, VALS, transform)
assert isinstance(mu, np.ndarray)
assert isinstance(sigma, np.ndarray)


@pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST)
def test_outout_len_match(transform):
"""Length of output vectors should match."""
mu, sigma = transform_univariate(VALS, VALS, N, transform)
mu, sigma = transform_univariate(VALS, VALS, transform)
assert len(mu) == len(sigma)


Expand All @@ -84,8 +84,8 @@ def test_delta_result(transform):
mu = np.random.uniform(0.1, 1.0, size=10)
sigma = np.random.uniform(0.1, 1.0, size=10)
mu_ref = UNIVARIATE_TRANSFORM_DICT[transform][0](mu)
sigma_ref = sigma * UNIVARIATE_TRANSFORM_DICT[transform][1](mu) / np.sqrt(N)
mu_trans, sigma_trans = transform_univariate(mu, sigma, N, transform)
sigma_ref = sigma * UNIVARIATE_TRANSFORM_DICT[transform][1](mu)
mu_trans, sigma_trans = transform_univariate(mu, sigma, transform)
assert np.allclose(mu_trans, mu_ref)
assert np.allclose(sigma_trans, sigma_ref)

Expand Down

0 comments on commit 76bd5e1

Please sign in to comment.