Splines#
B-spline utilities for non-linear feature transformations.
Typical usage:
from blayers.splines import make_knots, bspline_basis
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link
knots = make_knots(x_train, num_knots=10)
def model(x, y=None):
B = bspline_basis(x, knots)
f = AdaptiveLayer()("f", B)
return gaussian_link(f, y)
- blayers.splines.bspline_basis(x, knots, degree=3)[source]#
Compute the B-spline design matrix via Cox–de Boor recursion (JAX-compatible).
- Parameters:
x (Array) – 1D input array of shape
(n,).knots (Array) – Full clamped knot vector of shape
(num_basis + degree + 1,). Usemake_knotsto construct this.degree (int) – B-spline degree (3 = cubic).
- Returns:
jax.Array of shape
(n, num_basis)wherenum_basis = len(knots) - degree - 1.- Return type:
Array
- blayers.splines.make_knots(x, num_knots, degree=3)[source]#
Compute a clamped B-spline knot vector from data.
Interior knots are placed at evenly-spaced quantiles of
x. Call this once at preprocessing time (outside any JAX-traced function) and pass the returned array tobspline_basis.- Parameters:
x (Any) – Reference data (any shape). Only used for quantile computation.
num_knots (int) – Number of interior knots. The total number of basis functions will be
num_knots + degree + 1.degree (int) – B-spline degree (default 3 for cubic splines).
- Returns:
Full clamped knot vector as a
jax.Arrayof shape(num_knots + 2 * (degree + 1),).- Return type:
Array
Example:
knots = make_knots(x_train, num_knots=5) B = bspline_basis(x, knots)