Layers#

Implements Bayesian Layers using Jax and Numpyro.

Design:
  • There are three levels of complexity here: class-level, instance-level, and call-level

  • The class-level handles things like choosing generic model form and how to multiply coefficents with data. Defined by the class Layer(BLayer) def itself.

  • The instance-level handles specific distributions that fit into a generic model and the initial parameters for those distributions. Defined by creating an instance of the class: Layer(*args, **kwargs).

  • The call-level handles seeing a batch of data, sampling from the distributions defined on the class and multiplying coefficients and data to produce an output, works like result = Layer(*args, **kwargs)(data)

Notation:
  • n: observations in a batch

  • c: number of categories of things for time, random effects, etc

  • d: number of coefficients

  • l: low rank dimension of low rank models

  • m: embedding dimension

  • u: units aka output dimension

class blayers.layers.BLayer(*args)[source]#

Bases: ABC

Abstract base class for Bayesian layers. Lays out an interface.

Parameters:

args (Any)

abstractmethod __init__(*args)[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:

args (Any)

Return type:

None

abstractmethod __call__(*args)[source]#

Run the layer’s forward pass.

Parameters:

*args (Any) – Inputs to the layer.

Returns:

The result of the forward computation.

Return type:

jax.Array

class blayers.layers.AdaptiveLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian layer with adaptive prior using hierarchical modeling.

Generates coefficients from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\beta \sim Normal(0., \lambda)\]
Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – NumPyro distribution class for the scale (λ) of the prior.

  • coef_dist (Distribution) – NumPyro distribution class for the coefficient prior.

  • coef_kwargs (dict[str, float]) – Parameters for the prior distribution.

  • scale_kwargs (dict[str, float]) – Parameters for the scale distribution.

__call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass with adaptive prior on coefficients.

Parameters:
  • name (str) – Variable name.

  • x (Array) – Input data array of shape (n, d).

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.FixedPriorLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#

Bases: BLayer

Bayesian layer with a fixed prior distribution over coefficients.

Generates coefficients from the model

\[\beta \sim Normal(0., 1.)\]
Parameters:
  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

__init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Parameters:
  • coef_dist (Distribution) – NumPyro distribution class for the coefficients.

  • coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.

__call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass with fixed prior.

Parameters:
  • name (str) – Variable name.

  • x (Array) – Input data array of shape (n, d).

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.InterceptLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#

Bases: BLayer

Bayesian layer with a fixed prior distribution over coefficients.

Generates coefficients from the model

\[\beta \sim Normal(0., 1.)\]
Parameters:
  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

__init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Parameters:
  • coef_dist (Distribution) – NumPyro distribution class for the coefficients.

  • coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.

__call__(name, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass with fixed prior.

Parameters:
  • name (str) – Variable name.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (1, u).

Return type:

jax.Array

class blayers.layers.FMLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian factorization machine layer with adaptive priors.

Generates coefficients from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\beta \sim Normal(0., \lambda)\]

The shape of beta is (j, l), where j is the number if input covariates and l is the low rank dim.

Then performs matrix multiplication using the formula in Rendle (2010).

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – Distribution for scaling factor λ.

  • coef_dist (Distribution) – Prior for beta parameters.

  • coef_kwargs (dict[str, float]) – Arguments for prior distribution.

  • scale_kwargs (dict[str, float]) – Arguments for λ distribution.

__call__(name, x, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass through the factorization machine layer.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d).

  • low_rank_dim (int) – Dimensionality of low-rank approximation.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.FM3Layer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Order 3 FM. See Blondel et al 2016.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – Distribution for scaling factor λ.

  • coef_dist (Distribution) – Prior for beta parameters.

  • coef_kwargs (dict[str, float]) – Arguments for prior distribution.

  • scale_kwargs (dict[str, float]) – Arguments for λ distribution.

__call__(name, x, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass through the factorization machine layer.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d).

  • low_rank_dim (int) – Dimensionality of low-rank approximation.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n,).

Return type:

jax.Array

