Layers#

Implements Bayesian Layers using Jax and Numpyro.

Design:
  • There are three levels of complexity here: class-level, instance-level, and call-level

  • The class-level handles things like choosing generic model form and how to multiply coefficents with data. Defined by the class Layer(BLayer) def itself.

  • The instance-level handles specific distributions that fit into a generic model and the initial parameters for those distributions. Defined by creating an instance of the class: Layer(*args, **kwargs).

  • The call-level handles seeing a batch of data, sampling from the distributions defined on the class and multiplying coefficients and data to produce an output, works like result = Layer(*args, **kwargs)(data)

Notation:
  • n: observations in a batch

  • c: number of categories of things for time, random effects, etc

  • d: number of coefficients

  • l: low rank dimension of low rank models

  • m: embedding dimension

  • u: units aka output dimension

class blayers.layers.BLayer(*args)[source]#

Bases: ABC

Abstract base class for Bayesian layers. Lays out an interface.

Parameters:

args (Any)

abstractmethod __init__(*args)[source]#

Initialize layer parameters.

Parameters:

args (Any)

Return type:

None

abstractmethod __call__(*args)[source]#

Run the layer’s forward pass.

Parameters:
  • name – Name scope for sampled variables. Note due to mypy stuff we only write the name arg explicitly in subclass.

  • *args (Any) – Inputs to the layer.

Returns:

The result of the forward computation.

Return type:

jax.Array

class blayers.layers.AdaptiveLayer(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#

Bases: BLayer

Bayesian layer with adaptive prior using hierarchical modeling.

Generates coefficients from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\beta \sim Normal(0., \lambda)\]
Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

  • units (int)

__init__(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#
Parameters:
  • lmbda_dist (Distribution) – NumPyro distribution class for the scale (λ) of the prior.

  • coef_dist (Distribution) – NumPyro distribution class for the coefficient prior.

  • coef_kwargs (dict[str, float]) – Parameters for the prior distribution.

  • lmbda_kwargs (dict[str, float]) – Parameters for the scale distribution.

  • units (int) – The number of outputs

  • dependent_outputs – For multi-output models whether to treat the outputs as dependent. By deafult they are independent.

__call__(name, x)[source]#

Forward pass with adaptive prior on coefficients.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input data array of shape (n, d, u).

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.FixedPriorLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0}, units=1)[source]#

Bases: BLayer

Bayesian layer with a fixed prior distribution over coefficients.

Generates coefficients from the model

\[\beta \sim Normal(0., 1.)\]
Parameters:
  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • units (int)

__init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0}, units=1)[source]#
Parameters:
  • coef_dist (Distribution) – NumPyro distribution class for the coefficients.

  • coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.

  • units (int)

__call__(name, x)[source]#

Forward pass with fixed prior.

Parameters:
  • name (str) – Variable name prefix.

  • x (Array) – Input data array of shape (n, d).

Returns:

Output array of shape (n, u).

Return type:

jax.Array

class blayers.layers.InterceptLayer(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0}, units=1)[source]#

Bases: BLayer

Bayesian layer with a fixed prior distribution over coefficients.

Generates coefficients from the model

\[\beta \sim Normal(0., 1.)\]
Parameters:
  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • units (int)

__init__(coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0, 'scale': 1.0}, units=1)[source]#
Parameters:
  • coef_dist (Distribution) – NumPyro distribution class for the coefficients.

  • coef_kwargs (dict[str, float]) – Parameters to initialize the prior distribution.

  • units (int)

__call__(name)[source]#

Forward pass with fixed prior.

Parameters:

name (str) – Variable name prefix.

Returns:

Output array of shape (1, u).

Return type:

jax.Array

