Skip to content

Commit

Permalink
Fix minor bugs of include first basis (#9)
Browse files Browse the repository at this point in the history
Fix minor bugs of include_first_basis
  • Loading branch information
zhengp0 committed Dec 16, 2020
2 parents 1d9030a + 1f355f6 commit 358e35d
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions src/xspline/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(self,
degree,
l_linear=False,
r_linear=False,
include_first_basis: bool = False):
include_first_basis: bool = True):
r"""Constructor of the XSpline class.
knots (numpy.ndarray):
Expand All @@ -301,7 +301,7 @@ def __init__(self,
self.degree = degree
self.l_linear = l_linear
self.r_linear = r_linear
self.include_first_basis = include_first_basis
self.basis_start = int(not include_first_basis)

# dimensions
self.num_knots = knots.size
Expand All @@ -321,7 +321,7 @@ def __init__(self,
self.inner_lb = self.inner_knots[0]
self.inner_ub = self.inner_knots[-1]

self.num_spline_bases = self.inner_knots.size - 1 + self.degree - self.include_first_basis
self.num_spline_bases = self.inner_knots.size - 1 + self.degree - self.basis_start

def domain(self, idx, l_extra=False, r_extra=False):
"""Return the support of the XSpline.
Expand Down Expand Up @@ -651,8 +651,7 @@ def design_mat(self, x, l_extra=False, r_extra=False):
"""
mat = np.vstack([
self.fun(x, idx, l_extra=l_extra, r_extra=r_extra)
for idx in range(self.include_first_basis,
self.num_spline_bases)
for idx in range(self.basis_start, self.num_spline_bases)
]).T
return mat

Expand Down Expand Up @@ -680,8 +679,7 @@ def design_dmat(self, x, order, l_extra=False, r_extra=False):
"""
dmat = np.vstack([
self.dfun(x, order, idx, l_extra=l_extra, r_extra=r_extra)
for idx in range(self.include_first_basis,
self.num_spline_bases)
for idx in range(self.basis_start, self.num_spline_bases)
]).T
return dmat

Expand Down Expand Up @@ -714,8 +712,7 @@ def design_imat(self, a, x, order, l_extra=False, r_extra=False):
"""
imat = np.vstack([
self.ifun(a, x, order, idx, l_extra=l_extra, r_extra=r_extra)
for idx in range(self.include_first_basis,
self.num_spline_bases)
for idx in range(self.basis_start, self.num_spline_bases)
]).T
return imat

Expand Down Expand Up @@ -747,7 +744,7 @@ class NDXSpline:
def __init__(self, ndim, knots_list, degree_list,
l_linear_list=None,
r_linear_list=None,
include_first_basis_list=None):
include_first_basis_list=True):
"""Constructor of ndXSpline class
Args:
Expand Down

0 comments on commit 358e35d

Please sign in to comment.