Skip to content

Commit

Permalink
add docstring to xfunction
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengp0 committed Jul 24, 2023
1 parent 5b576c9 commit 676bb30
Showing 1 changed file with 153 additions and 28 deletions.
181 changes: 153 additions & 28 deletions src/xspline/xfunction.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,59 @@
from __future__ import annotations
from functools import partial
from math import factorial
from operator import attrgetter
from typing import Optional

import numpy as np
from numpy.typing import NDArray

from xspline.typing import (BoundaryPoint, Callable, RawDFunction,
RawIFunction, RawVFunction)
from xspline.typing import (
BoundaryPoint,
Callable,
RawDFunction,
RawIFunction,
RawVFunction,
NDArray,
)


class XFunction:
"""Function interface that provide easy access to function value,
derivatives and definite integrals.
def __init__(self, fun: Optional[Callable] = None) -> None:
There are two different ways to use this class, either use this class to
wrap around a function that aligns with the ``XFunction`` function call
interface. The added benefit is ``XFunction`` will automatic check and parse
the inputs. The other way is to inherit this class and implement a ``fun``
member function for the function call.
Parameters
----------
fun
Optional function implementation. If this is ``None``, it will require
user to implement ``fun`` class member function.
"""

def __init__(self, fun: Callable | None = None) -> None:
if not hasattr(self, "fun"):
self.fun = fun

def _check_args(self, x: NDArray, order: int) -> tuple[NDArray, int, bool]:
x, order = np.asarray(x, dtype=float), int(order)
if (x.ndim not in [0, 1, 2]) or (x.ndim == 2 and len(x) != 2):
raise ValueError("please provide a scalar, an 1d array, or a 2d "
"array with two rows")
raise ValueError(
"please provide a scalar, an 1d array, or a 2d " "array with two rows"
)

# reshape array
isscalar = x.ndim == 0
if isscalar:
x = x.ravel()
if order >= 0 and x.ndim == 2:
raise ValueError("please provide an 1d array for function value "
"defivative computation")
raise ValueError(
"please provide an 1d array for function value "
"defivative computation"
)
if order < 0 and x.ndim == 1:
x = np.vstack([np.repeat(x.min(), x.size), x])

Expand All @@ -44,7 +69,7 @@ def __call__(self, x: NDArray, order: int = 0) -> NDArray:
Parameters
----------
x
Data points where the function is evaluated. If `order < 0` and `x`
Data points where the function is evaluated. If `order < 0` and `x`
is a 2d array with two rows, the rows will be treated as the
starting and ending points for definite interval. If `order < 0` and
`x` is a 1d array, function will use the smallest number in `x` as
Expand Down Expand Up @@ -85,7 +110,19 @@ def __call__(self, x: NDArray, order: int = 0) -> NDArray:
result = result[0]
return result

def append(self, other: "XFunction", sep: BoundaryPoint) -> "XFunction":
def append(self, other: XFunction, sep: BoundaryPoint) -> XFunction:
"""Splice with another instance of ``XFunction`` to create a new
``XFunction``.
Parameters
----------
other
Another ``XFunction`` after the current function.
sep
The boundary point to separate two functions, before is the current
function and after the the ``other`` function.
"""

def fun(x: NDArray, order: int = 0) -> NDArray:
left = x <= sep[0] if sep[1] else x < sep[0]
Expand Down Expand Up @@ -114,18 +151,53 @@ def fun(x: NDArray, order: int = 0) -> NDArray:


class BundleXFunction(XFunction):

def __init__(self,
params: tuple,
val_fun: RawVFunction,
der_fun: RawDFunction,
int_fun: RawIFunction) -> None:
"""This is one implementation of the ``XFunction``, it takes the value,
derivative and definite integral function and bundle them together as a
``XFunction``.
Parameters
----------
params
This is the parameters that is needed for the value, derivatives and
the definitely integral function.
val_fun
Value function.
der_fun
Derviative function.
int_fun
Defintie integral function.
"""

def __init__(
self,
params: tuple,
val_fun: RawVFunction,
der_fun: RawDFunction,
int_fun: RawIFunction,
) -> None:
self.params = params
self.val_fun = partial(val_fun, params)
self.der_fun = partial(der_fun, params)
self.int_fun = partial(int_fun, params)

def fun(self, x: NDArray, order: int = 0) -> NDArray:
"""Function implementation, aligns with the ``XFunction`` function call
interface.
Parameters
----------
x
Data points
order
Order of differentiation/integration.
Returns
-------
describe
Function value, derivatives or definite integrals.
"""
if order == 0:
return self.val_fun(x)
if order > 0:
Expand All @@ -138,15 +210,26 @@ def fun(self, x: NDArray, order: int = 0) -> NDArray:


class BasisXFunction(XFunction):
"""This is one implementation of ``XFunction`` by taking in a set of
instances of ``XFunction`` as basis functions. And the linear combination
coefficients to provide function value, derivative and definite integral.
Parameters
----------
basis_funs
A set of instances of ``XFunction`` as basis functions.
coefs
Coefficients for the linearly combine the basis functions.
"""

coefs = property(attrgetter("_coefs"))

def __init__(self,
basis_funs: tuple[XFunction, ...],
coefs: Optional[NDArray] = None) -> None:
def __init__(
self, basis_funs: tuple[XFunction, ...], coefs: Optional[NDArray] = None
) -> None:
if not all(isinstance(fun, XFunction) for fun in basis_funs):
raise TypeError("basis functions must all be instances of "
"`XFunction`")
raise TypeError("basis functions must all be instances of " "`XFunction`")
self.basis_funs = tuple(basis_funs)
self.coefs = coefs

Expand All @@ -155,23 +238,65 @@ def coefs(self, coefs: Optional[NDArray]) -> None:
if coefs is not None:
coefs = np.asarray(coefs, dtype=float).ravel()
if coefs.size != len(self):
raise ValueError("number of coeffcients does not match number "
"of basis functions")
raise ValueError(
"number of coeffcients does not match number " "of basis functions"
)
self._coefs = coefs

def get_design_mat(self, x: NDArray,
order: int = 0,
check_args: bool = True) -> NDArray:
def get_design_mat(
self, x: NDArray, order: int = 0, check_args: bool = True
) -> NDArray:
"""Provide design matrix from the set of basis functions.
Parameters
----------
x
Data points
order
Order of differentiation/integration.
check_args
If ``True`` it will check and parse the arguments.
Returns
-------
describe
Design matrix with dimention number of data points by number of
basis functions.
"""
if check_args:
x, order, _ = self._check_args(x, order)
return np.vstack([fun.fun(x, order) for fun in self.basis_funs]).T

def fun(self, x: NDArray, order: int = 0) -> NDArray:
"""Function implementation, aligns with the ``XFunction`` function call
interface.
Parameters
----------
x
Data points
order
Order of differentiation/integration.
Returns
-------
describe
Function value, derivatives or definite integrals.
Raises
------
ValueError
Raised when the ``coefs`` is not provided.
"""
if self.coefs is None:
raise ValueError("please provide the coefficients for the basis "
"functions")
raise ValueError(
"please provide the coefficients for the basis " "functions"
)
design_mat = self.get_design_mat(x, order=order, check_args=False)
return design_mat.dot(self.coefs)

def __len__(self) -> int:
"""Number of basis functions."""
return len(self.basis_funs)

0 comments on commit 676bb30

Please sign in to comment.