Layers#

Implements Bayesian Layers using Jax and Numpyro.

Design:
  • There are three levels of complexity here: class-level, instance-level, and call-level

  • The class-level handles things like choosing generic model form and how to multiply coefficents with data. Defined by the class Layer(BLayer) def itself.

  • The instance-level handles specific distributions that fit into a generic model and the initial parameters for those distributions. Defined by creating an instance of the class: Layer(*args, **kwargs).

  • The call-level handles seeing a batch of data, sampling from the distributions defined on the class and multiplying coefficients and data to produce an output, works like result = Layer(*args, **kwargs)(data)

Notation:
  • n: observations in a batch

  • c: number of categories of things for time, random effects, etc

  • d: number of coefficients

  • l: low rank dimension of low rank models

  • m: embedding dimension

  • u: units aka output dimension

blayers.layers.pairwise_interactions(x, z)[source]#

Compute all pairwise interactions between features in X and Y.

Parameters:
  • X – (n_samples, n_features1)

  • Y – (n_samples, n_features2)

  • x (Array)

  • z (Array)

Returns:

(n_samples, n_features1 * n_features2)

Return type:

interactions

class blayers.layers.BLayer(*args)[source]#

Bases: ABC

Abstract base class for Bayesian layers. Lays out an interface.

Parameters:

args (Any)

abstractmethod __init__(*args)[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:

args (Any)

Return type:

None

abstractmethod __call__(*args)[source]#

Run the layer’s forward pass.

Parameters:

*args (Any) – Inputs to the layer.

Returns:

The result of the forward computation.

Return type:

jax.Array

class blayers.layers.AdaptiveLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian layer with adaptive prior using hierarchical modeling.

Generates coefficients from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\beta \sim Normal(0., \lambda)\]
Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – NumPyro distribution class for the scale (λ) of the prior.

  • coef_dist (Distribution) – NumPyro distribution class for the coefficient prior.

  • coef_kwargs (dict[str, float]) – Parameters for the prior distribution.

  • scale_kwargs (dict[str, float]) – Parameters for the scale distribution.

__call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass with adaptive prior on coefficients.

Parameters:
  • name (str) – Variable name.

  • x (Array) – Input data array of shape (n, d).

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.FixedPriorLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#

Bases: BLayer

Bayesian layer with a fixed prior distribution over coefficients.

Generates coefficients from the model

\[\beta \sim Normal(0., 1.)\]
Parameters:
  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

__init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Parameters:
  • coef_dist (Distribution) – NumPyro distribution class for the coefficients.

  • coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.

__call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass with fixed prior.

Parameters:
  • name (str) – Variable name.

  • x (Array) – Input data array of shape (n, d).

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.InterceptLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#

Bases: BLayer

Bayesian intercept (bias) term with a fixed prior.

Samples a scalar bias from

\[\beta \sim Normal(0., 1.)\]

and broadcasts it to every observation. No input x is needed.

Parameters:
  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

__init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Parameters:
  • coef_dist (Distribution) – NumPyro distribution class for the coefficients.

  • coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.

__call__(name, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass with fixed prior.

Parameters:
  • name (str) – Variable name.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (1, u).

Return type:

jax.Array

class blayers.layers.FMLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian factorization machine layer with adaptive priors.

Generates coefficients from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\beta \sim Normal(0., \lambda)\]

The shape of beta is (j, l), where j is the number if input covariates and l is the low rank dim.

Then performs matrix multiplication using the formula in Rendle (2010).

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – Distribution for scaling factor λ.

  • coef_dist (Distribution) – Prior for beta parameters.

  • coef_kwargs (dict[str, float]) – Arguments for prior distribution.

  • scale_kwargs (dict[str, float]) – Arguments for λ distribution.

__call__(name, x, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass through the factorization machine layer.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d).

  • low_rank_dim (int) – Dimensionality of low-rank approximation.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.FM3Layer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian order-3 factorization machine layer with adaptive prior.

Samples low-rank factors from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\theta \sim Normal(0., \lambda), \quad \theta \in \mathbb{R}^{d \times l}\]

Then computes the 3rd-order ANOVA kernel via Newton’s identity (Blondel et al. 2016). Defining power sums \(p_k = \sum_i x_i^k \theta_i^k\):

