Layers#
Implements Bayesian Layers using Jax and Numpyro.
- Design:
There are three levels of complexity here: class-level, instance-level, and call-level
The class-level handles things like choosing generic model form and how to multiply coefficents with data. Defined by the
class Layer(BLayer)def itself.The instance-level handles specific distributions that fit into a generic model and the initial parameters for those distributions. Defined by creating an instance of the class:
Layer(*args, **kwargs).The call-level handles seeing a batch of data, sampling from the distributions defined on the class and multiplying coefficients and data to produce an output, works like
result = Layer(*args, **kwargs)(data)
- Notation:
n: observations in a batchc: number of categories of things for time, random effects, etcd: number of coefficientsl: low rank dimension of low rank modelsm: embedding dimensionu: units aka output dimension
- blayers.layers.pairwise_interactions(x, z)[source]#
Compute all pairwise interactions between features in X and Y.
- Parameters:
X – (n_samples, n_features1)
Y – (n_samples, n_features2)
x (Array)
z (Array)
- Returns:
(n_samples, n_features1 * n_features2)
- Return type:
interactions
- class blayers.layers.BLayer(*args)[source]#
Bases:
ABCAbstract base class for Bayesian layers. Lays out an interface.
- Parameters:
args (Any)
- class blayers.layers.AdaptiveLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian layer with adaptive prior using hierarchical modeling.
Generates coefficients from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[\beta \sim Normal(0., \lambda)\]- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – NumPyro distribution class for the scale (λ) of the prior.
coef_dist (Distribution) – NumPyro distribution class for the coefficient prior.
coef_kwargs (dict[str, float]) – Parameters for the prior distribution.
scale_kwargs (dict[str, float]) – Parameters for the scale distribution.
- __call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass with adaptive prior on coefficients.
- Parameters:
name (str) – Variable name.
x (Array) – Input data array of shape
(n, d).units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.FixedPriorLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Bases:
BLayerBayesian layer with a fixed prior distribution over coefficients.
Generates coefficients from the model
\[\beta \sim Normal(0., 1.)\]- Parameters:
coef_dist (Distribution)
coef_kwargs (dict[str, float])
- __init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
- Parameters:
coef_dist (Distribution) – NumPyro distribution class for the coefficients.
coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.
- __call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass with fixed prior.
- Parameters:
name (str) – Variable name.
x (Array) – Input data array of shape
(n, d).units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.InterceptLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Bases:
BLayerBayesian intercept (bias) term with a fixed prior.
Samples a scalar bias from
\[\beta \sim Normal(0., 1.)\]and broadcasts it to every observation. No input
xis needed.- Parameters:
coef_dist (Distribution)
coef_kwargs (dict[str, float])
- __init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
- Parameters:
coef_dist (Distribution) – NumPyro distribution class for the coefficients.
coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.
- __call__(name, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass with fixed prior.
- Parameters:
name (str) – Variable name.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(1, u).- Return type:
jax.Array
- class blayers.layers.FMLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian factorization machine layer with adaptive priors.
Generates coefficients from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[\beta \sim Normal(0., \lambda)\]The shape of
betais(j, l), wherejis the number if input covariates andlis the low rank dim.Then performs matrix multiplication using the formula in Rendle (2010).
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – Distribution for scaling factor λ.
coef_dist (Distribution) – Prior for beta parameters.
coef_kwargs (dict[str, float]) – Arguments for prior distribution.
scale_kwargs (dict[str, float]) – Arguments for λ distribution.
- __call__(name, x, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass through the factorization machine layer.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d).low_rank_dim (int) – Dimensionality of low-rank approximation.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.FM3Layer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian order-3 factorization machine layer with adaptive prior.
Samples low-rank factors from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[\theta \sim Normal(0., \lambda), \quad \theta \in \mathbb{R}^{d \times l}\]Then computes the 3rd-order ANOVA kernel via Newton’s identity (Blondel et al. 2016). Defining power sums \(p_k = \sum_i x_i^k \theta_i^k\):
\[\text{output} = \frac{p_1^3 - 3\, p_2\, p_1 + 2\, p_3}{6}\]This efficiently computes all 3rd-order interaction terms without enumerating all \(\binom{d}{3}\) triples.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – Distribution for scaling factor λ.
coef_dist (Distribution) – Prior for beta parameters.
coef_kwargs (dict[str, float]) – Arguments for prior distribution.
scale_kwargs (dict[str, float]) – Arguments for λ distribution.
- __call__(name, x, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass through the factorization machine layer.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d).low_rank_dim (int) – Dimensionality of low-rank approximation.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n,).- Return type:
jax.Array
- class blayers.layers.LowRankInteractionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian low-rank bilinear interaction between two feature sets (UV decomposition).
Samples separate low-rank projections for
xandzfrom the hierarchical model\[\lambda_1 \sim HalfNormal(1.), \quad \theta_1 \sim Normal(0., \lambda_1), \quad \theta_1 \in \mathbb{R}^{d_1 \times l}\]\[\lambda_2 \sim HalfNormal(1.), \quad \theta_2 \sim Normal(0., \lambda_2), \quad \theta_2 \in \mathbb{R}^{d_2 \times l}\]and computes the element-wise product of the projections, summed over the low-rank dimension:
\[\text{output} = \sum_{r=1}^{l} (x \theta_1)_r \cdot (z \theta_2)_r = x^\top (\theta_1 \theta_2^\top) z\]This is equivalent to a rank-\(l\) approximation of the full bilinear form \(x^\top W z\).
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __call__(name, x, z, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#
Interaction between feature matrices X and Z in a low rank way. UV decomp.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d1).z (Array) – Input matrix of shape
(n, d2).low_rank_dim (int) – Dimensionality of low-rank approximation.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.InteractionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian full pairwise interaction layer with adaptive prior.
Samples one coefficient per pair of features from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[\beta \sim Normal(0., \lambda), \quad \beta \in \mathbb{R}^{d_1 d_2}\]and computes the weighted sum of all outer-product interactions:
\[\text{output} = (x \otimes z)\, \beta\]where \(x \otimes z\) is the flattened outer product of shape \((n, d_1 d_2)\). For large inputs this scales as \(O(d_1 d_2)\) parameters; prefer
LowRankInteractionLayerwhen \(d_1\) or \(d_2\) is large.- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __call__(name, x, z, units=1, activation=<PjitFunction of <function identity>>)[source]#
Interaction between feature matrices X and Z in a low rank way. UV decomp.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d1).z (Array) – Input matrix of shape
(n, d2).units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.BilinearLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian full bilinear layer with adaptive prior.
Samples a full interaction matrix from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[W \sim Normal(0., \lambda), \quad W \in \mathbb{R}^{d_1 \times d_2}\]and computes the bilinear form:
\[\text{output} = x^\top W z\]This learns a distinct weight for every pair \((x_i, z_j)\), making it the densest two-input layer. Has \(O(d_1 d_2)\) parameters; prefer
LowRankBilinearLayerwhen dimensions are large.- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – prior on scale of coefficients
coef_dist (Distribution) – distribution for coefficients
coef_kwargs (dict[str, float]) – kwargs for coef distribution
scale_kwargs (dict[str, float]) – kwargs for scale prior
- __call__(name, x, z, units=1, activation=<PjitFunction of <function identity>>)[source]#
Interaction between feature matrices X and Z in a low rank way. UV decomp.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d1).z (Array) – Input matrix of shape
(n, d2).units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.LowRankBilinearLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian low-rank bilinear layer with adaptive prior.
Samples shared-scale low-rank factors for both inputs from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[A \sim Normal(0., \lambda), \quad A \in \mathbb{R}^{d_1 \times l}\]\[B \sim Normal(0., \lambda), \quad B \in \mathbb{R}^{d_2 \times l}\]and computes the bilinear form with a rank-\(l\) weight matrix \(W = AB^\top\):
\[\text{output} = x^\top W z = (xA) \cdot (zB)\]Compared to
LowRankInteractionLayer,AandBshare a single scale \(\lambda\), tying the regularisation across both inputs.- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – prior on scale of coefficients
coef_dist (Distribution) – distribution for coefficients
coef_kwargs (dict[str, float]) – kwargs for coef distribution
scale_kwargs (dict[str, float]) – kwargs for scale prior
- __call__(name, x, z, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#
Interaction between feature matrices X and Z in a low rank way. UV decomp.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d1).z (Array) – Input matrix of shape
(n, d2).low_rank_dim (int) – Dimensionality of low-rank approximation.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.EmbeddingLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian embedding layer for sparse categorical features.
Samples an embedding table from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[\theta \sim Normal(0., \lambda), \quad \theta \in \mathbb{R}^{c \times m}\]and performs a lookup for each observation:
\[\text{output}_i = \theta[x_i]\]where \(c\) is the number of categories and \(m\) is the embedding dimension. For \(m = 1\) prefer
RandomEffectsLayer.- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – NumPyro distribution class for the scale (λ) of the prior.
coef_dist (Distribution) – NumPyro distribution class for the coefficient prior.
coef_kwargs (dict[str, float]) – Parameters for the prior distribution.
scale_kwargs (dict[str, float]) – Parameters for the scale distribution.
- __call__(name, x, num_categories, embedding_dim)[source]#
Forward pass through embedding lookup.
- Parameters:
name (str) – Variable name scope.
x (Array) – Integer indices indicating embeddings to use.
num_categories (int) – The number of distinct things getting an embedding
embedding_dim (int) – The size of each embedding, e.g. 2, 4, 8, etc.
- Returns:
Embedding vectors of shape
(n, m).- Return type:
jax.Array
- class blayers.layers.RandomEffectsLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian random-effects layer — a scalar embedding per category.
Special case of
EmbeddingLayerwithembedding_dim=1. Samples one scalar random effect per category from the hierarchical model\[\lambda \sim HalfNormal(1.)\]\[\theta \sim Normal(0., \lambda), \quad \theta \in \mathbb{R}^{c}\]and returns the scalar for each observation’s category:
\[\text{output}_i = \theta[x_i]\]Equivalent to a classical mixed-effects intercept with a learned variance \(\lambda^2\).
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
num_embeddings – Total number of discrete embedding entries.
embedding_dim – Dimensionality of each embedding vector.
coef_dist (Distribution) – Prior distribution for embedding weights.
coef_kwargs (dict[str, float]) – Parameters for the prior distribution.
scale_dist (Distribution)
scale_kwargs (dict[str, float])
- __call__(name, x, num_categories)[source]#
Forward pass through embedding lookup.
- Parameters:
name (str) – Variable name scope.
x (Array) – Integer indicating embeddings to use.
num_categories (int) – The number of distinct things getting an embedding
- Returns:
Embedding vectors of shape (n, embedding_dim).
- Return type:
jax.Array
- class blayers.layers.RandomWalkLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian Gaussian random walk over ordered categories.
Samples i.i.d. increments from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[\delta_t \sim Normal(0., \lambda), \quad t = 1, \ldots, c\]and accumulates them into positions via a cumulative sum:
\[\theta_t = \sum_{s=1}^{t} \delta_s\]Each observation is then assigned the position of its category:
\[\text{output}_i = \theta[x_i]\]The
embedding_dimmrunsmindependent walks in parallel, producing output of shape(n, m). Typical use: a time index where adjacent periods share information through the walk prior.- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __call__(name, x, num_categories, embedding_dim)[source]#
Forward pass through embedding lookup.
- Parameters:
name (str) – Variable name scope.
x (Array) – Integer indices indicating embeddings to use.
num_categories (int) – The number of distinct things getting an embedding
embedding_dim (int) – The size of each embedding, e.g. 2, 4, 8, etc.
- Returns:
Embedding vectors of shape
(n, m).- Return type:
jax.Array
- class blayers.layers.HorseshoeLayer(slab_scale=None, slab_df=4.0, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0})[source]#
Bases:
BLayerBayesian layer with horseshoe prior for sparse regression.
Implements the (regularized) horseshoe prior of Piironen & Vehtari (2017).
Basic horseshoe:
\[\tau \sim HalfCauchy(1), \quad \lambda_j \sim HalfCauchy(1), \quad \beta_j \sim Normal(0,\; \tau \lambda_j)\]Regularized horseshoe (
slab_scaleset) — prevents large coefficients from escaping the slab:\[\tilde{\lambda}_j^2 = \frac{c^2 \lambda_j^2}{c^2 + \tau^2 \lambda_j^2}, \quad c^2 \sim InverseGamma(s/2,\; s/2 \cdot scale_{slab}^2)\]- Parameters:
slab_scale (float | None)
slab_df (float)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
- __init__(slab_scale=None, slab_df=4.0, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0})[source]#
- Parameters:
slab_scale (float | None) – If set, uses the regularized horseshoe with this slab scale.
Nonegives the plain horseshoe.slab_df (float) – Degrees of freedom for the slab variance prior (only used when
slab_scaleis set).coef_dist (Distribution) – Distribution for the coefficients. Must accept a
scalekeyword (derived from the horseshoe shrinkage). Defaults toNormal.coef_kwargs (dict[str, float]) – Extra kwargs for
coef_dist(beyondscale). Default{"loc": 0.0}.
- __call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass with horseshoe prior on coefficients.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input array of shape
(n, d).units (int) – Number of output dimensions.
activation (Callable[[Array], Array]) – Activation function.
- Returns:
jax.Array of shape
(n, units).- Return type:
Array
- class blayers.layers.SpikeAndSlabLayer(alpha=0.5, beta=0.5, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Bases:
BLayerSparse regression via a spike-and-slab prior.
Each coefficient has a Beta-distributed inclusion weight
z_jin (0, 1). Included features (z_j ≈ 1) take the full slab coefficient; excluded features (z_j ≈ 0) are gated toward zero (the spike).Generative model:
z_j ~ Beta(alpha, beta) # inclusion weight (hardcoded Beta) β_j ~ coef_dist(**coef_kwargs) # slab coefficient y ~ link(z · β · x, ...) # z gates each coefficient
The default
Beta(0.5, 0.5)(Jeffreys prior) places mass near 0 and 1, encouraging features to be clearly included or excluded. The posterior mean ofz_japproximatesP(feature j included | data).The slab distribution defaults to
Normal(0, 1)but can be swapped for e.g.StudentTfor heavier-tailed slab behaviour.- Parameters:
alpha (float) – First concentration parameter of the Beta prior on
z.beta (float) – Second concentration parameter of the Beta prior on
z.coef_dist (Distribution) – Distribution for the slab coefficients.
coef_kwargs (dict[str, float]) – Kwargs for
coef_dist.
- __init__(alpha=0.5, beta=0.5, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
alpha (float)
beta (float)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
- __call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
- Parameters:
name (str) – Variable name scope.
x (Array) – Input of shape
(n, d).units (int) – Number of output dimensions.
activation (Callable[[Array], Array]) – Activation function.
- Returns:
jax.Array of shape
(n, units).- Return type:
Array
- class blayers.layers.AttentionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerMulti-head Bayesian self-attention over the feature dimension.
Treats the
dinput features as tokens using FT-Transformer style tokenisation (Gorishniy et al. 2021, https://arxiv.org/abs/2106.11959): each feature gets a per-column bias embedding (identity) plus a value-scaled embedding, so tokens are distinct even when the feature value is zero.For each observation
x_i ∈ R^d:Tokenise:
H_j = x_{i,j} · W_emb_j + W_bias_j(head_dim-dim each)Per head:
Q_m, K_m, V_m = H W_Q_m, H W_K_m, H W_V_mAttn_m = softmax(Q_m K_m^T / √h_k)Concatenate heads → mean-pool over features → project to
units
Requires
d ≥ 2for attention to be non-trivial. Total embedding dimension ishead_dim * num_heads— adding heads increases capacity rather than splitting a fixed budget.- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __call__(name, x, head_dim=8, num_heads=1, units=1, activation=<PjitFunction of <function identity>>)[source]#
- Parameters:
name (str) – Variable name scope.
x (Array) – Input of shape
(n, d). Each column is a feature token.head_dim (int) – Dimension of each individual head. Total embedding dimension is
head_dim * num_heads, so adding heads increases capacity.num_heads (int) – Number of attention heads.
units (int) – Number of output dimensions.
activation (Callable[[Array], Array]) – Activation function.
- Returns:
jax.Array of shape
(n, units).- Return type:
Array