diff --git a/src/xspline/core.py b/src/xspline/core.py index f25f7cc..aec9892 100644 --- a/src/xspline/core.py +++ b/src/xspline/core.py @@ -272,7 +272,13 @@ def bspline_ifun(a, x, knots, degree, order, idx, l_extra=False, r_extra=False): class XSpline: """XSpline main class of the package. """ - def __init__(self, knots, degree, l_linear=False, r_linear=False): + + def __init__(self, + knots, + degree, + l_linear=False, + r_linear=False, + include_first_basis: bool = False): r"""Constructor of the XSpline class. knots (numpy.ndarray): @@ -295,6 +301,7 @@ def __init__(self, knots, degree, l_linear=False, r_linear=False): self.degree = degree self.l_linear = l_linear self.r_linear = r_linear + self.include_first_basis = include_first_basis # dimensions self.num_knots = knots.size @@ -314,7 +321,7 @@ def __init__(self, knots, degree, l_linear=False, r_linear=False): self.inner_lb = self.inner_knots[0] self.inner_ub = self.inner_knots[-1] - self.num_spline_bases = self.inner_knots.size - 1 + self.degree + self.num_spline_bases = self.inner_knots.size - 1 + self.degree - self.include_first_basis def domain(self, idx, l_extra=False, r_extra=False): """Return the support of the XSpline. @@ -644,7 +651,8 @@ def design_mat(self, x, l_extra=False, r_extra=False): """ mat = np.vstack([ self.fun(x, idx, l_extra=l_extra, r_extra=r_extra) - for idx in range(self.num_spline_bases) + for idx in range(self.include_first_basis, + self.num_spline_bases) ]).T return mat @@ -672,7 +680,8 @@ def design_dmat(self, x, order, l_extra=False, r_extra=False): """ dmat = np.vstack([ self.dfun(x, order, idx, l_extra=l_extra, r_extra=r_extra) - for idx in range(self.num_spline_bases) + for idx in range(self.include_first_basis, + self.num_spline_bases) ]).T return dmat @@ -705,7 +714,8 @@ def design_imat(self, a, x, order, l_extra=False, r_extra=False): """ imat = np.vstack([ self.ifun(a, x, order, idx, l_extra=l_extra, r_extra=r_extra) - for idx in range(self.num_spline_bases) + for idx in range(self.include_first_basis, + self.num_spline_bases) ]).T return imat @@ -733,9 +743,11 @@ def last_dmat(self): class NDXSpline: """Multi-dimensional xspline. """ + def __init__(self, ndim, knots_list, degree_list, l_linear_list=None, - r_linear_list=None): + r_linear_list=None, + include_first_basis_list=None): """Constructor of ndXSpline class Args: @@ -757,11 +769,13 @@ def __init__(self, ndim, knots_list, degree_list, self.degree_list = degree_list self.l_linear_list = utils.option_to_list(l_linear_list, self.ndim) self.r_linear_list = utils.option_to_list(r_linear_list, self.ndim) + self.include_first_basis_list = utils.option_to_list(include_first_basis_list, self.ndim) self.spline_list = [ XSpline(self.knots_list[i], self.degree_list[i], l_linear=self.l_linear_list[i], - r_linear=self.r_linear_list[i]) + r_linear=self.r_linear_list[i], + include_first_basis=self.include_first_basis_list[i]) for i in range(self.ndim) ]