Splines

Splines#

B-spline utilities for non-linear feature transformations.

Typical usage:

from blayers.splines import make_knots, bspline_basis
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link

knots = make_knots(x_train, num_knots=10)

def model(x, y=None):
    B = bspline_basis(x, knots)
    f = AdaptiveLayer()("f", B)
    return gaussian_link(f, y)
blayers.splines.bspline_basis(x, knots, degree=3)[source]#

Compute the B-spline design matrix via Cox–de Boor recursion (JAX-compatible).

Parameters:
  • x (Array) – 1D input array of shape (n,).

  • knots (Array) – Full clamped knot vector of shape (num_basis + degree + 1,). Use make_knots to construct this.

  • degree (int) – B-spline degree (3 = cubic).

Returns:

jax.Array of shape (n, num_basis) where num_basis = len(knots) - degree - 1.

Return type:

Array

blayers.splines.make_knots(x, num_knots, degree=3)[source]#

Compute a clamped B-spline knot vector from data.

Interior knots are placed at evenly-spaced quantiles of x. Call this once at preprocessing time (outside any JAX-traced function) and pass the returned array to bspline_basis.

Parameters:
  • x (Any) – Reference data (any shape). Only used for quantile computation.

  • num_knots (int) – Number of interior knots. The total number of basis functions will be num_knots + degree + 1.

  • degree (int) – B-spline degree (default 3 for cubic splines).

Returns:

Full clamped knot vector as a jax.Array of shape (num_knots + 2 * (degree + 1),).

Return type:

Array

Example:

knots = make_knots(x_train, num_knots=5)
B = bspline_basis(x, knots)