diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 26e2d16..2caaaad 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -30,14 +30,14 @@ def test_method_name_valid(transform): """Raise ValueError for invalue `method`.""" with pytest.raises(ValueError): - transform_univariate(VALS, VALS, N, transform, method="dummy") + transform_univariate(VALS, VALS, transform, method="dummy") @pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST) def test_input_len_match(transform): """Raise ValueError if lengths of input vectors don't match.""" with pytest.raises(ValueError): - transform_univariate(VALS, VALS * 2, N, transform) + transform_univariate(VALS, VALS * 2, transform) @pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST) @@ -45,7 +45,7 @@ def test_sigma_negative(transform): """Raise ValueError if `sigma` contains negative values.""" vals = VALS + [-0.1] with pytest.raises(ValueError): - transform_univariate(vals, vals, N, transform) + transform_univariate(vals, vals, transform) @pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST) @@ -53,7 +53,7 @@ def test_sigma_zero(transform): """Display warning if `sigma` contains zeros.""" vals = VALS + [0.0] with pytest.warns(UserWarning): - transform_univariate(vals, vals, N, transform) + transform_univariate(vals, vals, transform) def test_transform_name_valid(): @@ -66,7 +66,7 @@ def test_transform_name_valid(): @pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST) def test_output_type(transform): """Output should be numpy arrays.""" - mu, sigma = transform_univariate(VALS, VALS, N, transform) + mu, sigma = transform_univariate(VALS, VALS, transform) assert isinstance(mu, np.ndarray) assert isinstance(sigma, np.ndarray) @@ -74,7 +74,7 @@ def test_output_type(transform): @pytest.mark.parametrize("transform", UNIVARIATE_TRANSFORM_LIST) def test_outout_len_match(transform): """Length of output vectors should match.""" - mu, sigma = transform_univariate(VALS, VALS, N, transform) + mu, sigma = transform_univariate(VALS, VALS, transform) assert len(mu) == len(sigma) @@ -84,8 +84,8 @@ def test_delta_result(transform): mu = np.random.uniform(0.1, 1.0, size=10) sigma = np.random.uniform(0.1, 1.0, size=10) mu_ref = UNIVARIATE_TRANSFORM_DICT[transform][0](mu) - sigma_ref = sigma * UNIVARIATE_TRANSFORM_DICT[transform][1](mu) / np.sqrt(N) - mu_trans, sigma_trans = transform_univariate(mu, sigma, N, transform) + sigma_ref = sigma * UNIVARIATE_TRANSFORM_DICT[transform][1](mu) + mu_trans, sigma_trans = transform_univariate(mu, sigma, transform) assert np.allclose(mu_trans, mu_ref) assert np.allclose(sigma_trans, sigma_ref)