Infer#
- class blayers.infer.Batched_Trace_ELBO(n_obs: int, num_particles: int = 1, batch_size: int | None = None)[source]#
Bases:
ELBO
- elbo_components(rng_key: Array, param_map: dict[str, Array], model: Callable[[...], Any], guide: Callable[[...], Any], *args: Any, **kwargs: Any) dict[str, Array] [source]#
- loss(rng_key: Array, param_map: dict[str, Array], model: Callable[[...], Any], guide: Callable[[...], Any], *args: Any, **kwargs: Any) Array [source]#
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.
- Parameters:
rng_key (jax.random.PRNGKey) – random number generator seed.
param_map (dict) – dictionary of current parameter values keyed by site name.
model – Python callable with NumPyro primitives for the model.
guide – Python callable with NumPyro primitives for the guide.
args – arguments to the model / guide (these can possibly vary during the course of fitting).
kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).
- Returns:
negative of the Evidence Lower Bound (ELBO) to be minimized.