Source code for blayers.layers

"""
Implements Bayesian Layers using Jax and Numpyro.

Design:
  - There are three levels of complexity here: class-level, instance-level, and
    call-level
  - The class-level handles things like choosing generic model form and how to
    multiply coefficents with data. Defined by the `class Layer(BLayer)` def
    itself.
  - The instance-level handles specific distributions that fit into a generic
    model and the initial parameters for those distributions. Defined by
    creating an instance of the class: `Layer(*args, **kwargs)`.
  - The call-level handles seeing a batch of data, sampling from the
    distributions defined on the class and multiplying coefficients and data to
    produce an output, works like `result = Layer(*args, **kwargs)(data)`

Notation:
  - `n`: observations in a batch
  - `d`: number of coefficients
  - `l`: low rank dimension of low rank models
  - `u`: units aka output dimension
  - `m`: embedding dimension
"""

from abc import ABC, abstractmethod
from typing import Any

import jax
import jax.numpy as jnp
from numpyro import distributions, sample

from blayers._utils import add_trailing_dim

# ---- Matmul functions ------------------------------------------------------ #


def _matmul_dot_product(x: jax.Array, beta: jax.Array) -> jax.Array:
    """
    Standard dot product between beta and x.

    Args:
        beta: Coefficient vector of shape (d, u).
        x: Input matrix of shape (n, d).

    Returns:
        jax.Array: Output of shape (n, u).
    """
    return jnp.einsum("nd,du->nu", x, beta)


def _matmul_factorization_machine(x: jax.Array, theta: jax.Array) -> jax.Array:
    """
    Apply second-order factorization machine interaction.

    Based on Rendle (2010). Computes:
    0.5 * sum((xV)^2 - (x^2 V^2))

    Args:
        theta: Weight matrix of shape (d, l, u).
        x: Input data of shape (n, d).

    Returns:
        jax.Array: Output of shape (n, u).
    """
    vx2 = jnp.einsum("nd,dlu->nlu", x, theta) ** 2
    v2x2 = jnp.einsum("nd,dlu->nlu", x**2, theta**2)
    return 0.5 * jnp.einsum("nlu->nu", vx2 - v2x2)


def _matmul_uv_decomp(
    x: jax.Array,
    z: jax.Array,
    theta1: jax.Array,
    theta2: jax.Array,
) -> jax.Array:
    """Implements low rank multiplication.

    According to ChatGPT this is a "factorized bilinear interaction".
    Basically, you just need to project x and z down to a common number of
    low rank terms and then just multiply those terms.

    This is equivalent to a UV decomposition where you use n=low_rank_dim
    on the columns of the U/V matrices.
    """
    xb = jnp.einsum("nd,dlu->nlu", x, theta1)
    zb = jnp.einsum("nd,dlu->nlu", z, theta2)
    return jnp.einsum("nlu->nu", xb * zb)


def _matmul_randomwalk(
    x: jax.Array,
    theta: jax.Array,
) -> jax.Array:
    """."""
    theta_cumsum = jnp.cumsum(theta)
    x_flat = x.squeeze(-1).astype(jnp.int32)
    return jnp.reshape(theta_cumsum[x_flat], (-1, 1))


def _matmul_interaction(
    x: jax.Array,
    z: jax.Array,
    beta: jax.Array,
) -> jax.Array:
    """."""
    n, d1 = x.shape
    _, d2 = z.shape

    # thanks chat GPT
    interactions = jnp.reshape(x[:, :, None] * z[:, None, :], (n, d1 * d2))

    return jnp.einsum("nd,du->nu", interactions, beta)


# ---- Classes --------------------------------------------------------------- #