\[\text{output} = \frac{p_1^3 - 3\, p_2\, p_1 + 2\, p_3}{6}\]

This efficiently computes all 3rd-order interaction terms without enumerating all \(\binom{d}{3}\) triples.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – Distribution for scaling factor λ.

  • coef_dist (Distribution) – Prior for beta parameters.

  • coef_kwargs (dict[str, float]) – Arguments for prior distribution.

  • scale_kwargs (dict[str, float]) – Arguments for λ distribution.

__call__(name, x, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass through the factorization machine layer.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d).

  • low_rank_dim (int) – Dimensionality of low-rank approximation.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n,).

Return type:

jax.Array

class blayers.layers.LowRankInteractionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian low-rank bilinear interaction between two feature sets (UV decomposition).

Samples separate low-rank projections for x and z from the hierarchical model

\[\lambda_1 \sim HalfNormal(1.), \quad \theta_1 \sim Normal(0., \lambda_1), \quad \theta_1 \in \mathbb{R}^{d_1 \times l}\]
\[\lambda_2 \sim HalfNormal(1.), \quad \theta_2 \sim Normal(0., \lambda_2), \quad \theta_2 \in \mathbb{R}^{d_2 \times l}\]

and computes the element-wise product of the projections, summed over the low-rank dimension:

\[\text{output} = \sum_{r=1}^{l} (x \theta_1)_r \cdot (z \theta_2)_r = x^\top (\theta_1 \theta_2^\top) z\]

This is equivalent to a rank-\(l\) approximation of the full bilinear form \(x^\top W z\).

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__call__(name, x, z, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#

Interaction between feature matrices X and Z in a low rank way. UV decomp.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d1).

  • z (Array) – Input matrix of shape (n, d2).

  • low_rank_dim (int) – Dimensionality of low-rank approximation.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.InteractionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian full pairwise interaction layer with adaptive prior.

Samples one coefficient per pair of features from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\beta \sim Normal(0., \lambda), \quad \beta \in \mathbb{R}^{d_1 d_2}\]

and computes the weighted sum of all outer-product interactions:

\[\text{output} = (x \otimes z)\, \beta\]

where \(x \otimes z\) is the flattened outer product of shape \((n, d_1 d_2)\). For large inputs this scales as \(O(d_1 d_2)\) parameters; prefer LowRankInteractionLayer when \(d_1\) or \(d_2\) is large.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__call__(name, x, z, units=1, activation=<PjitFunction of <function identity>>)[source]#

Interaction between feature matrices X and Z in a low rank way. UV decomp.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d1).

  • z (Array) – Input matrix of shape (n, d2).

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.BilinearLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian full bilinear layer with adaptive prior.

Samples a full interaction matrix from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[W \sim Normal(0., \lambda), \quad W \in \mathbb{R}^{d_1 \times d_2}\]

and computes the bilinear form:

\[\text{output} = x^\top W z\]

This learns a distinct weight for every pair \((x_i, z_j)\), making it the densest two-input layer. Has \(O(d_1 d_2)\) parameters; prefer LowRankBilinearLayer when dimensions are large.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – prior on scale of coefficients

  • coef_dist (Distribution) – distribution for coefficients

  • coef_kwargs (dict[str, float]) – kwargs for coef distribution

  • scale_kwargs (dict[str, float]) – kwargs for scale prior

__call__(name, x, z, units=1, activation=<PjitFunction of <function identity>>)[source]#

Interaction between feature matrices X and Z in a low rank way. UV decomp.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d1).

  • z (Array) – Input matrix of shape (n, d2).

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.LowRankBilinearLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian low-rank bilinear layer with adaptive prior.

Samples shared-scale low-rank factors for both inputs from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[A \sim Normal(0., \lambda), \quad A \in \mathbb{R}^{d_1 \times l}\]
\[B \sim Normal(0., \lambda), \quad B \in \mathbb{R}^{d_2 \times l}\]

and computes the bilinear form with a rank-\(l\) weight matrix \(W = AB^\top\):

\[\text{output} = x^\top W z = (xA) \cdot (zB)\]

Compared to LowRankInteractionLayer, A and B share a single scale \(\lambda\), tying the regularisation across both inputs.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – prior on scale of coefficients

  • coef_dist (Distribution) – distribution for coefficients

  • coef_kwargs (dict[str, float]) – kwargs for coef distribution

  • scale_kwargs (dict[str, float]) – kwargs for scale prior

__call__(name, x, z, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#

Interaction between feature matrices X and Z in a low rank way. UV decomp.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d1).

  • z (Array) – Input matrix of shape (n, d2).

  • low_rank_dim (int) – Dimensionality of low-rank approximation.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.EmbeddingLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian embedding layer for sparse categorical features.

Samples an embedding table from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\theta \sim Normal(0., \lambda), \quad \theta \in \mathbb{R}^{c \times m}\]

and performs a lookup for each observation:

\[\text{output}_i = \theta[x_i]\]

where \(c\) is the number of categories and \(m\) is the embedding dimension. For \(m = 1\) prefer RandomEffectsLayer.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – NumPyro distribution class for the scale (λ) of the prior.

  • coef_dist (Distribution) – NumPyro distribution class for the coefficient prior.

  • coef_kwargs (dict[str, float]) – Parameters for the prior distribution.

  • scale_kwargs (dict[str, float]) – Parameters for the scale distribution.

__call__(name, x, num_categories, embedding_dim)[source]#

Forward pass through embedding lookup.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Integer indices indicating embeddings to use.

  • num_categories (int) – The number of distinct things getting an embedding

  • embedding_dim (int) – The size of each embedding, e.g. 2, 4, 8, etc.

Returns:

Embedding vectors of shape (n, m).

Return type:

jax.Array

class blayers.layers.RandomEffectsLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian random-effects layer — a scalar embedding per category.

Special case of EmbeddingLayer with embedding_dim=1. Samples one scalar random effect per category from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\theta \sim Normal(0., \lambda), \quad \theta \in \mathbb{R}^{c}\]

and returns the scalar for each observation’s category:

\[\text{output}_i = \theta[x_i]\]

Equivalent to a classical mixed-effects intercept with a learned variance \(\lambda^2\).

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • num_embeddings – Total number of discrete embedding entries.

  • embedding_dim – Dimensionality of each embedding vector.

  • coef_dist (Distribution) – Prior distribution for embedding weights.

  • coef_kwargs (dict[str, float]) – Parameters for the prior distribution.

  • scale_dist (Distribution)

  • scale_kwargs (dict[str, float])

__call__(name, x, num_categories)[source]#

Forward pass through embedding lookup.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Integer indicating embeddings to use.

  • num_categories (int) – The number of distinct things getting an embedding

Returns:

Embedding vectors of shape (n, embedding_dim).

Return type:

jax.Array

class blayers.layers.RandomWalkLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian Gaussian random walk over ordered categories.

Samples i.i.d. increments from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\delta_t \sim Normal(0., \lambda), \quad t = 1, \ldots, c\]

and accumulates them into positions via a cumulative sum:

\[\theta_t = \sum_{s=1}^{t} \delta_s\]

Each observation is then assigned the position of its category:

\[\text{output}_i = \theta[x_i]\]

The embedding_dim m runs m independent walks in parallel, producing output of shape (n, m). Typical use: a time index where adjacent periods share information through the walk prior.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__call__(name, x, num_categories, embedding_dim)[source]#

Forward pass through embedding lookup.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Integer indices indicating embeddings to use.

  • num_categories (int) – The number of distinct things getting an embedding

  • embedding_dim (int) – The size of each embedding, e.g. 2, 4, 8, etc.

Returns:

Embedding vectors of shape (n, m).

Return type:

jax.Array

class blayers.layers.HorseshoeLayer(slab_scale=None, slab_df=4.0, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0})[source]#

Bases: BLayer

Bayesian layer with horseshoe prior for sparse regression.

Implements the (regularized) horseshoe prior of Piironen & Vehtari (2017).

Basic horseshoe:

\[\tau \sim HalfCauchy(1), \quad \lambda_j \sim HalfCauchy(1), \quad \beta_j \sim Normal(0,\; \tau \lambda_j)\]

