Skip to content

Commit

Permalink
Add include first spline basis option (#8)
Browse files Browse the repository at this point in the history
Add include first spline basis option
  • Loading branch information
zhengp0 committed Dec 16, 2020
2 parents 4b38df3 + 33ee68d commit 1d9030a
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions src/xspline/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,13 @@ def bspline_ifun(a, x, knots, degree, order, idx, l_extra=False, r_extra=False):
class XSpline:
"""XSpline main class of the package.
"""
def __init__(self, knots, degree, l_linear=False, r_linear=False):

def __init__(self,
knots,
degree,
l_linear=False,
r_linear=False,
include_first_basis: bool = False):
r"""Constructor of the XSpline class.
knots (numpy.ndarray):
Expand All @@ -295,6 +301,7 @@ def __init__(self, knots, degree, l_linear=False, r_linear=False):
self.degree = degree
self.l_linear = l_linear
self.r_linear = r_linear
self.include_first_basis = include_first_basis

# dimensions
self.num_knots = knots.size
Expand All @@ -314,7 +321,7 @@ def __init__(self, knots, degree, l_linear=False, r_linear=False):
self.inner_lb = self.inner_knots[0]
self.inner_ub = self.inner_knots[-1]

self.num_spline_bases = self.inner_knots.size - 1 + self.degree
self.num_spline_bases = self.inner_knots.size - 1 + self.degree - self.include_first_basis

def domain(self, idx, l_extra=False, r_extra=False):
"""Return the support of the XSpline.
Expand Down Expand Up @@ -644,7 +651,8 @@ def design_mat(self, x, l_extra=False, r_extra=False):
"""
mat = np.vstack([
self.fun(x, idx, l_extra=l_extra, r_extra=r_extra)
for idx in range(self.num_spline_bases)
for idx in range(self.include_first_basis,
self.num_spline_bases)
]).T
return mat

Expand Down Expand Up @@ -672,7 +680,8 @@ def design_dmat(self, x, order, l_extra=False, r_extra=False):
"""
dmat = np.vstack([
self.dfun(x, order, idx, l_extra=l_extra, r_extra=r_extra)
for idx in range(self.num_spline_bases)
for idx in range(self.include_first_basis,
self.num_spline_bases)
]).T
return dmat

Expand Down Expand Up @@ -705,7 +714,8 @@ def design_imat(self, a, x, order, l_extra=False, r_extra=False):
"""
imat = np.vstack([
self.ifun(a, x, order, idx, l_extra=l_extra, r_extra=r_extra)
for idx in range(self.num_spline_bases)
for idx in range(self.include_first_basis,
self.num_spline_bases)
]).T
return imat

Expand Down Expand Up @@ -733,9 +743,11 @@ def last_dmat(self):
class NDXSpline:
"""Multi-dimensional xspline.
"""

def __init__(self, ndim, knots_list, degree_list,
l_linear_list=None,
r_linear_list=None):
r_linear_list=None,
include_first_basis_list=None):
"""Constructor of ndXSpline class
Args:
Expand All @@ -757,11 +769,13 @@ def __init__(self, ndim, knots_list, degree_list,
self.degree_list = degree_list
self.l_linear_list = utils.option_to_list(l_linear_list, self.ndim)
self.r_linear_list = utils.option_to_list(r_linear_list, self.ndim)
self.include_first_basis_list = utils.option_to_list(include_first_basis_list, self.ndim)

self.spline_list = [
XSpline(self.knots_list[i], self.degree_list[i],
l_linear=self.l_linear_list[i],
r_linear=self.r_linear_list[i])
r_linear=self.r_linear_list[i],
include_first_basis=self.include_first_basis_list[i])
for i in range(self.ndim)
]

Expand Down

0 comments on commit 1d9030a

Please sign in to comment.