diff --git a/src/xspline/core.py b/src/xspline/core.py index aec9892..0f6b4e6 100644 --- a/src/xspline/core.py +++ b/src/xspline/core.py @@ -278,7 +278,7 @@ def __init__(self, degree, l_linear=False, r_linear=False, - include_first_basis: bool = False): + include_first_basis: bool = True): r"""Constructor of the XSpline class. knots (numpy.ndarray): @@ -301,7 +301,7 @@ def __init__(self, self.degree = degree self.l_linear = l_linear self.r_linear = r_linear - self.include_first_basis = include_first_basis + self.basis_start = int(not include_first_basis) # dimensions self.num_knots = knots.size @@ -321,7 +321,7 @@ def __init__(self, self.inner_lb = self.inner_knots[0] self.inner_ub = self.inner_knots[-1] - self.num_spline_bases = self.inner_knots.size - 1 + self.degree - self.include_first_basis + self.num_spline_bases = self.inner_knots.size - 1 + self.degree - self.basis_start def domain(self, idx, l_extra=False, r_extra=False): """Return the support of the XSpline. @@ -651,8 +651,7 @@ def design_mat(self, x, l_extra=False, r_extra=False): """ mat = np.vstack([ self.fun(x, idx, l_extra=l_extra, r_extra=r_extra) - for idx in range(self.include_first_basis, - self.num_spline_bases) + for idx in range(self.basis_start, self.num_spline_bases) ]).T return mat @@ -680,8 +679,7 @@ def design_dmat(self, x, order, l_extra=False, r_extra=False): """ dmat = np.vstack([ self.dfun(x, order, idx, l_extra=l_extra, r_extra=r_extra) - for idx in range(self.include_first_basis, - self.num_spline_bases) + for idx in range(self.basis_start, self.num_spline_bases) ]).T return dmat @@ -714,8 +712,7 @@ def design_imat(self, a, x, order, l_extra=False, r_extra=False): """ imat = np.vstack([ self.ifun(a, x, order, idx, l_extra=l_extra, r_extra=r_extra) - for idx in range(self.include_first_basis, - self.num_spline_bases) + for idx in range(self.basis_start, self.num_spline_bases) ]).T return imat @@ -747,7 +744,7 @@ class NDXSpline: def __init__(self, ndim, knots_list, degree_list, l_linear_list=None, r_linear_list=None, - include_first_basis_list=None): + include_first_basis_list=True): """Constructor of ndXSpline class Args: