Samplers#

The sampler draws electron configurations from \(|\psi|^2\) using Markov Chain Monte Carlo (MCMC). See Sampling for background on how MCMC sampling works and how to tune the sampler.

Configuration#

For sampler config keys, see the configuration reference: Molecule, Solid, or Hall.

Protocols#

class jaqmc.sampler.SamplerLike(*args, **kwargs)[source]#

Protocol for samplers.

A sampler must have two callable attributes: init and step. Defining them as methods on a class satisfies this protocol.

Type Parameters:

StateT – Sampler state type threaded through init and step.

init: jaqmc.sampler.base.SamplerInit[source]#
step: jaqmc.sampler.base.SamplerStep[source]#
class jaqmc.sampler.base.SamplerInit(*args, **kwargs)[source]#
__call__(data, rngs)[source]#

Returns initial sampler state.

Parameters:
  • data (PyTree) – Initial data.

  • rngs (PRNGKey) – Random state.

Return type:

TypeVar(StateT, bound= ArrayLikeTree)

class jaqmc.sampler.base.SamplerStep(*args, **kwargs)[source]#
__call__(batch_log_prob, data, state, rngs)[source]#

Sample walker data.

Parameters:
Return type:

tuple[PyTree, dict[str, Any], TypeVar(StateT, bound= ArrayLikeTree)]

Returns:

A tuple of - data: Sampled data. - stats: Statistical variables of the sampler. - state: New state of the sampler.

Type Parameters:

StateT – Sampler state type threaded through successive updates.

class jaqmc.sampler.base.BatchLogProb(*args, **kwargs)[source]#
__call__(data)[source]#

Returns log-probability (2 * log|psi|) over a batch of walkers.

Parameters:

data (PyTree) – A batched PyTree (the data subset for this sampler’s keys).

Return type:

Array

Returns:

A 1D array (batch_size,) with one log-probability per walker.

class jaqmc.sampler.SamplePlan(log_amplitude, samplers=None)[source]#

Coordinate how one or more samplers update batched walker data.

In the common case, a workflow registers a single sampler for the "electrons" field. More complex systems can register different samplers for different fields, or update multiple fields together.

A sample plan initializes sampler state, runs each sampler on its assigned part of the data, and combines the results back into one BatchedData object.

Parameters:
init(batched_data, rngs)[source]#

Initialize the state for all registered samplers.

Each sampler receives the part of batched_data that it is responsible for and returns its own sampler state.

Parameters:
  • batched_data (BatchedData) – Current batched walker data.

  • rngs (PRNGKey) – Random key used for sampler initialization.

Return type:

dict[str, Any]

Returns:

A mapping of sampler states to pass back into step().

step(params, batched_data, state, rngs)[source]#

Run one sampling round and return updated walker data.

Each registered sampler proposes updates for its own fields, and the plan combines those updates into a new batched data object. Even when a sampler updates only part of the data, proposals are scored using the full wavefunction.

Parameters:
  • params (Params) – Wavefunction parameters.

  • batched_data (BatchedData) – Current batched walker data.

  • state (dict[str, Any]) – Sampler state produced by init() and previous calls to step().

  • rngs (PRNGKey) – Random key for this sampling round.

Return type:

tuple[BatchedData, dict[str, Any], dict[str, Any]]

Returns:

A tuple (batched_data, stats, state) containing the updated walker data, sampler statistics, and updated sampler state.

will_sample(keys, sampler)[source]#

Register which data fields a sampler should update.

Use this when one sampler should control a specific field such as "electrons", or when a sampler should update several fields together.

Parameters:
  • keys (str | tuple[str]) – Field name or field names handled by this sampler.

  • sampler (SamplerLike) – Sampler instance that proposes updates for those fields.

Built-in samplers#

class jaqmc.sampler.mcmc.MCMCSampler(*, steps=10, initial_width=0.1, adapt_frequency=100, pmove_range=(0.5, 0.55), sampling_proposal=<function gaussian_proposal>)[source]#

Metropolis-Hastings MCMC sampler.

Parameters:
  • steps (int, default: 10) – Number of Metropolis-Hastings updates per sample draw. Controls decorrelation between consecutive samples. Also determines the granularity of burn-in (see WorkStageConfig.burn_in).

  • initial_width (float, default: 0.1) – Initial width (stddev) of the Gaussian proposal.

  • adapt_frequency (int, default: 100) – Frequency of adaptive width updates.

  • pmove_range (tuple[float, float], default: (0.5, 0.55)) – Target range for acceptance rate.

  • sampling_proposal (SamplingProposal, default: <function gaussian_proposal at 0x75b4cfc276a0>) – Proposal function for MCMC moves.

init(data, rngs)[source]#

Initialize adaptive Metropolis-Hastings sampler state.

Parameters:
  • data – Current sample data. Not used.

  • rngs – Random key. Not used.

Returns:

Initial MCMCState containing proposal width, acceptance history, and adaptation counter.

step(batch_log_prob, data, state, rngs)[source]#

Run multiple MH updates and adapt proposal width.

Parameters:
  • batch_log_prob (BatchLogProb) – Log-probability function over a batch.

  • data (TypeVar(StateT, bound= PyTree)) – Current MCMC configurations.

  • state (MCMCState) – Sampler state.

  • rngs (PRNGKey) – Random state.

Return type:

tuple[TypeVar(StateT, bound= PyTree), dict[str, Any], MCMCState]

Returns:

Tuple of (data, stats, new_state) after one sampler step.

Type Parameters:

StateT – PyTree sample/state type threaded through the step.

Raises:

ValueError – If batch_log_prob(data) does not return shape (batch_size,).

class jaqmc.sampler.mcmc.SamplingProposal(*args, **kwargs)[source]#
__call__(rngs, x, stddev)[source]#

Propose a new sample from the current state.

Parameters:
Return type:

TypeVar(StateT, bound= PyTree)

Returns:

Proposed sample/state with the same structure as x.

Type Parameters:

StateT – PyTree sample/state type preserved by the proposal.

class jaqmc.sampler.mcmc.MCMCState(stddev: Array, pmoves: Array, counter: Array)[source]#

State carried by the adaptive Metropolis-Hastings sampler.

stddev[source]#

Current proposal width.

pmoves[source]#

Rolling acceptance-rate history used for width adaptation.

counter[source]#

Number of completed sampler steps.