Layers#
Implements Bayesian Layers using Jax and Numpyro.
- Design:
There are three levels of complexity here: class-level, instance-level, and call-level
The class-level handles things like choosing generic model form and how to multiply coefficents with data. Defined by the
class Layer(BLayer)def itself.The instance-level handles specific distributions that fit into a generic model and the initial parameters for those distributions. Defined by creating an instance of the class:
Layer(*args, **kwargs).The call-level handles seeing a batch of data, sampling from the distributions defined on the class and multiplying coefficients and data to produce an output, works like
result = Layer(*args, **kwargs)(data)
- Notation:
n: observations in a batchc: number of categories of things for time, random effects, etcd: number of coefficientsl: low rank dimension of low rank modelsm: embedding dimensionu: units aka output dimension
- class blayers.layers.BLayer(*args)[source]#
Bases:
ABCAbstract base class for Bayesian layers. Lays out an interface.
- Parameters:
args (Any)
- class blayers.layers.AdaptiveLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian layer with adaptive prior using hierarchical modeling.
Generates coefficients from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[\beta \sim Normal(0., \lambda)\]- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – NumPyro distribution class for the scale (λ) of the prior.
coef_dist (Distribution) – NumPyro distribution class for the coefficient prior.
coef_kwargs (dict[str, float]) – Parameters for the prior distribution.
scale_kwargs (dict[str, float]) – Parameters for the scale distribution.
- __call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass with adaptive prior on coefficients.
- Parameters:
name (str) – Variable name.
x (Array) – Input data array of shape
(n, d).units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.FixedPriorLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Bases:
BLayerBayesian layer with a fixed prior distribution over coefficients.
Generates coefficients from the model
\[\beta \sim Normal(0., 1.)\]- Parameters:
coef_dist (Distribution)
coef_kwargs (dict[str, float])
- __init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
- Parameters:
coef_dist (Distribution) – NumPyro distribution class for the coefficients.
coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.
- __call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass with fixed prior.
- Parameters:
name (str) – Variable name.
x (Array) – Input data array of shape
(n, d).units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.InterceptLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Bases:
BLayerBayesian layer with a fixed prior distribution over coefficients.
Generates coefficients from the model
\[\beta \sim Normal(0., 1.)\]- Parameters:
coef_dist (Distribution)
coef_kwargs (dict[str, float])
- __init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
- Parameters:
coef_dist (Distribution) – NumPyro distribution class for the coefficients.
coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.
- __call__(name, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass with fixed prior.
- Parameters:
name (str) – Variable name.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(1, u).- Return type:
jax.Array
- class blayers.layers.FMLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian factorization machine layer with adaptive priors.
Generates coefficients from the hierarchical model
\[\lambda \sim HalfNormal(1.)\]\[\beta \sim Normal(0., \lambda)\]The shape of
betais(j, l), wherejis the number if input covariates andlis the low rank dim.Then performs matrix multiplication using the formula in Rendle (2010).
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – Distribution for scaling factor λ.
coef_dist (Distribution) – Prior for beta parameters.
coef_kwargs (dict[str, float]) – Arguments for prior distribution.
scale_kwargs (dict[str, float]) – Arguments for λ distribution.
- __call__(name, x, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass through the factorization machine layer.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d).low_rank_dim (int) – Dimensionality of low-rank approximation.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.FM3Layer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerOrder 3 FM. See Blondel et al 2016.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – Distribution for scaling factor λ.
coef_dist (Distribution) – Prior for beta parameters.
coef_kwargs (dict[str, float]) – Arguments for prior distribution.
scale_kwargs (dict[str, float]) – Arguments for λ distribution.
- __call__(name, x, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass through the factorization machine layer.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d).low_rank_dim (int) – Dimensionality of low-rank approximation.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n,).- Return type:
jax.Array
- class blayers.layers.LowRankInteractionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerTakes two sets of features and learns a low-rank interaction matrix.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __call__(name, x, z, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#
Interaction between feature matrices X and Z in a low rank way. UV decomp.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d1).z (Array) – Input matrix of shape
(n, d2).low_rank_dim (int) – Dimensionality of low-rank approximation.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.InteractionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerComputes every interaction coefficient between two sets of inputs.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __call__(name, x, z, units=1, activation=<PjitFunction of <function identity>>)[source]#
Interaction between feature matrices X and Z in a low rank way. UV decomp.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d1).z (Array) – Input matrix of shape
(n, d2).units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.BilinearLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian bilinear interaction layer: computes x^T W z.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – prior on scale of coefficients
coef_dist (Distribution) – distribution for coefficients
coef_kwargs (dict[str, float]) – kwargs for coef distribution
scale_kwargs (dict[str, float]) – kwargs for scale prior
- __call__(name, x, z, units=1, activation=<PjitFunction of <function identity>>)[source]#
Interaction between feature matrices X and Z in a low rank way. UV decomp.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d1).z (Array) – Input matrix of shape
(n, d2).units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.LowRankBilinearLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian bilinear interaction layer: computes x^T W z. W low rank.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – prior on scale of coefficients
coef_dist (Distribution) – distribution for coefficients
coef_kwargs (dict[str, float]) – kwargs for coef distribution
scale_kwargs (dict[str, float]) – kwargs for scale prior
- __call__(name, x, z, low_rank_dim, units=1, activation=<PjitFunction of <function identity>>)[source]#
Interaction between feature matrices X and Z in a low rank way. UV decomp.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input matrix of shape
(n, d1).z (Array) – Input matrix of shape
(n, d2).low_rank_dim (int) – Dimensionality of low-rank approximation.
units (int) – Number of outputs.
activation (Callable[[Array], Array]) – Activation function to apply to output.
- Returns:
Output array of shape
(n, u).- Return type:
jax.Array
- class blayers.layers.EmbeddingLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerBayesian embedding layer for sparse categorical features.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
scale_dist (Distribution) – NumPyro distribution class for the scale (λ) of the prior.
coef_dist (Distribution) – NumPyro distribution class for the coefficient prior.
coef_kwargs (dict[str, float]) – Parameters for the prior distribution.
scale_kwargs (dict[str, float]) – Parameters for the scale distribution.
- __call__(name, x, num_categories, embedding_dim)[source]#
Forward pass through embedding lookup.
- Parameters:
name (str) – Variable name scope.
x (Array) – Integer indices indicating embeddings to use.
num_categories (int) – The number of distinct things getting an embedding
embedding_dim (int) – The size of each embedding, e.g. 2, 4, 8, etc.
- Returns:
Embedding vectors of shape
(n, m).- Return type:
jax.Array
- class blayers.layers.RandomEffectsLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerExactly like the EmbeddingLayer but with
embedding_dim=1.- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
- Parameters:
num_embeddings – Total number of discrete embedding entries.
embedding_dim – Dimensionality of each embedding vector.
coef_dist (Distribution) – Prior distribution for embedding weights.
coef_kwargs (dict[str, float]) – Parameters for the prior distribution.
scale_dist (Distribution)
scale_kwargs (dict[str, float])
- __call__(name, x, num_categories)[source]#
Forward pass through embedding lookup.
- Parameters:
name (str) – Variable name scope.
x (Array) – Integer indicating embeddings to use.
num_categories (int) – The number of distinct things getting an embedding
- Returns:
Embedding vectors of shape (n, embedding_dim).
- Return type:
jax.Array
- class blayers.layers.RandomWalkLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerRandom walk of embedding dim
m, defaults to Gaussian walk.- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __call__(name, x, num_categories, embedding_dim)[source]#
Forward pass through embedding lookup.
- Parameters:
name (str) – Variable name scope.
x (Array) – Integer indices indicating embeddings to use.
num_categories (int) – The number of distinct things getting an embedding
embedding_dim (int) – The size of each embedding, e.g. 2, 4, 8, etc.
- Returns:
Embedding vectors of shape
(n, m).- Return type:
jax.Array
- class blayers.layers.HorseshoeLayer(slab_scale=None, slab_df=4.0, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0})[source]#
Bases:
BLayerBayesian layer with horseshoe prior for sparse regression.
Implements the (regularized) horseshoe prior of Piironen & Vehtari (2017).
Basic horseshoe:
\[\tau \sim HalfCauchy(1), \quad \lambda_j \sim HalfCauchy(1), \quad \beta_j \sim Normal(0,\; \tau \lambda_j)\]Regularized horseshoe (
slab_scaleset) — prevents large coefficients from escaping the slab:\[\tilde{\lambda}_j^2 = \frac{c^2 \lambda_j^2}{c^2 + \tau^2 \lambda_j^2}, \quad c^2 \sim InverseGamma(s/2,\; s/2 \cdot scale_{slab}^2)\]- Parameters:
slab_scale (float | None)
slab_df (float)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
- __init__(slab_scale=None, slab_df=4.0, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0})[source]#
- Parameters:
slab_scale (float | None) – If set, uses the regularized horseshoe with this slab scale.
Nonegives the plain horseshoe.slab_df (float) – Degrees of freedom for the slab variance prior (only used when
slab_scaleis set).coef_dist (Distribution) – Distribution for the coefficients. Must accept a
scalekeyword (derived from the horseshoe shrinkage). Defaults toNormal.coef_kwargs (dict[str, float]) – Extra kwargs for
coef_dist(beyondscale). Default{"loc": 0.0}.
- __call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
Forward pass with horseshoe prior on coefficients.
- Parameters:
name (str) – Variable name scope.
x (Array) – Input array of shape
(n, d).units (int) – Number of output dimensions.
activation (Callable[[Array], Array]) – Activation function.
- Returns:
jax.Array of shape
(n, units).- Return type:
Array
- class blayers.layers.SpikeAndSlabLayer(alpha=0.5, beta=0.5, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Bases:
BLayerSparse regression via a spike-and-slab prior.
Each coefficient has a Beta-distributed inclusion weight
z_jin (0, 1). Included features (z_j ≈ 1) take the full slab coefficient; excluded features (z_j ≈ 0) are gated toward zero (the spike).Generative model:
z_j ~ Beta(alpha, beta) # inclusion weight (hardcoded Beta) β_j ~ coef_dist(**coef_kwargs) # slab coefficient y ~ link(z · β · x, ...) # z gates each coefficient
The default
Beta(0.5, 0.5)(Jeffreys prior) places mass near 0 and 1, encouraging features to be clearly included or excluded. The posterior mean ofz_japproximatesP(feature j included | data).The slab distribution defaults to
Normal(0, 1)but can be swapped for e.g.StudentTfor heavier-tailed slab behaviour.- Parameters:
alpha (float) – First concentration parameter of the Beta prior on
z.beta (float) – Second concentration parameter of the Beta prior on
z.coef_dist (Distribution) – Distribution for the slab coefficients.
coef_kwargs (dict[str, float]) – Kwargs for
coef_dist.
- __init__(alpha=0.5, beta=0.5, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
alpha (float)
beta (float)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
- __call__(name, x, units=1, activation=<PjitFunction of <function identity>>)[source]#
- Parameters:
name (str) – Variable name scope.
x (Array) – Input of shape
(n, d).units (int) – Number of output dimensions.
activation (Callable[[Array], Array]) – Activation function.
- Returns:
jax.Array of shape
(n, units).- Return type:
Array
- class blayers.layers.AttentionLayer(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Bases:
BLayerMulti-head Bayesian self-attention over the feature dimension.
Treats the
dinput features as tokens using FT-Transformer style tokenisation (Gorishniy et al. 2021, https://arxiv.org/abs/2106.11959): each feature gets a per-column bias embedding (identity) plus a value-scaled embedding, so tokens are distinct even when the feature value is zero.For each observation
x_i ∈ R^d:Tokenise:
H_j = x_{i,j} · W_emb_j + W_bias_j(head_dim-dim each)Per head:
Q_m, K_m, V_m = H W_Q_m, H W_K_m, H W_V_mAttn_m = softmax(Q_m K_m^T / √h_k)Concatenate heads → mean-pool over features → project to
units
Requires
d ≥ 2for attention to be non-trivial. Total embedding dimension ishead_dim * num_heads— adding heads increases capacity rather than splitting a fixed budget.- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __init__(scale_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, scale_kwargs={'scale': 1.0})[source]#
Initialize layer parameters. This is the Bayesian model.
- Parameters:
scale_dist (Distribution)
coef_dist (Distribution)
coef_kwargs (dict[str, float])
scale_kwargs (dict[str, float])
- __call__(name, x, head_dim=8, num_heads=1, units=1, activation=<PjitFunction of <function identity>>)[source]#
- Parameters:
name (str) – Variable name scope.
x (Array) – Input of shape
(n, d). Each column is a feature token.head_dim (int) – Dimension of each individual head. Total embedding dimension is
head_dim * num_heads, so adding heads increases capacity.num_heads (int) – Number of attention heads.
units (int) – Number of output dimensions.
activation (Callable[[Array], Array]) – Activation function.
- Returns:
jax.Array of shape
(n, units).- Return type:
Array