[docs] class BLayer(ABC): """Abstract base class for Bayesian layers. Lays out an interface."""
[docs] @abstractmethod def __init__(self, *args: Any) -> None: """Initialize layer parameters."""
[docs] @abstractmethod def __call__(self, *args: Any) -> Any: """ Run the layer's forward pass. Args: name: Name scope for sampled variables. Note due to mypy stuff we only write the `name` arg explicitly in subclass. *args: Inputs to the layer. Returns: jax.Array: The result of the forward computation. """
[docs] class AdaptiveLayer(BLayer): """Bayesian layer with adaptive prior using hierarchical modeling. Generates coefficients from the hierarchical model .. math:: \\lambda \\sim HalfNormal(1.) .. math:: \\beta \\sim Normal(0., \\lambda) """
[docs] def __init__( self, lmbda_dist: distributions.Distribution = distributions.HalfNormal, coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0}, lmbda_kwargs: dict[str, float] = {"scale": 1.0}, units: int = 1, ): """ Args: lmbda_dist: NumPyro distribution class for the scale (λ) of the prior. coef_dist: NumPyro distribution class for the coefficient prior. coef_kwargs: Parameters for the prior distribution. lmbda_kwargs: Parameters for the scale distribution. units: The number of outputs dependent_outputs: For multi-output models whether to treat the outputs as dependent. By deafult they are independent. """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist self.coef_kwargs = coef_kwargs self.lmbda_kwargs = lmbda_kwargs self.units = units
[docs] def __call__( self, name: str, x: jax.Array, ) -> jax.Array: """ Forward pass with adaptive prior on coefficients. Args: name: Variable name scope. x: Input data array of shape (n, d, u). Returns: jax.Array: Output array of shape (n, u). """ x = add_trailing_dim(x) input_shape = x.shape[1] # sampling block lmbda = sample( name=f"{self.__class__.__name__}_{name}_lmbda", fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]), ) beta = sample( name=f"{self.__class__.__name__}_{name}_beta", fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( [input_shape, self.units] ), ) # matmul and return return _matmul_dot_product(x, beta)
[docs] class FixedPriorLayer(BLayer): """Bayesian layer with a fixed prior distribution over coefficients. Generates coefficients from the model .. math:: \\beta \\sim Normal(0., 1.) """
[docs] def __init__( self, coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0, "scale": 1.0}, units: int = 1, ): """ Args: coef_dist: NumPyro distribution class for the coefficients. coef_kwargs: Parameters to initialize the prior distribution. """ self.coef_dist = coef_dist self.coef_kwargs = coef_kwargs self.units = units
[docs] def __call__( self, name: str, x: jax.Array, ) -> jax.Array: """ Forward pass with fixed prior. Args: name: Variable name prefix. x: Input data array of shape (n, d). Returns: jax.Array: Output array of shape (n, u). """ x = add_trailing_dim(x) input_shape = x.shape[1] # sampling block beta = sample( name=f"{self.__class__.__name__}_{name}_beta", fn=self.coef_dist(**self.coef_kwargs).expand( [input_shape, self.units] ), ) # matmul and return return _matmul_dot_product(x, beta)
[docs] class InterceptLayer(BLayer): """Bayesian layer with a fixed prior distribution over coefficients. Generates coefficients from the model .. math:: \\beta \\sim Normal(0., 1.) """
[docs] def __init__( self, coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0, "scale": 1.0}, units: int = 1, ): """ Args: coef_dist: NumPyro distribution class for the coefficients. coef_kwargs: Parameters to initialize the prior distribution. """ self.coef_dist = coef_dist self.coef_kwargs = coef_kwargs self.units = units
[docs] def __call__( self, name: str, ) -> jax.Array: """ Forward pass with fixed prior. Args: name: Variable name prefix. Returns: jax.Array: Output array of shape (1, u). """ # sampling block beta = sample( name=f"{self.__class__.__name__}_{name}_beta", fn=self.coef_dist(**self.coef_kwargs).expand([1, self.units]), ) return beta
[docs] class EmbeddingLayer(BLayer): """Bayesian embedding layer for sparse categorical features."""
[docs] def __init__( self, lmbda_dist: distributions.Distribution = distributions.HalfNormal, coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0}, lmbda_kwargs: dict[str, float] = {"scale": 1.0}, units: int = 1, ): """ Args: num_embeddings: Total number of discrete embedding entries. embedding_dim: Dimensionality of each embedding vector. coef_dist: Prior distribution for embedding weights. coef_kwargs: Parameters for the prior distribution. """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist self.coef_kwargs = coef_kwargs self.lmbda_kwargs = lmbda_kwargs self.units = units
[docs] def __call__( self, name: str, x: jax.Array, num_categories: int, embedding_dim: int, ) -> jax.Array: """ Forward pass through embedding lookup. Args: name: Variable name scope. x: Integer indices of shape (n,) indicating embeddings to use. num_categories: The number of distinct things getting an embedding embedding_dim: The size of each embedding, e.g. 2, 4, 8, etc. Returns: jax.Array: Embedding vectors of shape (n, m). """ # sampling block lmbda = sample( name=f"{self.__class__.__name__}_{name}_lmbda", fn=self.lmbda_dist(**self.lmbda_kwargs), ) beta = sample( name=f"{self.__class__.__name__}_{name}_beta", fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( [num_categories, embedding_dim] ), ) # matmul and return return beta[x.squeeze()]
[docs] class RandomEffectsLayer(BLayer): """Exactly like the EmbeddingLayer but with `embedding_dim=1`."""
[docs] def __init__( self, lmbda_dist: distributions.Distribution = distributions.HalfNormal, coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0}, lmbda_kwargs: dict[str, float] = {"scale": 1.0}, units: int = 1, ): """ Args: num_embeddings: Total number of discrete embedding entries. embedding_dim: Dimensionality of each embedding vector. coef_dist: Prior distribution for embedding weights. coef_kwargs: Parameters for the prior distribution. """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist self.coef_kwargs = coef_kwargs self.lmbda_kwargs = lmbda_kwargs self.units = units
[docs] def __call__( self, name: str, x: jax.Array, num_categories: int, ) -> jax.Array: """ Forward pass through embedding lookup. Args: name: Variable name scope. x: Integer indices of shape (n,) indicating embeddings to use. num_categories: The number of distinct things getting an embedding Returns: jax.Array: Embedding vectors of shape (n, embedding_dim). """ # sampling block lmbda = sample( name=f"{self.__class__.__name__}_{name}_lmbda", fn=self.lmbda_dist(**self.lmbda_kwargs), ) beta = sample( name=f"{self.__class__.__name__}_{name}_beta", fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( [num_categories, 1] ), ) return beta[x.squeeze()]
[docs] class FMLayer(BLayer): """Bayesian factorization machine layer with adaptive priors. Generates coefficients from the hierarchical model .. math:: \\lambda \\sim HalfNormal(1.) .. math:: \\beta \\sim Normal(0., \\lambda) The shape of :math:`\\beta` is :math:`(j, l)`, where :math:`j` is the number if input covariates and :math:`l` is the low rank dim. Then performs matrix multiplication using the formula in `Rendle (2010) <https://jame-zhang.github.io/assets/algo/Factorization-Machines-Rendle2010.pdf>`_. """
[docs] def __init__( self, lmbda_dist: distributions.Distribution = distributions.HalfNormal, coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0}, lmbda_kwargs: dict[str, float] = {"scale": 1.0}, units: int = 1, ): """ Args: lmbda_dist: Distribution for scaling factor λ. coef_dist: Prior for beta parameters. coef_kwargs: Arguments for prior distribution. lmbda_kwargs: Arguments for λ distribution. low_rank_dim: Dimensionality of low-rank approximation. """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist self.coef_kwargs = coef_kwargs self.lmbda_kwargs = lmbda_kwargs self.units = units
[docs] def __call__( self, name: str, x: jax.Array, low_rank_dim: int, ) -> jax.Array: """ Forward pass through the factorization machine layer. Args: name: Variable name scope. x: Input matrix of shape (n, d). Returns: jax.Array: Output array of shape (n,). """ # get shapes and reshape if necessary x = add_trailing_dim(x) input_shape = x.shape[1] # sampling block lmbda = sample( name=f"{self.__class__.__name__}_{name}_lmbda", fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]), ) theta = sample( name=f"{self.__class__.__name__}_{name}_theta", fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( [input_shape, low_rank_dim, self.units] ), ) # matmul and return return _matmul_factorization_machine(x, theta)
[docs] class LowRankInteractionLayer(BLayer): """Takes two sets of features and learns a low-rank interaction matrix."""
[docs] def __init__( self, lmbda_dist: distributions.Distribution = distributions.HalfNormal, coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0}, lmbda_kwargs: dict[str, float] = {"scale": 1.0}, units: int = 1, ): self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist self.coef_kwargs = coef_kwargs self.lmbda_kwargs = lmbda_kwargs self.units = units
[docs] def __call__( self, name: str, x: jax.Array, z: jax.Array, low_rank_dim: int, ) -> jax.Array: # get shapes and reshape if necessary x = add_trailing_dim(x) z = add_trailing_dim(z) input_shape1 = x.shape[1] input_shape2 = z.shape[1] # sampling block lmbda1 = sample( name=f"{self.__class__.__name__}_{name}_lmbda1", fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]), ) theta1 = sample( name=f"{self.__class__.__name__}_{name}_theta1", fn=self.coef_dist(scale=lmbda1, **self.coef_kwargs).expand( [input_shape1, low_rank_dim, self.units] ), ) lmbda2 = sample( name=f"{self.__class__.__name__}_{name}_lmbda2", fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]), ) theta2 = sample( name=f"{self.__class__.__name__}_{name}_theta2", fn=self.coef_dist(scale=lmbda2, **self.coef_kwargs).expand( [input_shape2, low_rank_dim, self.units] ), ) return _matmul_uv_decomp(x, z, theta1, theta2)
[docs] class RandomWalkLayer(BLayer): """."""
[docs] def __init__( self, lmbda_dist: distributions.Distribution = distributions.HalfNormal, coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0}, lmbda_kwargs: dict[str, float] = {"scale": 1.0}, ): self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist self.coef_kwargs = coef_kwargs self.lmbda_kwargs = lmbda_kwargs
[docs] def __call__( self, name: str, x: jax.Array, num_periods: int, ) -> jax.Array: """ """ # sampling block lmbda = sample( name=f"{self.__class__.__name__}_{name}_lmbda", fn=self.lmbda_dist(**self.lmbda_kwargs), ) theta = sample( name=f"{self.__class__.__name__}_{name}_theta", fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( [ num_periods, ] ), ) # matmul and return return _matmul_randomwalk(x, theta)
[docs] class InteractionLayer(BLayer):
[docs] def __init__( self, lmbda_dist: distributions.Distribution = distributions.HalfNormal, coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0}, lmbda_kwargs: dict[str, float] = {"scale": 1.0}, units: int = 1, ): self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist self.coef_kwargs = coef_kwargs self.lmbda_kwargs = lmbda_kwargs self.units = units
[docs] def __call__( self, name: str, x: jax.Array, z: jax.Array, ) -> jax.Array: # get shapes and reshape if necessary x = add_trailing_dim(x) z = add_trailing_dim(z) input_shape1 = x.shape[1] input_shape2 = z.shape[1] # sampling block lmbda = sample( name=f"{self.__class__.__name__}_{name}_lmbda1", fn=self.lmbda_dist(**self.lmbda_kwargs).expand([self.units]), ) beta = sample( name=f"{self.__class__.__name__}_{name}_beta1", fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( [input_shape1 * input_shape2, self.units] ), ) return _matmul_interaction(x, z, beta)