Infer#

class blayers.infer.Batched_Trace_ELBO(num_obs, num_particles=1, batch_size=None)[source]#

Bases: ELBO

Parameters:
  • num_obs (int)

  • num_particles (int)

  • batch_size (int | None)

__init__(num_obs, num_particles=1, batch_size=None)[source]#
Parameters:
  • num_obs (int)

  • num_particles (int)

  • batch_size (int | None)

loss(rng_key, param_map, model, guide, *args, **kwargs)[source]#

Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

Parameters:
  • rng_key (jax.random.PRNGKey) – random number generator seed.

  • param_map (dict) – dictionary of current parameter values keyed by site name.

  • model (Callable[[...], Any]) – Python callable with NumPyro primitives for the model.

  • guide (Callable[[...], Any]) – Python callable with NumPyro primitives for the guide.

  • args (Any) – arguments to the model / guide (these can possibly vary during the course of fitting).

  • kwargs (Any) – keyword arguments to the model / guide (these can possibly vary during the course of fitting).

Returns:

negative of the Evidence Lower Bound (ELBO) to be minimized.

Return type:

Array

elbo_components(rng_key, param_map, model, guide, *args, **kwargs)[source]#
Parameters:
  • rng_key (Array)

  • param_map (dict[str, Array])

  • model (Callable[[...], Any])

  • guide (Callable[[...], Any])

  • args (Any)

  • kwargs (Any)

Return type:

dict[str, Array]

blayers.infer.svi_run_batched(svi, rng_key, batch_size, num_steps=None, num_epochs=None, **data)[source]#
Parameters:
  • svi (SVI)

  • rng_key (Array)

  • batch_size (int)

  • num_steps (int | None)

  • num_epochs (int | None)

  • data (dict[str, Array])

Return type:

SVIRunResult