Layers#

Implements Bayesian Layers using Jax and Numpyro.

Design:
  • There are three levels of complexity here: class-level, instance-level, and call-level

  • The class-level handles things like choosing generic model form and how to multiply coefficents with data. Defined by the class Layer(BLayer) def itself.

  • The instance-level handles specific distributions that fit into a generic model and the initial parameters for those distributions. Defined by creating an instance of the class: Layer(*args, **kwargs).

  • The call-level handles seeing a batch of data, sampling from the distributions defined on the class and multiplying coefficients and data to produce an output, works like result = Layer(*args, **kwargs)(data)

Notation:
  • i: observations in a batch

  • j, k: number of sampled coefficients

  • l: low rank dimension of low rank models

class blayers.layers.AdaptiveLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0})[source]#

Bases: BLayer

Bayesian layer with adaptive prior using hierarchical modeling.

static matmul(beta: Array, x: Array) Array[source]#

Standard dot product between beta and x.

Parameters:
  • beta – Coefficient vector of shape (d,).

  • x – Input matrix of shape (n, d).

Returns:

Output of shape (n,).

Return type:

jax.Array

class blayers.layers.BLayer(*args: Any)[source]#

Bases: ABC

Abstract base class for Bayesian layers. Lays out an interface.

abstractmethod static matmul(*args: Any) Any[source]#

Abstract static method for matrix multiplication logic.

Parameters:

*args – Parameters to multiply.

Returns:

The result of the matrix multiplication.

Return type:

jax.Array

class blayers.layers.EmbeddingLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0})[source]#

Bases: BLayer

Bayesian embedding layer for sparse categorical features.

static matmul(beta: Array, x: Array) Array[source]#

Index into the embedding table using the provided indices.

Parameters:
  • beta – Embedding table of shape (num_embeddings, embedding_dim).

  • x – Indices array of shape (n,).

Returns:

Looked-up embeddings of shape (n, embedding_dim).

Return type:

jax.Array

class blayers.layers.FMLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0}, low_rank_dim: int = 3)[source]#

Bases: BLayer

Bayesian factorization machine layer with adaptive priors.

static matmul(theta: Array, x: Array) Array[source]#

Apply second-order factorization machine interaction.

Based on Rendle (2010). Computes: 0.5 * sum((xV)^2 - (x^2 V^2))

Parameters:
  • theta – Weight matrix of shape (d, k).

  • x – Input data of shape (n, d).

Returns:

Output of shape (n,).

Return type:

jax.Array

class blayers.layers.FixedPriorLayer(prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0, 'scale': 1.0})[source]#

Bases: BLayer

Bayesian layer with a fixed prior distribution over coefficients.

static matmul(beta: Array, x: Array) Array[source]#

A dot product.

Parameters:
  • beta – Model coefficients of shape (j,).

  • x – Input data array of shape (n, d).

Returns:

Output array of shape (n,).

Return type:

jax.Array

class blayers.layers.LowRankInteractionLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, low_rank_dim: int = 3, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0})[source]#

Bases: BLayer

Takes two sets of features and learns a low-rank interaction matrix.

static matmul(theta1: Array, theta2: Array, x: Array, z: Array) Array[source]#

Implements low rank multiplication.

According to ChatGPT this is a “factorized bilinear interaction”. Basically, you just need to project x and z down to a common number of low rank terms and then just multiply those terms.

This is equivalent to a UV decomposition where you use n=low_rank_dim on the columns of the U/V matrices.

class blayers.layers.RandomEffectsLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0})[source]#

Bases: BLayer

Exactly like the EmbeddingLayer but with embedding_dim=1.

static matmul(beta: Array, x: Array) Array[source]#

Index into the embedding table using the provided indices.

Parameters:
  • beta – Embedding table of shape (num_embeddings, embedding_dim).

  • x – Indices array of shape (n,).

Returns:

Looked-up embeddings of shape (n, embedding_dim).

Return type:

jax.Array