Infer#
- class blayers.infer.Batched_Trace_ELBO(num_obs, num_particles=1, batch_size=None)[source]#
Bases:
ELBO
- Parameters:
num_obs (int)
num_particles (int)
batch_size (int | None)
- __init__(num_obs, num_particles=1, batch_size=None)[source]#
- Parameters:
num_obs (int)
num_particles (int)
batch_size (int | None)
- loss(rng_key, param_map, model, guide, *args, **kwargs)[source]#
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
- Parameters:
rng_key (jax.random.PRNGKey) – random number generator seed.
param_map (dict) – dictionary of current parameter values keyed by site name.
model (Callable[[...], Any]) – Python callable with NumPyro primitives for the model.
guide (Callable[[...], Any]) – Python callable with NumPyro primitives for the guide.
args (Any) – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs (Any) – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns:
negative of the Evidence Lower Bound (ELBO) to be minimized.
- Return type:
Array