"""
Implements Bayesian Layers using Jax and Numpyro.
Design:
- There are three levels of complexity here: class-level, instance-level, and
call-level
- The class-level handles things like choosing generic model form and how to
multiply coefficents with data. Defined by the ``class Layer(BLayer)`` def
itself.
- The instance-level handles specific distributions that fit into a generic
model and the initial parameters for those distributions. Defined by
creating an instance of the class: ``Layer(*args, **kwargs)``.
- The call-level handles seeing a batch of data, sampling from the
distributions defined on the class and multiplying coefficients and data to
produce an output, works like ``result = Layer(*args, **kwargs)(data)``
Notation:
- ``n``: observations in a batch
- ``c``: number of categories of things for time, random effects, etc
- ``d``: number of coefficients
- ``l``: low rank dimension of low rank models
- ``m``: embedding dimension
- ``u``: units aka output dimension
"""
from abc import ABC, abstractmethod
from typing import Any, Callable
import jax
import jax.nn as jnn
import jax.numpy as jnp
import numpy as np
from numpyro import distributions, sample
from blayers._utils import add_trailing_dim
# ---- Matmul functions ------------------------------------------------------ #
def _pairwise_interactions(x: jax.Array, z: jax.Array) -> jax.Array:
"""
Compute all pairwise interactions between features in X and Y.
Parameters:
X: (n_samples, n_features1)
Y: (n_samples, n_features2)
Returns:
interactions: (n_samples, n_features1 * n_features2)
"""
n, d1 = x.shape
_, d2 = z.shape
return jnp.reshape(x[:, :, None] * z[:, None, :], (n, d1 * d2))
def _matmul_dot_product(x: jax.Array, beta: jax.Array) -> jax.Array:
"""Standard dot product between beta and x.
Args:
beta: Coefficient vector of shape `(d, u)`.
x: Input matrix of shape `(n, d)`.
Returns:
jax.Array: Output of shape `(n, u)`.
"""
return jnp.einsum("nd,du->nu", x, beta)
def _matmul_factorization_machine(x: jax.Array, theta: jax.Array) -> jax.Array:
"""Apply second-order factorization machine interaction.
Based on Rendle (2010). Computes:
.. math::
0.5 * sum((xV)^2 - (x^2 V^2))
Args:
theta: Weight matrix of shape `(d, l, u)`.
x: Input data of shape `(n, d)`.
Returns:
jax.Array: Output of shape `(n, u)`.
"""
vx2 = jnp.einsum("nd,dlu->nlu", x, theta) ** 2
v2x2 = jnp.einsum("nd,dlu->nlu", x**2, theta**2)
return 0.5 * jnp.einsum("nlu->nu", vx2 - v2x2)
def _matmul_fm3(x: jax.Array, theta: jax.Array) -> jax.Array:
"""Apply second-order factorization machine interaction.
Based on Rendle (2010). Computes:
.. math::
0.5 * sum((xV)^2 - (x^2 V^2))
Args:
theta: Weight matrix of shape `(d, l, u)`.
x: Input data of shape `(n, d)`.
Returns:
jax.Array: Output of shape `(n, u)`.
"""
# x: (n_features,)
# E: (n_features, k) embedding matrix
linear_sum = jnp.einsum("nd,dlu->nlu", x, theta) # jnp.dot(x, theta)
square_sum = jnp.einsum(
"nd,dlu->nlu", x**2, theta**2
) # jnp.dot(x**2, theta**2)
cube_sum = jnp.einsum(
"nd,dlu->nlu", x**3, theta**3
) # jnp.dot(x**3, theta**3)
term = (
linear_sum**3 - 3.0 * square_sum * linear_sum + 2.0 * cube_sum
) / 6.0
return jnp.einsum("nlu->nu", term) # scalar
def _matmul_uv_decomp(
theta1: jax.Array,
theta2: jax.Array,
x: jax.Array,
z: jax.Array,
) -> jax.Array:
"""Implements low rank multiplication.
According to ChatGPT this is a "factorized bilinear interaction".
Basically, you just need to project x and z down to a common number of
low rank terms and then just multiply those terms.
This is equivalent to a UV decomposition where you use n=low_rank_dim
on the columns of the U/V matrices.
Args:
theta1: Weight matrix of shape `(d1, l, u)`.
theta2: Weight matrix of shape `(d2, l, u)`.
x: Input data of shape `(n, d1)`.
z: Input data of shape `(n, d2)`.
Returns:
jax.Array: Output of shape `(n, u)`.
"""
xb = jnp.einsum("nd,dlu->nlu", x, theta1)
zb = jnp.einsum("nd,dlu->nlu", z, theta2)
return jnp.einsum("nlu->nu", xb * zb)
def _matmul_randomwalk(
theta: jax.Array,
idx: jax.Array,
) -> jax.Array:
"""Vertical cumsum and then picks out index.
We do a vertical cumsum of `theta` across `m` embedding dimensions, and then
pick out the index.
Args:
theta: Weight matrix of shape `(c, m)`
idx: Integer indexes of shape `(n, 1)` or `(n,)` with indexes up to `c`
Returns:
jax.Array: Output of shape `(n, m)`
"""
theta_cumsum = jnp.cumsum(theta, axis=0)
idx_flat = idx.squeeze().astype(jnp.int32)
return theta_cumsum[idx_flat]
def _matmul_interaction(
beta: jax.Array,
x: jax.Array,
z: jax.Array,
) -> jax.Array:
"""Full interaction between `x` and `z`.
Args:
beta: Weight matrix for each interaction between `x` and `z`.
x: First feature matrix.
z: Second feature matrix.
Returns:
jax.Array
"""
# thanks chat GPT
interactions = _pairwise_interactions(x, z)
return jnp.einsum("nd,du->nu", interactions, beta)
# ---- Classes --------------------------------------------------------------- #
def _validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist=None, scale_kwargs=None):
"""Eagerly instantiate distributions at construction time to catch bad kwargs.
Raises ``TypeError`` immediately if the supplied kwargs are incompatible
with the distribution, rather than waiting until the layer is called.
"""
try:
if scale_dist is not None:
scale_dist(**scale_kwargs)
coef_dist(scale=1.0, **coef_kwargs)
else:
coef_dist(**coef_kwargs)
except TypeError as e:
raise TypeError(f"Invalid distribution kwargs: {e}") from e
[docs]
class BLayer(ABC):
"""Abstract base class for Bayesian layers. Lays out an interface."""
[docs]
@abstractmethod
def __init__(self, *args: Any) -> None:
"""Initialize layer parameters. This is the Bayesian model."""
[docs]
@abstractmethod
def __call__(self, *args: Any) -> Any:
"""
Run the layer's forward pass.
Args:
*args: Inputs to the layer.
Returns:
jax.Array: The result of the forward computation.
"""
[docs]
class AdaptiveLayer(BLayer):
"""Bayesian layer with adaptive prior using hierarchical modeling.
Generates coefficients from the hierarchical model
.. math::
\\lambda \\sim HalfNormal(1.)
.. math::
\\beta \\sim Normal(0., \\lambda)
"""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
"""
Args:
scale_dist: NumPyro distribution class for the scale (λ) of the
prior.
coef_dist: NumPyro distribution class for the coefficient prior.
coef_kwargs: Parameters for the prior distribution.
scale_kwargs: Parameters for the scale distribution.
"""
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Forward pass with adaptive prior on coefficients.
Args:
name: Variable name.
x: Input data array of shape ``(n, d)``.
units: Number of outputs.
activation: Activation function to apply to output.
Returns:
jax.Array: Output array of shape ``(n, u)``.
"""
x = add_trailing_dim(x)
input_shape = x.shape[1]
# sampling block
scale = sample(
name=f"{self.__class__.__name__}_{name}_scale",
fn=self.scale_dist(**self.scale_kwargs).expand([units]),
)
beta = sample(
name=f"{self.__class__.__name__}_{name}_beta",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[input_shape, units]
),
)
# matmul and return
return activation(_matmul_dot_product(x, beta))
[docs]
class FixedPriorLayer(BLayer):
"""Bayesian layer with a fixed prior distribution over coefficients.
Generates coefficients from the model
.. math::
\\beta \\sim Normal(0., 1.)
"""
[docs]
def __init__(
self,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0, "scale": 1.0},
):
"""
Args:
coef_dist: NumPyro distribution class for the coefficients.
coef_kwargs: Parameters to initialize the prior distribution.
"""
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Forward pass with fixed prior.
Args:
name: Variable name.
x: Input data array of shape ``(n, d)``.
units: Number of outputs.
activation: Activation function to apply to output.
Returns:
jax.Array: Output array of shape ``(n, u)``.
"""
x = add_trailing_dim(x)
input_shape = x.shape[1]
# sampling block
beta = sample(
name=f"{self.__class__.__name__}_{name}_beta",
fn=self.coef_dist(**self.coef_kwargs).expand([input_shape, units]),
)
# matmul and return
return activation(_matmul_dot_product(x, beta))
[docs]
class InterceptLayer(BLayer):
"""Bayesian layer with a fixed prior distribution over coefficients.
Generates coefficients from the model
.. math::
\\beta \\sim Normal(0., 1.)
"""
[docs]
def __init__(
self,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0, "scale": 1.0},
):
"""
Args:
``coef_dist``: NumPyro distribution class for the coefficients.
``coef_kwargs``: Parameters to initialize the prior distribution.
"""
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs)
[docs]
def __call__(
self,
name: str,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Forward pass with fixed prior.
Args:
name: Variable name.
units: Number of outputs.
activation: Activation function to apply to output.
Returns:
jax.Array: Output array of shape ``(1, u)``.
"""
# sampling block
beta = sample(
name=f"{self.__class__.__name__}_{name}_beta",
fn=self.coef_dist(**self.coef_kwargs).expand([1, units]),
)
return activation(beta)
[docs]
class FMLayer(BLayer):
"""Bayesian factorization machine layer with adaptive priors.
Generates coefficients from the hierarchical model
.. math::
\\lambda \\sim HalfNormal(1.)
.. math::
\\beta \\sim Normal(0., \\lambda)
The shape of ``beta`` is ``(j, l)``, where ``j`` is the number
if input covariates and ``l`` is the low rank dim.
Then performs matrix multiplication using the formula in `Rendle (2010) <https://jame-zhang.github.io/assets/algo/Factorization-Machines-Rendle2010.pdf>`_.
"""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
"""
Args:
scale_dist: Distribution for scaling factor λ.
coef_dist: Prior for beta parameters.
coef_kwargs: Arguments for prior distribution.
scale_kwargs: Arguments for λ distribution.
"""
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
low_rank_dim: int,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Forward pass through the factorization machine layer.
Args:
name: Variable name scope.
x: Input matrix of shape ``(n, d)``.
low_rank_dim: Dimensionality of low-rank approximation.
units: Number of outputs.
activation: Activation function to apply to output.
Returns:
jax.Array: Output array of shape ``(n, u)``.
"""
# get shapes and reshape if necessary
x = add_trailing_dim(x)
input_shape = x.shape[1]
# sampling block
scale = sample(
name=f"{self.__class__.__name__}_{name}_scale",
fn=self.scale_dist(**self.scale_kwargs).expand([units]),
)
theta = sample(
name=f"{self.__class__.__name__}_{name}_theta",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[input_shape, low_rank_dim, units]
),
)
# matmul and return
return activation(_matmul_factorization_machine(x, theta))
[docs]
class FM3Layer(BLayer):
"""Order 3 FM. See `Blondel et al 2016 <https://proceedings.neurips.cc/paper/2016/file/158fc2ddd52ec2cf54d3c161f2dd6517-Paper.pdf>`_."""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
"""
Args:
scale_dist: Distribution for scaling factor λ.
coef_dist: Prior for beta parameters.
coef_kwargs: Arguments for prior distribution.
scale_kwargs: Arguments for λ distribution.
"""
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
low_rank_dim: int,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Forward pass through the factorization machine layer.
Args:
name: Variable name scope.
x: Input matrix of shape ``(n, d)``.
low_rank_dim: Dimensionality of low-rank approximation.
units: Number of outputs.
activation: Activation function to apply to output.
Returns:
jax.Array: Output array of shape ``(n,)``.
"""
# get shapes and reshape if necessary
x = add_trailing_dim(x)
input_shape = x.shape[1]
# sampling block
scale = sample(
name=f"{self.__class__.__name__}_{name}_scale",
fn=self.scale_dist(**self.scale_kwargs).expand([units]),
)
theta = sample(
name=f"{self.__class__.__name__}_{name}_theta",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[input_shape, low_rank_dim, units]
),
)
# matmul and return
return activation(_matmul_fm3(x, theta))
[docs]
class LowRankInteractionLayer(BLayer):
"""Takes two sets of features and learns a low-rank interaction matrix."""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
z: jax.Array,
low_rank_dim: int,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Interaction between feature matrices X and Z in a low rank way. UV decomp.
Args:
name: Variable name scope.
x: Input matrix of shape ``(n, d1)``.
z: Input matrix of shape ``(n, d2)``.
low_rank_dim: Dimensionality of low-rank approximation.
units: Number of outputs.
activation: Activation function to apply to output.
Returns:
jax.Array: Output array of shape ``(n, u)``.
"""
# get shapes and reshape if necessary
x = add_trailing_dim(x)
z = add_trailing_dim(z)
input_shape1 = x.shape[1]
input_shape2 = z.shape[1]
# sampling block
scale1 = sample(
name=f"{self.__class__.__name__}_{name}_scale1",
fn=self.scale_dist(**self.scale_kwargs).expand([units]),
)
theta1 = sample(
name=f"{self.__class__.__name__}_{name}_theta1",
fn=self.coef_dist(scale=scale1, **self.coef_kwargs).expand(
[input_shape1, low_rank_dim, units]
),
)
scale2 = sample(
name=f"{self.__class__.__name__}_{name}_scale2",
fn=self.scale_dist(**self.scale_kwargs).expand([units]),
)
theta2 = sample(
name=f"{self.__class__.__name__}_{name}_theta2",
fn=self.coef_dist(scale=scale2, **self.coef_kwargs).expand(
[input_shape2, low_rank_dim, units]
),
)
return activation(_matmul_uv_decomp(theta1, theta2, x, z))
[docs]
class InteractionLayer(BLayer):
"""Computes every interaction coefficient between two sets of inputs."""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
z: jax.Array,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Interaction between feature matrices X and Z in a low rank way. UV decomp.
Args:
name: Variable name scope.
x: Input matrix of shape ``(n, d1)``.
z: Input matrix of shape ``(n, d2)``.
units: Number of outputs.
activation: Activation function to apply to output.
Returns:
jax.Array: Output array of shape ``(n, u)``.
"""
# get shapes and reshape if necessary
x = add_trailing_dim(x)
z = add_trailing_dim(z)
input_shape1 = x.shape[1]
input_shape2 = z.shape[1]
# sampling block
scale = sample(
name=f"{self.__class__.__name__}_{name}_scale1",
fn=self.scale_dist(**self.scale_kwargs).expand([units]),
)
beta = sample(
name=f"{self.__class__.__name__}_{name}_beta1",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[input_shape1 * input_shape2, units]
),
)
return activation(_matmul_interaction(beta, x, z))
[docs]
class BilinearLayer(BLayer):
"""Bayesian bilinear interaction layer: computes x^T W z."""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
"""
Args:
scale_dist: prior on scale of coefficients
coef_dist: distribution for coefficients
coef_kwargs: kwargs for coef distribution
scale_kwargs: kwargs for scale prior
"""
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
z: jax.Array,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Interaction between feature matrices X and Z in a low rank way. UV decomp.
Args:
name: Variable name scope.
x: Input matrix of shape ``(n, d1)``.
z: Input matrix of shape ``(n, d2)``.
units: Number of outputs.
activation: Activation function to apply to output.
Returns:
jax.Array: Output array of shape ``(n, u)``.
"""
# ensure inputs are [batch, dim]
x = add_trailing_dim(x)
z = add_trailing_dim(z)
input_shape1, input_shape2 = x.shape[1], z.shape[1]
# sample coefficient scales
scale = sample(
name=f"{self.__class__.__name__}_{name}_scale",
fn=self.scale_dist(**self.scale_kwargs).expand([units]),
)
# full W: [input_shape1, input_shape2, units]
W = sample(
name=f"{self.__class__.__name__}_{name}_W",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[input_shape1, input_shape2, units]
),
)
# bilinear form: x^T W z for each unit
return activation(jnp.einsum("ni,iju,nj->nu", x, W, z))
[docs]
class LowRankBilinearLayer(BLayer):
"""Bayesian bilinear interaction layer: computes x^T W z. W low rank."""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
"""
Args:
scale_dist: prior on scale of coefficients
coef_dist: distribution for coefficients
coef_kwargs: kwargs for coef distribution
scale_kwargs: kwargs for scale prior
"""
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
z: jax.Array,
low_rank_dim: int,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Interaction between feature matrices X and Z in a low rank way. UV decomp.
Args:
name: Variable name scope.
x: Input matrix of shape ``(n, d1)``.
z: Input matrix of shape ``(n, d2)``.
low_rank_dim: Dimensionality of low-rank approximation.
units: Number of outputs.
activation: Activation function to apply to output.
Returns:
jax.Array: Output array of shape ``(n, u)``.
"""
# ensure inputs are [batch, dim]
x = add_trailing_dim(x)
z = add_trailing_dim(z)
input_shape1, input_shape2 = x.shape[1], z.shape[1]
# sample coefficient scales
scale = sample(
name=f"{self.__class__.__name__}_{name}_scale",
fn=self.scale_dist(**self.scale_kwargs).expand([units]),
)
A = sample(
name=f"{self.__class__.__name__}_{name}_A",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[input_shape1, low_rank_dim, units]
),
)
B = sample(
name=f"{self.__class__.__name__}_{name}_B",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[input_shape2, low_rank_dim, units]
),
)
# project x and z into rank-r space, then take dot product
x_proj = jnp.einsum("ni,ilu->nlu", x, A) # [batch, rank, units]
z_proj = jnp.einsum("nj,jlu->nlu", z, B) # [batch, rank, units]
out = jnp.sum(x_proj * z_proj, axis=1) # [batch, units]
return activation(out)
# ---- Embeddings ------------------------------------------------------------ #
[docs]
class EmbeddingLayer(BLayer):
"""Bayesian embedding layer for sparse categorical features."""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
"""
Args:
scale_dist: NumPyro distribution class for the scale (λ) of the
prior.
coef_dist: NumPyro distribution class for the coefficient prior.
coef_kwargs: Parameters for the prior distribution.
scale_kwargs: Parameters for the scale distribution.
"""
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
num_categories: int,
embedding_dim: int,
) -> jax.Array:
"""
Forward pass through embedding lookup.
Args:
name: Variable name scope.
x: Integer indices indicating embeddings to use.
num_categories: The number of distinct things getting an embedding
embedding_dim: The size of each embedding, e.g. 2, 4, 8, etc.
Returns:
jax.Array: Embedding vectors of shape ``(n, m)``.
"""
# sampling block
scale = sample(
name=f"{self.__class__.__name__}_{name}_scale",
fn=self.scale_dist(**self.scale_kwargs),
)
theta = sample(
name=f"{self.__class__.__name__}_{name}_theta",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[num_categories, embedding_dim]
),
)
# matmul and return
return theta[x.squeeze()]
[docs]
class RandomEffectsLayer(BLayer):
"""Exactly like the EmbeddingLayer but with ``embedding_dim=1``."""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
"""
Args:
num_embeddings: Total number of discrete embedding entries.
embedding_dim: Dimensionality of each embedding vector.
coef_dist: Prior distribution for embedding weights.
coef_kwargs: Parameters for the prior distribution.
"""
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
num_categories: int,
) -> jax.Array:
"""
Forward pass through embedding lookup.
Args:
name: Variable name scope.
x: Integer indicating embeddings to use.
num_categories: The number of distinct things getting an embedding
Returns:
jax.Array: Embedding vectors of shape (n, embedding_dim).
"""
# sampling block
scale = sample(
name=f"{self.__class__.__name__}_{name}_scale",
fn=self.scale_dist(**self.scale_kwargs),
)
theta = sample(
name=f"{self.__class__.__name__}_{name}_theta",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[num_categories, 1]
),
)
return theta[x.squeeze()]
[docs]
class RandomWalkLayer(BLayer):
"""Random walk of embedding dim ``m``, defaults to Gaussian walk."""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
num_categories: int,
embedding_dim: int,
) -> jax.Array:
"""
Forward pass through embedding lookup.
Args:
name: Variable name scope.
x: Integer indices indicating embeddings to use.
num_categories: The number of distinct things getting an embedding
embedding_dim: The size of each embedding, e.g. 2, 4, 8, etc.
Returns:
jax.Array: Embedding vectors of shape ``(n, m)``.
"""
# sampling block
scale = sample(
name=f"{self.__class__.__name__}_{name}_scale",
fn=self.scale_dist(**self.scale_kwargs),
)
theta = sample(
name=f"{self.__class__.__name__}_{name}_theta",
fn=self.coef_dist(scale=scale, **self.coef_kwargs).expand(
[
num_categories,
embedding_dim,
]
),
)
# matmul and return
return _matmul_randomwalk(theta, x)
# ---- Sparse priors --------------------------------------------------------- #
[docs]
class HorseshoeLayer(BLayer):
"""Bayesian layer with horseshoe prior for sparse regression.
Implements the (regularized) horseshoe prior of Piironen & Vehtari (2017).
Basic horseshoe:
.. math::
\\tau \\sim HalfCauchy(1), \\quad
\\lambda_j \\sim HalfCauchy(1), \\quad
\\beta_j \\sim Normal(0,\\; \\tau \\lambda_j)
Regularized horseshoe (``slab_scale`` set) — prevents large coefficients
from escaping the slab:
.. math::
\\tilde{\\lambda}_j^2 = \\frac{c^2 \\lambda_j^2}{c^2 + \\tau^2 \\lambda_j^2},
\\quad c^2 \\sim InverseGamma(s/2,\\; s/2 \\cdot scale_{slab}^2)
"""
[docs]
def __init__(
self,
slab_scale: float | None = None,
slab_df: float = 4.0,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
):
"""
Args:
slab_scale: If set, uses the regularized horseshoe with this slab
scale. ``None`` gives the plain horseshoe.
slab_df: Degrees of freedom for the slab variance prior (only
used when ``slab_scale`` is set).
coef_dist: Distribution for the coefficients. Must accept a
``scale`` keyword (derived from the horseshoe shrinkage).
Defaults to ``Normal``.
coef_kwargs: Extra kwargs for ``coef_dist`` (beyond ``scale``).
Default ``{"loc": 0.0}``.
"""
self.slab_scale = slab_scale
self.slab_df = slab_df
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
try:
coef_dist(scale=1.0, **coef_kwargs)
except TypeError as e:
raise TypeError(f"Invalid coef_dist kwargs: {e}") from e
[docs]
def __call__(
self,
name: str,
x: jax.Array,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Forward pass with horseshoe prior on coefficients.
Args:
name: Variable name scope.
x: Input array of shape ``(n, d)``.
units: Number of output dimensions.
activation: Activation function.
Returns:
jax.Array of shape ``(n, units)``.
"""
x = add_trailing_dim(x)
d = x.shape[1]
cls = self.__class__.__name__
# Global shrinkage: one scale per output unit
tau = sample(
f"{cls}_{name}_tau",
distributions.HalfCauchy(1.0).expand([units]),
)
# Local shrinkage: one per feature per output unit
scale = sample(
f"{cls}_{name}_scale",
distributions.HalfCauchy(1.0).expand([d, units]),
)
if self.slab_scale is not None:
# Soft upper bound on coefficient size via a finite-variance slab
c2 = sample(
f"{cls}_{name}_c2",
distributions.InverseGamma(
self.slab_df / 2.0,
self.slab_df / 2.0 * self.slab_scale**2,
),
)
scale_tilde = jnp.sqrt(
c2 * scale**2 / (c2 + tau**2 * scale**2)
)
scale = tau * scale_tilde
else:
scale = tau * scale # (d, units)
beta = sample(f"{cls}_{name}_beta", self.coef_dist(scale=scale, **self.coef_kwargs))
return activation(_matmul_dot_product(x, beta))
# ---- Spike and slab -------------------------------------------------------- #
[docs]
class SpikeAndSlabLayer(BLayer):
"""Sparse regression via a spike-and-slab prior.
Each coefficient has a Beta-distributed inclusion weight ``z_j`` in
(0, 1). Included features (``z_j ≈ 1``) take the full slab coefficient;
excluded features (``z_j ≈ 0``) are gated toward zero (the spike).
Generative model::
z_j ~ Beta(alpha, beta) # inclusion weight (hardcoded Beta)
β_j ~ coef_dist(**coef_kwargs) # slab coefficient
y ~ link(z · β · x, ...) # z gates each coefficient
The default ``Beta(0.5, 0.5)`` (Jeffreys prior) places mass near 0 and 1,
encouraging features to be clearly included or excluded. The posterior
mean of ``z_j`` approximates ``P(feature j included | data)``.
The slab distribution defaults to ``Normal(0, 1)`` but can be swapped for
e.g. ``StudentT`` for heavier-tailed slab behaviour.
Args:
alpha: First concentration parameter of the Beta prior on ``z``.
beta: Second concentration parameter of the Beta prior on ``z``.
coef_dist: Distribution for the slab coefficients.
coef_kwargs: Kwargs for ``coef_dist``.
"""
[docs]
def __init__(
self,
alpha: float = 0.5,
beta: float = 0.5,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0, "scale": 1.0},
):
self.alpha = alpha
self.beta = beta
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Args:
name: Variable name scope.
x: Input of shape ``(n, d)``.
units: Number of output dimensions.
activation: Activation function.
Returns:
jax.Array of shape ``(n, units)``.
"""
x = add_trailing_dim(x)
d = x.shape[1]
cls = self.__class__.__name__
# Inclusion weight: posterior z_j ≈ P(feature j included | data)
z = sample(
f"{cls}_{name}_z",
distributions.Beta(self.alpha, self.beta).expand([d, units]),
)
# Slab coefficients
beta = sample(
f"{cls}_{name}_beta",
self.coef_dist(**self.coef_kwargs).expand([d, units]),
)
# Gate: z≈1 → full slab value; z≈0 → near zero (spike at 0)
return activation(_matmul_dot_product(x, z * beta))
# ---- Attention ------------------------------------------------------------- #
[docs]
class AttentionLayer(BLayer):
"""Multi-head Bayesian self-attention over the feature dimension.
Treats the ``d`` input features as tokens using FT-Transformer style
tokenisation (Gorishniy et al. 2021, https://arxiv.org/abs/2106.11959):
each feature gets a per-column bias embedding (identity) plus a
value-scaled embedding, so tokens are distinct even when the feature
value is zero.
For each observation ``x_i ∈ R^d``:
1. Tokenise: ``H_j = x_{i,j} · W_emb_j + W_bias_j`` (``head_dim``-dim each)
2. Per head: ``Q_m, K_m, V_m = H W_Q_m, H W_K_m, H W_V_m``
3. ``Attn_m = softmax(Q_m K_m^T / √h_k)``
4. Concatenate heads → mean-pool over features → project to ``units``
Requires ``d ≥ 2`` for attention to be non-trivial.
Total embedding dimension is ``head_dim * num_heads`` — adding heads
increases capacity rather than splitting a fixed budget.
"""
[docs]
def __init__(
self,
scale_dist: distributions.Distribution = distributions.HalfNormal,
coef_dist: distributions.Distribution = distributions.Normal,
coef_kwargs: dict[str, float] = {"loc": 0.0},
scale_kwargs: dict[str, float] = {"scale": 1.0},
):
self.scale_dist = scale_dist
self.coef_dist = coef_dist
self.coef_kwargs = coef_kwargs
self.scale_kwargs = scale_kwargs
_validate_prior_kwargs(coef_dist, coef_kwargs, scale_dist, scale_kwargs)
[docs]
def __call__(
self,
name: str,
x: jax.Array,
head_dim: int = 8,
num_heads: int = 1,
units: int = 1,
activation: Callable[[jax.Array], jax.Array] = jnn.identity,
) -> jax.Array:
"""
Args:
name: Variable name scope.
x: Input of shape ``(n, d)``. Each column is a feature token.
head_dim: Dimension of each individual head. Total embedding
dimension is ``head_dim * num_heads``, so adding heads
increases capacity.
num_heads: Number of attention heads.
units: Number of output dimensions.
activation: Activation function.
Returns:
jax.Array of shape ``(n, units)``.
"""
x = add_trailing_dim(x)
n, d = x.shape[0], x.shape[1]
h_k = head_dim # per-head dimension
m = num_heads
h = head_dim * m # total embedding dimension
cls = self.__class__.__name__
# FT-Transformer tokenisation: value scaling + per-column bias
# H[i,j] = x[i,j] * W_emb[j] + W_bias[j] → (n, d, h)
scale_emb = sample(
f"{cls}_{name}_scale_emb",
self.scale_dist(**self.scale_kwargs).expand([h]),
)
W_emb = sample(
f"{cls}_{name}_W_emb",
self.coef_dist(scale=scale_emb, **self.coef_kwargs).expand([d, h]),
)
W_bias = sample(
f"{cls}_{name}_W_bias",
self.coef_dist(scale=scale_emb, **self.coef_kwargs).expand([d, h]),
)
H = x[:, :, None] * W_emb[None, :, :] + W_bias[None, :, :] # (n, d, h)
# Q, K, V projections — one set per head: (m, h, h_k)
# scale_qkv is (m, h_k); unsqueeze to (m, 1, h_k) so it broadcasts to (m, h, h_k)
scale_qkv = sample(
f"{cls}_{name}_scale_qkv",
self.scale_dist(**self.scale_kwargs).expand([m, h_k]),
)
scale_qkv_bc = scale_qkv[:, None, :] # (m, 1, h_k)
W_Q = sample(
f"{cls}_{name}_W_Q",
self.coef_dist(scale=scale_qkv_bc, **self.coef_kwargs).expand([m, h, h_k]),
)
W_K = sample(
f"{cls}_{name}_W_K",
self.coef_dist(scale=scale_qkv_bc, **self.coef_kwargs).expand([m, h, h_k]),
)
W_V = sample(
f"{cls}_{name}_W_V",
self.coef_dist(scale=scale_qkv_bc, **self.coef_kwargs).expand([m, h, h_k]),
)
# Project to per-head Q/K/V: (n, d, m, h_k)
Q = jnp.einsum("ndh,mhk->ndmk", H, W_Q)
K = jnp.einsum("ndh,mhk->ndmk", H, W_K)
V = jnp.einsum("ndh,mhk->ndmk", H, W_V)
# Scaled dot-product attention per head: (n, m, d, d)
scores = jnp.einsum("ndmk,nqmk->nmdq", Q, K) / h_k**0.5
weights = jax.nn.softmax(scores, axis=-1)
out = jnp.einsum("nmdq,nqmk->ndmk", weights, V) # (n, d, m, h_k)
# Concatenate heads, mean-pool over features: (n, h)
pooled = out.reshape(n, d, h).mean(axis=1)
# Output projection
scale_out = sample(
f"{cls}_{name}_scale_out",
self.scale_dist(**self.scale_kwargs).expand([units]),
)
W_out = sample(
f"{cls}_{name}_W_out",
self.coef_dist(scale=scale_out, **self.coef_kwargs).expand([h, units]),
)
return activation(pooled @ W_out)