Last active
March 9, 2024 00:38
-
-
Save amifalk/950439c10063f0a75ac99fa2d277825f to your computer and use it in GitHub Desktop.
Functional NumPyro MCMC Wrapper
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from operator import attrgetter | |
import jax | |
import jax.numpy as jnp | |
import jax.random as random | |
@partial(jax.jit, static_argnames=['field_names']) | |
def collect_fields(state, field_names: tuple): | |
""" | |
Collect fields from a state (i.e. a `namedtuple`) | |
Returns: | |
dict: {*field_names: *collected_fields} | |
""" | |
collected_fields = attrgetter(*field_names)(state) | |
if len(field_names) == 1: | |
return {field_names[0]: collected_fields} | |
else: | |
return dict(zip(field_names, collected_fields)) | |
@partial(jax.jit, static_argnames=['transform', 'sites_subset']) | |
def transform_and_subset_sample_sites(z, transform, sites_subset): | |
z = transform(z) | |
if sites_subset: | |
return {site: z[site] for site in sites_subset} | |
else: | |
return z | |
def run_mcmc(rng_key, | |
kernel, | |
model_args, | |
model_kwargs, | |
*, | |
num_warmup: int, | |
num_samples: int, | |
extra_fields=(), | |
sites_subset=None, | |
return_warmup=False, | |
): | |
""" | |
A thin functional wrapper around NumPyro `MCMCKernel`s that allows for vectorization/parallelism | |
over model_args and model_kwargs and optionally avoids storing nuisance sample sites. The API and | |
implementation draw heavily on work in the [BlackJAX](https://github.com/blackjax-devs/blackjax) | |
and [NumPyro](https://github.com/pyro-ppl/numpyro) repositories. | |
Args: | |
rng_key (PRNGKey): JAX PRNGKey. For `EnsembleSampler` kernels, the number of keys must match the | |
number of desired chains. | |
kernel (MCMCKernel): MCMCKernel to use for inference. | |
model_args (tuple): Tuple containing model arguments. | |
model_kwargs (dict): Dictionary containing model keyword arguments. | |
num_warmup (int): Number of warmup steps to take. | |
num_samples (int): Number of samples to take. | |
extra_fields (tuple, optional): Tuple containing the names of fields from the kernel state | |
to collect. If empty, will collect the kernel's default fields. Defaults to (). | |
sites_subset (tuple, optional): Tuple containing the names of sample sites to collect. If None, | |
collects all sites. Defaults to None. | |
return_warmup (bool, optional): Whether or not to return warmup samples. Defaults to False. | |
Returns: | |
(samples, other_fields): Tuple containing samples and collected fields from the state object. | |
""" | |
init_state = kernel.init(rng_key, | |
num_warmup=num_warmup, | |
model_args=model_args, | |
model_kwargs=model_kwargs) | |
transform = kernel.postprocess_fn(model_args, model_kwargs) | |
to_collect = tuple(set((kernel.sample_field,) + kernel.default_fields + extra_fields)) | |
def step(state, iter): | |
state = kernel.sample(state, model_args, model_kwargs) | |
collected_fields = collect_fields(state, to_collect) | |
sample_sites = collected_fields.pop(kernel.sample_field) | |
sample = transform_and_subset_sample_sites(sample_sites, transform, sites_subset) | |
return state, (sample, collected_fields) | |
final_state, (samples, other_fields) = jax.lax.scan(step, init_state, jnp.arange(num_warmup + num_samples)) | |
if not return_warmup: | |
samples = jax.tree_util.tree_map(lambda x: x[num_warmup:], samples) | |
other_fields = jax.tree_util.tree_map(lambda x: x[num_warmup:], other_fields) | |
return samples, other_fields | |
if __name__ == '__main__': | |
import numpyro | |
import numpyro.distributions as dist | |
from numpyro.infer import NUTS | |
numpyro.set_platform('cpu') | |
numpyro.set_host_device_count('4') | |
def model(data): | |
mu = numpyro.sample('mu', dist.Normal(0, 2)) | |
sigma = numpyro.sample('sigma', dist.HalfCauchy(1)) | |
with numpyro.plate('n_obs', len(data)): | |
numpyro.sample('data', dist.Normal(mu, sigma), obs=data) | |
key = random.PRNGKey(0) | |
ground_truth_mu = jnp.array([1, 2, 3, 4])[:, None] | |
data = random.normal(key, (4, 200))*0.1 + ground_truth_mu | |
model_args = (data,) | |
model_kwargs = {} | |
keys = random.split(key, 4) | |
batch_run_mcmc = partial(run_mcmc, num_warmup=5000, num_samples=5000, sites_subset=('mu',)) | |
samps, other_fields = jax.pmap(batch_run_mcmc, | |
in_axes=(0, None, 0, None), | |
static_broadcasted_argnums=1)(keys, | |
NUTS(model), | |
model_args, | |
model_kwargs) | |
print(samps['mu'].shape) | |
assert jnp.allclose(jnp.mean(samps['mu'], axis=1), ground_truth_mu.squeeze(), | |
atol=.01) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment