Fit#

High-level fitting API for blayers models.

Reduces boilerplate when fitting Bayesian models by providing sensible defaults for guides, optimizers, learning rate schedules, and prediction.

Example

from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link
from blayers.fit import fit

def model(x, y=None):
    mu = AdaptiveLayer()('mu', x)
    return gaussian_link(mu, y)

# Fit with batched VI
result = fit(model, y=y_train, batch_size=1024, num_epochs=100, x=x_train)

# Predict on new data
preds = result.predict(x=x_test)
print(preds.mean, preds.std)

# Fit with MCMC
result = fit(model, y=y_train, method="mcmc", x=x_train)
preds = result.predict(x=x_test)

# Fit with Stein Variational Gradient Descent
result = fit(model, y=y_train, method="svgd", num_steps=500, x=x_train)
preds = result.predict(x=x_test)

Constants (non-array kwargs) are automatically bound via functools.partial, so you never need to wrap your model manually:

def model(x, n_conditions, y=None):
    ...

# n_conditions is an int → auto-bound; x is an array → batched
result = fit(model, y=y_train, batch_size=4096, num_epochs=250,
             x=x_train, n_conditions=10)
class blayers.fit.Predictions(mean, std, samples)[source]#

Bases: object

Posterior predictive output from FittedModel.predict().

Parameters:
  • mean (Array)

  • std (Array)

  • samples (Array)

mean#

Point predictions averaged over posterior samples. Shape (n,).

Type:

jax.Array

std#

Predictive standard deviation over posterior samples. Shape (n,).

Type:

jax.Array

samples#

Raw posterior predictive draws. Shape (num_samples, n, ...).

Type:

jax.Array

mean: Array#
std: Array#
samples: Array#
__init__(mean, std, samples)#
Parameters:
  • mean (Array)

  • std (Array)

  • samples (Array)

Return type:

None

class blayers.fit.FittedModel(model_fn, method, params=None, guide=None, losses=None, posterior_samples=None, num_particles=None)[source]#

Bases: object

A fitted blayers model.

Created by fit(). Provides predict() for posterior predictive inference and summary() for inspecting latent variable posteriors.

Parameters:
  • model_fn (Callable)

  • method (str)

  • params (dict | None)

  • guide (Any | None)

  • losses (Array | None)

  • posterior_samples (dict | None)

  • num_particles (int | None)

model_fn#

The model function with any constants already bound.

Type:

Callable

method#

"vi" or "mcmc".

Type:

str

params#

SVI parameters (VI only).

Type:

dict or None

guide#

Fitted variational guide (VI only).

Type:

AutoGuide or None

losses#

Per-step ELBO loss curve (VI only).

Type:

jax.Array or None

posterior_samples#

MCMC posterior samples (MCMC only).

Type:

dict or None

model_fn: Callable#
method: str#
params: dict | None = None#
guide: Any | None = None#
losses: Array | None = None#
posterior_samples: dict | None = None#
num_particles: int | None = None#
predict(*, num_samples=100, seed=1, **data)[source]#

Generate posterior predictive predictions on new data.

Parameters:
  • num_samples (int) – Number of posterior samples to draw. For VI this controls the guide; for MCMC all posterior samples are used regardless.

  • seed (int) – Random seed for the predictive distribution.

  • **data – Model inputs excluding y. Constants that were auto-bound during fit() should not be passed again.

Return type:

Predictions

summary(*, num_samples=1000, seed=2, **data)[source]#

Summarize the posterior of each latent variable.

Parameters:
  • num_samples (int) – Samples to draw from the guide (VI only; ignored for MCMC).

  • seed (int) – Random seed.

  • **data – Model inputs (excluding y) needed so the guide can determine parameter shapes. Required for VI; ignored for MCMC.

Returns:

{site_name: {"mean": ..., "std": ..., "q025": ..., "q975": ..., "shape": ...}}

Return type:

dict

__init__(model_fn, method, params=None, guide=None, losses=None, posterior_samples=None, num_particles=None)#
Parameters:
  • model_fn (Callable)

  • method (str)

  • params (dict | None)

  • guide (Any | None)

  • losses (Array | None)

  • posterior_samples (dict | None)

  • num_particles (int | None)

Return type:

None

blayers.fit.fit(model_fn, *, y, method='vi', batch_size=None, num_epochs=None, num_steps=None, lr=0.01, schedule='cosine', guide=None, optimizer=None, num_warmup=500, num_mcmc_samples=1000, num_chains=1, autoreparam_model=True, num_particles=10, kernel_fn=None, seed=0, **kwargs)[source]#

Fit a blayers model via variational inference, MCMC, or SVGD.

Keyword arguments that are JAX/numpy arrays are treated as data and batched during training. Non-array keyword arguments (ints, floats, strings, etc.) are treated as constants and bound to the model via functools.partial so they don’t need to be passed again at predict time.

Parameters:
  • model_fn (Callable) – A NumPyro model function that accepts y as a keyword argument.

  • y (jax.Array) – Target / observed values.

  • method ("vi", "mcmc", or "svgd") – Inference method. Default "vi".

  • batch_size (int, optional) – Mini-batch size for VI. If None the full dataset is used each step (appropriate for small datasets).

  • num_epochs (int, optional) – Number of full passes through the data. Exactly one of num_epochs or num_steps is required for VI and SVGD.

  • num_steps (int, optional) – Total number of gradient updates. Exactly one of num_epochs or num_steps is required for VI and SVGD.

  • lr (float) – Peak learning rate (default 0.01). For SVGD this is the Adagrad step size. Ignored when optimizer is given.

  • schedule (str) – LR schedule name: "cosine" (default), "warmup_cosine", or "constant". Only used for VI.

  • guide (type or AutoGuide instance, optional) – Variational family. Pass a class (instantiated on model_fn) or a ready-to-use instance. Default: AutoDiagonalNormal. Not used for SVGD (which auto-generates an AutoDelta guide).

  • optimizer (optax.GradientTransformation, optional) – A fully-constructed optax optimizer. When provided, lr and schedule are ignored. Not used for SVGD.

  • num_warmup (int) – MCMC warmup iterations (default 500).

  • num_mcmc_samples (int) – MCMC posterior samples to draw (default 1000).

  • num_chains (int) – Number of MCMC chains (default 1).

  • autoreparam_model (bool) – Automatically reparameterize LocScale distributions for MCMC (default True).

  • num_particles (int) – Number of Stein particles (default 10). Only used for SVGD.

  • kernel_fn (SteinKernel, optional) – Kernel for SVGD. Default: RBFKernel().

  • seed (int) – Random seed (default 0).

  • **kwargs – Model inputs. Arrays → batched data. Non-arrays → constants bound via partial.

Returns:

Object with .predict(**data) and .summary(**data) methods.

Return type:

FittedModel

Examples

Batched VI (the common case for large datasets):

>>> result = fit(model, y=y_train, batch_size=4096, num_epochs=250,
...              x=x_train, n_conditions=10)
>>> preds = result.predict(x=x_test)

Full-dataset VI (small datasets):

>>> result = fit(model, y=y_train, num_steps=20000, x=x_train)

MCMC:

>>> result = fit(model, y=y_train, method="mcmc", x=x_train)

SVGD:

>>> result = fit(model, y=y_train, method="svgd", num_steps=500,
...              num_particles=20, x=x_train)