Regularized horseshoe (slab_scale set) — prevents large coefficients from escaping the slab:

\[\tilde{\lambda}_j^2 = \frac{c^2 \lambda_j^2}{c^2 + \tau^2 \lambda_j^2}, \quad c^2 \sim InverseGamma(s/2,\; s/2 \cdot scale_{slab}^2)\]
Parameters:
  • slab_scale (float | None)

  • slab_df (float)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

__init__(slab_scale=None, slab_df=4.0, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0})[source]#
Parameters:
  • slab_scale (float | None) – If set, uses the regularized horseshoe with this slab scale. None gives the plain horseshoe.

  • slab_df (float) – Degrees of freedom for the slab variance prior (only used when slab_scale is set).

  • coef_dist (Distribution) – Distribution for the coefficients. Must accept a scale keyword (derived from the horseshoe shrinkage). Defaults to Normal.

  • coef_kwargs (dict[str, float]) – Extra kwargs for coef_dist (beyond scale). Default {"loc": 0.0}.

__call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass with horseshoe prior on coefficients.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input array of shape (n, d).

  • units (int) – Number of output dimensions.

  • activation (Callable[[Array], Array]) – Activation function.

Returns:

jax.Array of shape (n, units).

Return type:

Array

class blayers.layers.SpikeAndSlabLayer(alpha=0.5, beta=0.5, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#

Bases: BLayer

Sparse regression via a spike-and-slab prior.

Each coefficient has a Beta-distributed inclusion weight z_j in (0, 1). Included features (z_j 1) take the full slab coefficient; excluded features (z_j 0) are gated toward zero (the spike).

Generative model:

z_j ~ Beta(alpha, beta)          # inclusion weight (hardcoded Beta)
β_j ~ coef_dist(**coef_kwargs)   # slab coefficient
y   ~ link(z · β · x, ...)       # z gates each coefficient

The default Beta(0.5, 0.5) (Jeffreys prior) places mass near 0 and 1, encouraging features to be clearly included or excluded. The posterior mean of z_j approximates P(feature j included | data).

The slab distribution defaults to Normal(0, 1) but can be swapped for e.g. StudentT for heavier-tailed slab behaviour.

Parameters:
  • alpha (float) – First concentration parameter of the Beta prior on z.

  • beta (float) – Second concentration parameter of the Beta prior on z.

  • coef_dist (Distribution) – Distribution for the slab coefficients.

  • coef_kwargs (dict[str, float]) – Kwargs for coef_dist.

__init__(alpha=0.5, beta=0.5, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • alpha (float)

  • beta (float)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

__call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input of shape (n, d).

  • units (int) – Number of output dimensions.

  • activation (Callable[[Array], Array]) – Activation function.

Returns:

jax.Array of shape (n, units).

Return type:

Array

class blayers.layers.AttentionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Multi-head Bayesian self-attention over the feature dimension.

Treats the d input features as tokens using FT-Transformer style tokenisation (Gorishniy et al. 2021, https://arxiv.org/abs/2106.11959): each feature gets a per-column bias embedding (identity) plus a value-scaled embedding, so tokens are distinct even when the feature value is zero.

For each observation x_i R^d:

  1. Tokenise: H_j = x_{i,j} · W_emb_j + W_bias_j (head_dim-dim each)

  2. Per head: Q_m, K_m, V_m = H W_Q_m, H W_K_m, H W_V_m

  3. Attn_m = softmax(Q_m K_m^T / √h_k)

  4. Concatenate heads → mean-pool over features → project to units

Requires d 2 for attention to be non-trivial. Total embedding dimension is head_dim * num_heads — adding heads increases capacity rather than splitting a fixed budget.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__call__(name, x, head_dim=8, num_heads=1, units=1, activation=<PjitFunction of <function identity>>)[source]#
Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input of shape (n, d). Each column is a feature token.

  • head_dim (int) – Dimension of each individual head. Total embedding dimension is head_dim * num_heads, so adding heads increases capacity.

  • num_heads (int) – Number of attention heads.

  • units (int) – Number of output dimensions.

  • activation (Callable[[Array], Array]) – Activation function.

Returns:

jax.Array of shape (n, units).

Return type:

Array