Skip to content

Commit

Permalink
change coefs to coef
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Jul 26, 2023
1 parent ed8617a commit c9ac1e3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
30 changes: 15 additions & 15 deletions src/xspline/xfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,40 +193,40 @@ class BasisXFunction(XFunction):
----------
basis_funs
A set of instances of ``XFunction`` as basis functions.
coefs
coef
Coefficients for the linearly combine the basis functions.
"""

coefs = property(attrgetter("_coefs"))
coef = property(attrgetter("_coef"))

def __init__(
self, basis_funs: tuple[XFunction, ...], coefs: NDArray | None = None
self, basis_funs: tuple[XFunction, ...], coef: NDArray | None = None
) -> None:
if not all(isinstance(fun, XFunction) for fun in basis_funs):
raise TypeError("basis functions must all be instances of " "`XFunction`")
raise TypeError("basis functions must all be instances of 'XFunction'")
self.basis_funs = tuple(basis_funs)
self.coefs = coefs
self.coef = coef

def fun(x: NDArray, order: int = 0) -> NDArray:
if self.coefs is None:
if self.coef is None:
raise ValueError(
"please provide the coefficients for the basis " "functions"
"please provide the coefficients for the basis functions"
)
design_mat = self.get_design_mat(x, order=order, check_args=False)
return design_mat.dot(self.coefs)
return design_mat.dot(self.coef)

super().__init__(fun)

@coefs.setter
def coefs(self, coefs: NDArray | None) -> None:
if coefs is not None:
coefs = np.asarray(coefs, dtype=float).ravel()
if coefs.size != len(self):
@coef.setter
def coef(self, coef: NDArray | None) -> None:
if coef is not None:
coef = np.asarray(coef, dtype=float).ravel()
if coef.size != len(self):
raise ValueError(
"number of coeffcients does not match number " "of basis functions"
"number of coeffcients does not match number of basis functions"
)
self._coefs = coefs
self._coef = coef

def get_design_mat(
self, x: NDArray, order: int = 0, check_args: bool = True
Expand Down
6 changes: 3 additions & 3 deletions src/xspline/xspl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class XSpline(BasisXFunction):
Left extrapolation polynomial degree.
rdegree
Right extrapolation polynomial degree.
coefs
coef
The coefficients for linear combining the spline basis.
"""
Expand All @@ -30,7 +30,7 @@ def __init__(
degree: int,
ldegree: Optional[int] = None,
rdegree: Optional[int] = None,
coefs: Optional[NDArray] = None,
coef: Optional[NDArray] = None,
) -> None:
# validate inputs
knots, degree = tuple(sorted(map(float, knots))), int(degree)
Expand All @@ -52,7 +52,7 @@ def __init__(

self.knots, self.degree = knots, degree
self.ldegree, self.rdegree = ldegree, rdegree
super().__init__(funs, coefs=coefs)
super().__init__(funs, coef=coef)

def get_design_mat(
self, x: NDArray, order: int = 0, check_args: bool = True
Expand Down
8 changes: 4 additions & 4 deletions tests/test_basis_xfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def test_len(xfun):
assert len(xfun) == 2


def test_coefs_error(xfun):
coefs = [1, 2, 3]
def test_coef_error(xfun):
coef = [1, 2, 3]
with pytest.raises(ValueError):
xfun.coefs = coefs
xfun.coef = coef


@pytest.mark.parametrize("order", [-1, 0, 1])
Expand All @@ -31,7 +31,7 @@ def test_get_design_mat(xfun, order):

@pytest.mark.parametrize("order", [-1, 0, 1])
def test_fun(xfun, order):
xfun.coefs = [1, 1]
xfun.coef = [1, 1]
x = np.linspace(-0.5, 3.0, 101)
result = xfun(x, order=order)
assert result.shape == x.shape
Expand Down

0 comments on commit c9ac1e3

Please sign in to comment.