Infer#

class blayers.infer.Batched_Trace_ELBO(n_obs: int, num_particles: int = 1, batch_size: int | None = None)[source]#

Bases: ELBO

elbo_components(rng_key: Array, param_map: dict[str, Array], model: Callable[[...], Any], guide: Callable[[...], Any], *args: Any, **kwargs: Any) dict[str, Array][source]#
loss(rng_key: Array, param_map: dict[str, Array], model: Callable[[...], Any], guide: Callable[[...], Any], *args: Any, **kwargs: Any) Array[source]#

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • param_map (dict) – dictionary of current parameter values keyed by site name.

  • model – Python callable with NumPyro primitives for the model.

  • guide – Python callable with NumPyro primitives for the guide.

  • args – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

negative of the Evidence Lower Bound (ELBO) to be minimized.

blayers.infer.svi_run_batched(svi: SVI, rng_key: Array, num_steps: int, batch_size: int, **data: dict[str, Array]) SVIRunResult[source]#