Layers#
Implements Bayesian Layers using Jax and Numpyro.
- Design:
There are three levels of complexity here: class-level, instance-level, and call-level
The class-level handles things like choosing generic model form and how to multiply coefficents with data. Defined by the class Layer(BLayer) def itself.
The instance-level handles specific distributions that fit into a generic model and the initial parameters for those distributions. Defined by creating an instance of the class: Layer(*args, **kwargs).
The call-level handles seeing a batch of data, sampling from the distributions defined on the class and multiplying coefficients and data to produce an output, works like result = Layer(*args, **kwargs)(data)
- Notation:
i: observations in a batch
j, k: number of sampled coefficients
l: low rank dimension of low rank models
- class blayers.layers.AdaptiveLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0})[source]#
Bases:
BLayer
Bayesian layer with adaptive prior using hierarchical modeling.
- class blayers.layers.BLayer(*args: Any)[source]#
Bases:
ABC
Abstract base class for Bayesian layers. Lays out an interface.
- class blayers.layers.EmbeddingLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0})[source]#
Bases:
BLayer
Bayesian embedding layer for sparse categorical features.
- static matmul(beta: Array, x: Array) Array [source]#
Index into the embedding table using the provided indices.
- Parameters:
beta – Embedding table of shape (num_embeddings, embedding_dim).
x – Indices array of shape (n,).
- Returns:
Looked-up embeddings of shape (n, embedding_dim).
- Return type:
jax.Array
- class blayers.layers.FMLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0}, low_rank_dim: int = 3)[source]#
Bases:
BLayer
Bayesian factorization machine layer with adaptive priors.
- static matmul(theta: Array, x: Array) Array [source]#
Apply second-order factorization machine interaction.
Based on Rendle (2010). Computes: 0.5 * sum((xV)^2 - (x^2 V^2))
- Parameters:
theta – Weight matrix of shape (d, k).
x – Input data of shape (n, d).
- Returns:
Output of shape (n,).
- Return type:
jax.Array
- class blayers.layers.FixedPriorLayer(prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0, 'scale': 1.0})[source]#
Bases:
BLayer
Bayesian layer with a fixed prior distribution over coefficients.
- class blayers.layers.LowRankInteractionLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, low_rank_dim: int = 3, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0})[source]#
Bases:
BLayer
Takes two sets of features and learns a low-rank interaction matrix.
- static matmul(theta1: Array, theta2: Array, x: Array, z: Array) Array [source]#
Implements low rank multiplication.
According to ChatGPT this is a “factorized bilinear interaction”. Basically, you just need to project x and z down to a common number of low rank terms and then just multiply those terms.
This is equivalent to a UV decomposition where you use n=low_rank_dim on the columns of the U/V matrices.
- class blayers.layers.RandomEffectsLayer(lmbda_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.HalfNormal'>, prior_dist: ~numpyro.distributions.distribution.Distribution = <class 'numpyro.distributions.continuous.Normal'>, prior_kwargs: dict[str, float] = {'loc': 0.0}, lmbda_kwargs: dict[str, float] = {'scale': 1.0})[source]#
Bases:
BLayer
Exactly like the EmbeddingLayer but with embedding_dim=1.
- static matmul(beta: Array, x: Array) Array [source]#
Index into the embedding table using the provided indices.
- Parameters:
beta – Embedding table of shape (num_embeddings, embedding_dim).
x – Indices array of shape (n,).
- Returns:
Looked-up embeddings of shape (n, embedding_dim).
- Return type:
jax.Array