class blayers.layers.EmbeddingLayer(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#

Bases: BLayer

Bayesian embedding layer for sparse categorical features.

Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

  • units (int)

__init__(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#
Parameters:
  • num_embeddings – Total number of discrete embedding entries.

  • embedding_dim – Dimensionality of each embedding vector.

  • coef_dist (Distribution) – Prior distribution for embedding weights.

  • coef_kwargs (dict[str, float]) – Parameters for the prior distribution.

  • lmbda_dist (Distribution)

  • lmbda_kwargs (dict[str, float])

  • units (int)

__call__(name, x, num_categories, embedding_dim)[source]#

Forward pass through embedding lookup.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Integer indices of shape (n,) indicating embeddings to use.

  • num_categories (int) – The number of distinct things getting an embedding

  • embedding_dim (int) – The size of each embedding, e.g. 2, 4, 8, etc.

Returns:

Embedding vectors of shape (n, m).

Return type:

jax.Array

class blayers.layers.RandomEffectsLayer(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#

Bases: BLayer

Exactly like the EmbeddingLayer but with embedding_dim=1.

Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

  • units (int)

__init__(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#
Parameters:
  • num_embeddings – Total number of discrete embedding entries.

  • embedding_dim – Dimensionality of each embedding vector.

  • coef_dist (Distribution) – Prior distribution for embedding weights.

  • coef_kwargs (dict[str, float]) – Parameters for the prior distribution.

  • lmbda_dist (Distribution)

  • lmbda_kwargs (dict[str, float])

  • units (int)

__call__(name, x, num_categories)[source]#

Forward pass through embedding lookup.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Integer indices of shape (n,) indicating embeddings to use.

  • num_categories (int) – The number of distinct things getting an embedding

Returns:

Embedding vectors of shape (n, embedding_dim).

Return type:

jax.Array

class blayers.layers.FMLayer(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#

Bases: BLayer

Bayesian factorization machine layer with adaptive priors.

Generates coefficients from the hierarchical model

\[\lambda \sim HalfNormal(1.)\]
\[\beta \sim Normal(0., \lambda)\]

The shape of beta is (j, l), where j is the number if input covariates and l is the low rank dim.

Then performs matrix multiplication using the formula in Rendle (2010).

Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

  • units (int)

__init__(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#
Parameters:
  • lmbda_dist (Distribution) – Distribution for scaling factor λ.

  • coef_dist (Distribution) – Prior for beta parameters.

  • coef_kwargs (dict[str, float]) – Arguments for prior distribution.

  • lmbda_kwargs (dict[str, float]) – Arguments for λ distribution.

  • low_rank_dim – Dimensionality of low-rank approximation.

  • units (int)

__call__(name, x, low_rank_dim)[source]#

Forward pass through the factorization machine layer.

Parameters:
  • name (str) – Variable name scope.

  • x (Array) – Input matrix of shape (n, d).

  • low_rank_dim (int)

Returns:

Output array of shape (n,).

Return type:

jax.Array

class blayers.layers.LowRankInteractionLayer(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#

Bases: BLayer

Takes two sets of features and learns a low-rank interaction matrix.

Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

  • units (int)

__init__(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#

Initialize layer parameters.

Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

  • units (int)

__call__(name, x, z, low_rank_dim)[source]#

Run the layer’s forward pass.

Parameters:
  • name (str) – Name scope for sampled variables. Note due to mypy stuff we only write the name arg explicitly in subclass.

  • *args – Inputs to the layer.

  • x (Array)

  • z (Array)

  • low_rank_dim (int)

Returns:

The result of the forward computation.

Return type:

jax.Array

class blayers.layers.RandomWalkLayer(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0})[source]#

Bases: BLayer

Random walk of embedding dim m, defaults to Gaussian walk.

Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

__init__(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0})[source]#

Initialize layer parameters.

Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

__call__(name, x, num_categories, embedding_dim)[source]#
Parameters:
  • name (str)

  • x (Array)

  • num_categories (int)

  • embedding_dim (int)

Return type:

Array

class blayers.layers.InteractionLayer(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#

Bases: BLayer

Computes every interaction coefficient between two sets of inputs.

Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

  • units (int)

__init__(lmbda_dist=<class 'numpyro.distributions.continuous.HalfNormal'>, coef_dist=<class 'numpyro.distributions.continuous.Normal'>, coef_kwargs={'loc': 0.0}, lmbda_kwargs={'scale': 1.0}, units=1)[source]#

Initialize layer parameters.

Parameters:
  • lmbda_dist (Distribution)

  • coef_dist (Distribution)

  • coef_kwargs (dict[str, float])

  • lmbda_kwargs (dict[str, float])

  • units (int)

__call__(name, x, z)[source]#

Run the layer’s forward pass.

Parameters:
  • name (str) – Name scope for sampled variables. Note due to mypy stuff we only write the name arg explicitly in subclass.

  • *args – Inputs to the layer.

  • x (Array)

  • z (Array)

Returns:

The result of the forward computation.

Return type:

jax.Array