class blayers.layers.LowRankInteractionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Takes two sets of features and learns a low-rank interaction matrix.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__call__(name, x, z, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#

Interaction between feature matrices X and Z in a low rank way. UV decomp.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d1).

  • z (Array) – Input matrix of shape (n, d2).

  • low_rank_dim (int) – Dimensionality of low-rank approximation.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.InteractionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Computes every interaction coefficient between two sets of inputs.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__call__(name, x, z, units=1, activation=<PjitFunction of <function identity>>)[source]#

Interaction between feature matrices X and Z in a low rank way. UV decomp.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d1).

  • z (Array) – Input matrix of shape (n, d2).

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.BilinearLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian bilinear interaction layer: computes x^T W z.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – prior on scale of coefficients

  • coef_dist (Distribution) – distribution for coefficients

  • coef_kwargs (dict[str, float]) – kwargs for coef distribution

  • scale_kwargs (dict[str, float]) – kwargs for scale prior

__call__(name, x, z, units=1, activation=<PjitFunction of <function identity>>)[source]#

Interaction between feature matrices X and Z in a low rank way. UV decomp.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d1).

  • z (Array) – Input matrix of shape (n, d2).

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.LowRankBilinearLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian bilinear interaction layer: computes x^T W z. W low rank.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – prior on scale of coefficients

  • coef_dist (Distribution) – distribution for coefficients

  • coef_kwargs (dict[str, float]) – kwargs for coef distribution

  • scale_kwargs (dict[str, float]) – kwargs for scale prior

__call__(name, x, z, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#

Interaction between feature matrices X and Z in a low rank way. UV decomp.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d1).

  • z (Array) – Input matrix of shape (n, d2).

  • low_rank_dim (int) – Dimensionality of low-rank approximation.

  • units (int) – Number of outputs.

  • activation (Callable[[Array], Array]) – Activation function to apply to output.

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.EmbeddingLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Bayesian embedding layer for sparse categorical features.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • scale_dist (Distribution) – NumPyro distribution class for the scale (λ) of the prior.

  • coef_dist (Distribution) – NumPyro distribution class for the coefficient prior.

  • coef_kwargs (dict[str, float]) – Parameters for the prior distribution.

  • scale_kwargs (dict[str, float]) – Parameters for the scale distribution.

__call__(name, x, num_categories, embedding_dim)[source]#

Forward pass through embedding lookup.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Integer indices indicating embeddings to use.

  • num_categories (int) – The number of distinct things getting an embedding

  • embedding_dim (int) – The size of each embedding, e.g. 2, 4, 8, etc.

Returns:

Embedding vectors of shape (n, m).

Return type:

jax.Array

class blayers.layers.RandomEffectsLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Exactly like the EmbeddingLayer but with embedding_dim=1.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Parameters:
  • num_embeddings – Total number of discrete embedding entries.

  • embedding_dim – Dimensionality of each embedding vector.

  • coef_dist (Distribution) – Prior distribution for embedding weights.

  • coef_kwargs (dict[str, float]) – Parameters for the prior distribution.

  • scale_dist (Distribution)

  • scale_kwargs (dict[str, float])

__call__(name, x, num_categories)[source]#

Forward pass through embedding lookup.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Integer indicating embeddings to use.

  • num_categories (int) – The number of distinct things getting an embedding

Returns:

Embedding vectors of shape (n, embedding_dim).

Return type:

jax.Array

class blayers.layers.RandomWalkLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Random walk of embedding dim m, defaults to Gaussian walk.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__call__(name, x, num_categories, embedding_dim)[source]#

Forward pass through embedding lookup.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Integer indices indicating embeddings to use.

  • num_categories (int) – The number of distinct things getting an embedding

  • embedding_dim (int) – The size of each embedding, e.g. 2, 4, 8, etc.

Returns:

Embedding vectors of shape (n, m).

Return type:

jax.Array

class blayers.layers.HorseshoeLayer(slab_scale=None, slab_df=4.0, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0})[source]#

Bases: BLayer

Bayesian layer with horseshoe prior for sparse regression.

Implements the (regularized) horseshoe prior of Piironen & Vehtari (2017).

Basic horseshoe:

\[\tau \sim HalfCauchy(1), \quad \lambda_j \sim HalfCauchy(1), \quad \beta_j \sim Normal(0,\; \tau \lambda_j)\]

Regularized horseshoe (slab_scale set) — prevents large coefficients from escaping the slab:

