Skip to content

Commit

Permalink
add docstring to bspline function module
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Jul 24, 2023
1 parent 3aecc29 commit fe7a421
Showing 1 changed file with 93 additions and 16 deletions.
109 changes: 93 additions & 16 deletions src/xspline/bspl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,24 @@


def cache_bspl(function: RawFunction) -> RawFunction:
"""Cache implementation for bspline basis functions, to avoid repetitively
evaluate functions.
Parameters
----------
function
Raw value, derivative and definite integral functions.
Returns
-------
describe
Cached version of the raw functions.
"""
cache = {}

def wrapper_function(*args, **kwargs) -> NDArray:
key = tuple(
tuple(x.ravel()) if isinstance(x, np.ndarray) else x for x in args
)
key = tuple(tuple(x.ravel()) if isinstance(x, np.ndarray) else x for x in args)
if key in cache:
return cache[key]
result = function(*args, **kwargs)
Expand All @@ -28,6 +40,22 @@ def cache_clear():

@cache_bspl
def bspl_val(params: BsplParams, x: NDArray) -> NDArray:
"""Value of the bspline function.
Parameters
----------
params
Bspline function parameters as a tuple including, knots, degree and the
index of the spline basis.
x
Data points.
Returns
-------
describe
Function value of the bspline function.
"""
# knots, degree, and index
t, k, i = params

Expand All @@ -44,10 +72,10 @@ def bspl_val(params: BsplParams, x: NDArray) -> NDArray:

if t[ii[0]] != t[ii[2]]:
n0 = bspl_val((t, k - 1, i), x)
val0 = (x - t[ii[0]])*n0/(t[ii[2]] - t[ii[0]])
val0 = (x - t[ii[0]]) * n0 / (t[ii[2]] - t[ii[0]])
if t[ii[1]] != t[ii[3]]:
n1 = bspl_val((t, k - 1, i + 1), x)
val1 = (t[ii[3]] - x)*n1/(t[ii[3]] - t[ii[1]])
val1 = (t[ii[3]] - x) * n1 / (t[ii[3]] - t[ii[1]])

val = val0 + val1

Expand All @@ -56,6 +84,22 @@ def bspl_val(params: BsplParams, x: NDArray) -> NDArray:

@cache_bspl
def bspl_der(params: BsplParams, x: NDArray, order: int) -> NDArray:
"""Derivative of the bspline function.
Parameters
----------
params
Bspline function parameters as a tuple including, knots, degree and the
index of the spline basis.
x
Data points.
Returns
-------
describe
Derivative of the bspline function.
"""
# knots, degree, and index
t, k, i = params

Expand All @@ -72,10 +116,10 @@ def bspl_der(params: BsplParams, x: NDArray, order: int) -> NDArray:

if t[ii[0]] != t[ii[2]]:
n0 = bspl_der((t, k - 1, i), x, order - 1)
val0 = k*n0/(t[ii[2]] - t[ii[0]])
val0 = k * n0 / (t[ii[2]] - t[ii[0]])
if t[ii[1]] != t[ii[3]]:
n1 = bspl_der((t, k - 1, i + 1), x, order - 1)
val1 = k*n1/(t[ii[3]] - t[ii[1]])
val1 = k * n1 / (t[ii[3]] - t[ii[1]])

val = val0 - val1

Expand All @@ -84,6 +128,22 @@ def bspl_der(params: BsplParams, x: NDArray, order: int) -> NDArray:

@cache_bspl
def bspl_int(params: BsplParams, x: NDArray, order: int) -> NDArray:
"""Definite integral of the bspline function.
Parameters
----------
params
Bspline function parameters as a tuple including, knots, degree and the
index of the spline basis.
x
Data points.
Returns
-------
describe
Definite integral of the bspline function.
"""
# knots, degree, and index
t, k, i = params

Expand All @@ -101,21 +161,25 @@ def bspl_int(params: BsplParams, x: NDArray, order: int) -> NDArray:

if t[ii[0]] != t[ii[2]]:
val0 = (
(x - t[ii[0]])*bspl_int((t, k - 1, i), x, order) +
order*bspl_int((t, k - 1, i), x, order - 1)
)/(t[ii[2]] - t[ii[0]])
(x - t[ii[0]]) * bspl_int((t, k - 1, i), x, order)
+ order * bspl_int((t, k - 1, i), x, order - 1)
) / (t[ii[2]] - t[ii[0]])
if t[ii[1]] != t[ii[3]]:
val1 = (
(t[ii[3]] - x)*bspl_int((t, k - 1, i + 1), x, order) -
order*bspl_int((t, k - 1, i + 1), x, order - 1)
)/(t[ii[3]] - t[ii[1]])
(t[ii[3]] - x) * bspl_int((t, k - 1, i + 1), x, order)
- order * bspl_int((t, k - 1, i + 1), x, order - 1)
) / (t[ii[3]] - t[ii[1]])

val = val0 + val1

return val


def clear_bspl_cache() -> None:
"""Clear all cache of the value, derivative and definite integral for
bspline function.
"""
bspl_val.cache_clear()
bspl_der.cache_clear()
bspl_int.cache_clear()
Expand Down Expand Up @@ -149,6 +213,19 @@ def __init__(self, params: BsplParams) -> None:


def get_bspl_funs(knots: tuple[float, ...], degree: int) -> tuple[Bspl]:
return tuple(
Bspl((knots, degree, i)) for i in range(-degree, len(knots) - 1)
)
"""Create the bspline basis functions give knots and degree.
Parameters
----------
knots
Bspline knots.
degree
Bspline degree.
Returns
-------
describe
A full set of bspline functions.
"""
return tuple(Bspl((knots, degree, i)) for i in range(-degree, len(knots) - 1))

0 comments on commit fe7a421

Please sign in to comment.