Fit#
High-level fitting API for blayers models.
Reduces boilerplate when fitting Bayesian models by providing sensible defaults for guides, optimizers, learning rate schedules, and prediction.
Example
from blayers.layers import AdaptiveLayer
from blayers.links import gaussian_link
from blayers.fit import fit
def model(x, y=None):
mu = AdaptiveLayer()('mu', x)
return gaussian_link(mu, y)
# Fit with batched VI
result = fit(model, y=y_train, batch_size=1024, num_epochs=100, x=x_train)
# Predict on new data
preds = result.predict(x=x_test)
print(preds.mean, preds.std)
# Fit with MCMC
result = fit(model, y=y_train, method="mcmc", x=x_train)
preds = result.predict(x=x_test)
# Fit with Stein Variational Gradient Descent
result = fit(model, y=y_train, method="svgd", num_steps=500, x=x_train)
preds = result.predict(x=x_test)
Constants (non-array kwargs) are automatically bound via functools.partial,
so you never need to wrap your model manually:
def model(x, n_conditions, y=None):
...
# n_conditions is an int → auto-bound; x is an array → batched
result = fit(model, y=y_train, batch_size=4096, num_epochs=250,
x=x_train, n_conditions=10)
- class blayers.fit.Predictions(mean, std, samples)[source]#
Bases:
objectPosterior predictive output from
FittedModel.predict().- Parameters:
mean (Array)
std (Array)
samples (Array)
- mean#
Point predictions averaged over posterior samples. Shape
(n,).- Type:
jax.Array
- std#
Predictive standard deviation over posterior samples. Shape
(n,).- Type:
jax.Array
- samples#
Raw posterior predictive draws. Shape
(num_samples, n, ...).- Type:
jax.Array
- mean: Array#
- std: Array#
- samples: Array#
- __init__(mean, std, samples)#
- Parameters:
mean (Array)
std (Array)
samples (Array)
- Return type:
None
- class blayers.fit.FittedModel(model_fn, method, params=None, guide=None, losses=None, posterior_samples=None, num_particles=None)[source]#
Bases:
objectA fitted blayers model.
Created by
fit(). Providespredict()for posterior predictive inference andsummary()for inspecting latent variable posteriors.- Parameters:
model_fn (Callable)
method (str)
params (dict | None)
guide (Any | None)
losses (Array | None)
posterior_samples (dict | None)
num_particles (int | None)
- model_fn#
The model function with any constants already bound.
- Type:
Callable
- method#
"vi"or"mcmc".- Type:
str
- params#
SVI parameters (VI only).
- Type:
dict or None
- guide#
Fitted variational guide (VI only).
- Type:
AutoGuide or None
- losses#
Per-step ELBO loss curve (VI only).
- Type:
jax.Array or None
- posterior_samples#
MCMC posterior samples (MCMC only).
- Type:
dict or None
- model_fn: Callable#
- method: str#
- params: dict | None = None#
- guide: Any | None = None#
- losses: Array | None = None#
- posterior_samples: dict | None = None#
- num_particles: int | None = None#
- predict(*, num_samples=100, seed=1, **data)[source]#
Generate posterior predictive predictions on new data.
- Parameters:
num_samples (int) – Number of posterior samples to draw. For VI this controls the guide; for MCMC all posterior samples are used regardless.
seed (int) – Random seed for the predictive distribution.
**data – Model inputs excluding
y. Constants that were auto-bound duringfit()should not be passed again.
- Return type:
- summary(*, num_samples=1000, seed=2, **data)[source]#
Summarize the posterior of each latent variable.
- Parameters:
num_samples (int) – Samples to draw from the guide (VI only; ignored for MCMC).
seed (int) – Random seed.
**data – Model inputs (excluding
y) needed so the guide can determine parameter shapes. Required for VI; ignored for MCMC.
- Returns:
{site_name: {"mean": ..., "std": ..., "q025": ..., "q975": ..., "shape": ...}}- Return type:
dict
- __init__(model_fn, method, params=None, guide=None, losses=None, posterior_samples=None, num_particles=None)#
- Parameters:
model_fn (Callable)
method (str)
params (dict | None)
guide (Any | None)
losses (Array | None)
posterior_samples (dict | None)
num_particles (int | None)
- Return type:
None
- blayers.fit.fit(model_fn, *, y, method='vi', batch_size=None, num_epochs=None, num_steps=None, lr=0.01, schedule='cosine', guide=None, optimizer=None, num_warmup=500, num_mcmc_samples=1000, num_chains=1, autoreparam_model=True, num_particles=10, kernel_fn=None, seed=0, **kwargs)[source]#
Fit a blayers model via variational inference, MCMC, or SVGD.
Keyword arguments that are JAX/numpy arrays are treated as data and batched during training. Non-array keyword arguments (ints, floats, strings, etc.) are treated as constants and bound to the model via
functools.partialso they don’t need to be passed again at predict time.- Parameters:
model_fn (Callable) – A NumPyro model function that accepts
yas a keyword argument.y (jax.Array) – Target / observed values.
method (
"vi","mcmc", or"svgd") – Inference method. Default"vi".batch_size (int, optional) – Mini-batch size for VI. If None the full dataset is used each step (appropriate for small datasets).
num_epochs (int, optional) – Number of full passes through the data. Exactly one of num_epochs or num_steps is required for VI and SVGD.
num_steps (int, optional) – Total number of gradient updates. Exactly one of num_epochs or num_steps is required for VI and SVGD.
lr (float) – Peak learning rate (default 0.01). For SVGD this is the Adagrad step size. Ignored when optimizer is given.
schedule (str) – LR schedule name:
"cosine"(default),"warmup_cosine", or"constant". Only used for VI.guide (type or AutoGuide instance, optional) – Variational family. Pass a class (instantiated on model_fn) or a ready-to-use instance. Default:
AutoDiagonalNormal. Not used for SVGD (which auto-generates anAutoDeltaguide).optimizer (optax.GradientTransformation, optional) – A fully-constructed optax optimizer. When provided, lr and schedule are ignored. Not used for SVGD.
num_warmup (int) – MCMC warmup iterations (default 500).
num_mcmc_samples (int) – MCMC posterior samples to draw (default 1000).
num_chains (int) – Number of MCMC chains (default 1).
autoreparam_model (bool) – Automatically reparameterize LocScale distributions for MCMC (default True).
num_particles (int) – Number of Stein particles (default 10). Only used for SVGD.
kernel_fn (SteinKernel, optional) – Kernel for SVGD. Default:
RBFKernel().seed (int) – Random seed (default 0).
**kwargs – Model inputs. Arrays → batched data. Non-arrays → constants bound via
partial.
- Returns:
Object with
.predict(**data)and.summary(**data)methods.- Return type:
Examples
Batched VI (the common case for large datasets):
>>> result = fit(model, y=y_train, batch_size=4096, num_epochs=250, ... x=x_train, n_conditions=10) >>> preds = result.predict(x=x_test)
Full-dataset VI (small datasets):
>>> result = fit(model, y=y_train, num_steps=20000, x=x_train)
MCMC:
>>> result = fit(model, y=y_train, method="mcmc", x=x_train)
SVGD:
>>> result = fit(model, y=y_train, method="svgd", num_steps=500, ... num_particles=20, x=x_train)