\[\tilde{\lambda}_j^2 = \frac{c^2 \lambda_j^2}{c^2 + \tau^2 \lambda_j^2}, \quad c^2 \sim InverseGamma(s/2,\; s/2 \cdot scale_{slab}^2)\]
Parameters:
  • slab_scale (float | None)

  • slab_df (float)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

__init__(slab_scale=None, slab_df=4.0, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0})[source]#
Parameters:
  • slab_scale (float | None) – If set, uses the regularized horseshoe with this slab scale. None gives the plain horseshoe.

  • slab_df (float) – Degrees of freedom for the slab variance prior (only used when slab_scale is set).

  • coef_dist (Distribution) – Distribution for the coefficients. Must accept a scale keyword (derived from the horseshoe shrinkage). Defaults to Normal.

  • coef_kwargs (dict[str, float]) – Extra kwargs for coef_dist (beyond scale). Default {"loc": 0.0}.

__call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#

Forward pass with horseshoe prior on coefficients.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input array of shape (n, d).

  • units (int) – Number of output dimensions.

  • activation (Callable[[Array], Array]) – Activation function.

Returns:

jax.Array of shape (n, units).

Return type:

Array

class blayers.layers.SpikeAndSlabLayer(alpha=0.5, beta=0.5, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#

Bases: BLayer

Sparse regression via a spike-and-slab prior.

Each coefficient has a Beta-distributed inclusion weight z_j in (0, 1). Included features (z_j 1) take the full slab coefficient; excluded features (z_j 0) are gated toward zero (the spike).

Generative model:

z_j ~ Beta(alpha, beta)          # inclusion weight (hardcoded Beta)
β_j ~ coef_dist(**coef_kwargs)   # slab coefficient
y   ~ link(z · β · x, ...)       # z gates each coefficient

The default Beta(0.5, 0.5) (Jeffreys prior) places mass near 0 and 1, encouraging features to be clearly included or excluded. The posterior mean of z_j approximates P(feature j included | data).

The slab distribution defaults to Normal(0, 1) but can be swapped for e.g. StudentT for heavier-tailed slab behaviour.

Parameters:
  • alpha (float) – First concentration parameter of the Beta prior on z.

  • beta (float) – Second concentration parameter of the Beta prior on z.

  • coef_dist (Distribution) – Distribution for the slab coefficients.

  • coef_kwargs (dict[str, float]) – Kwargs for coef_dist.

__init__(alpha=0.5, beta=0.5, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • alpha (float)

  • beta (float)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

__call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input of shape (n, d).

  • units (int) – Number of output dimensions.

  • activation (Callable[[Array], Array]) – Activation function.

Returns:

jax.Array of shape (n, units).

Return type:

Array

class blayers.layers.AttentionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Multi-head Bayesian self-attention over the feature dimension.

Treats the d input features as tokens using FT-Transformer style tokenisation (Gorishniy et al. 2021, https://arxiv.org/abs/2106.11959): each feature gets a per-column bias embedding (identity) plus a value-scaled embedding, so tokens are distinct even when the feature value is zero.

For each observation x_i R^d:

  1. Tokenise: H_j = x_{i,j} · W_emb_j + W_bias_j (head_dim-dim each)

  2. Per head: Q_m, K_m, V_m = H W_Q_m, H W_K_m, H W_V_m

  3. Attn_m = softmax(Q_m K_m^T / √h_k)

  4. Concatenate heads → mean-pool over features → project to units

Requires d 2 for attention to be non-trivial. Total embedding dimension is head_dim * num_heads — adding heads increases capacity rather than splitting a fixed budget.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#

Initialize layer parameters. This is the Bayesian model.

Parameters:
  • scale_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • scale_kwargs (dict[str, float])

__call__(name, x, head_dim=8, num_heads=1, units=1, activation=<PjitFunction of <function identity>>)[source]#
Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input of shape (n, d). Each column is a feature token.

  • head_dim (int) – Dimension of each individual head. Total embedding dimension is head_dim * num_heads, so adding heads increases capacity.

  • num_heads (int) – Number of attention heads.

  • units (int) – Number of output dimensions.

  • activation (Callable[[Array], Array]) – Activation function.

Returns:

jax.Array of shape (n, units).

Return type:

Array