Samplers#
The sampler draws electron configurations from \(|\psi|^2\) using Markov Chain Monte Carlo (MCMC). See Sampling for background on how MCMC sampling works and how to tune the sampler.
Configuration#
For sampler config keys, see the configuration reference: Molecule, Solid, or Hall.
Protocols#
- class jaqmc.sampler.SamplerLike(*args, **kwargs)[source]#
Protocol for samplers.
A sampler must have two callable attributes:
initandstep. Defining them as methods on a class satisfies this protocol.- Type Parameters:
StateT – Sampler state type threaded through
initandstep.
- class jaqmc.sampler.base.SamplerStep(*args, **kwargs)[source]#
- __call__(batch_log_prob, data, state, rngs)[source]#
Sample walker data.
- Parameters:
batch_log_prob (
BatchLogProb) – Function to be sampled.data (
PyTree) – Previously sampled data.state (
TypeVar(StateT, bound=ArrayLikeTree)) – State of the sampler.rngs (
PRNGKey) – Random state.
- Return type:
tuple[PyTree,dict[str,Any],TypeVar(StateT, bound=ArrayLikeTree)]- Returns:
A tuple of - data: Sampled data. - stats: Statistical variables of the sampler. - state: New state of the sampler.
- Type Parameters:
StateT – Sampler state type threaded through successive updates.
- class jaqmc.sampler.SamplePlan(log_amplitude, samplers=None)[source]#
Coordinate how one or more samplers update batched walker data.
In the common case, a workflow registers a single sampler for the
"electrons"field. More complex systems can register different samplers for different fields, or update multiple fields together.A sample plan initializes sampler state, runs each sampler on its assigned part of the data, and combines the results back into one
BatchedDataobject.- Parameters:
log_amplitude (
NumericWavefunctionEvaluate) – Wavefunction log-amplitude used to score sampling proposals.samplers (
Mapping[str|tuple[str],SamplerLike] |None, default:None) – Optional mapping from field names, or tuples of field names, to sampler instances.
- init(batched_data, rngs)[source]#
Initialize the state for all registered samplers.
Each sampler receives the part of
batched_datathat it is responsible for and returns its own sampler state.
- step(params, batched_data, state, rngs)[source]#
Run one sampling round and return updated walker data.
Each registered sampler proposes updates for its own fields, and the plan combines those updates into a new batched data object. Even when a sampler updates only part of the data, proposals are scored using the full wavefunction.
- Parameters:
- Return type:
- Returns:
A tuple
(batched_data, stats, state)containing the updated walker data, sampler statistics, and updated sampler state.
- will_sample(keys, sampler)[source]#
Register which data fields a sampler should update.
Use this when one sampler should control a specific field such as
"electrons", or when a sampler should update several fields together.- Parameters:
keys (
str|tuple[str]) – Field name or field names handled by this sampler.sampler (
SamplerLike) – Sampler instance that proposes updates for those fields.
Built-in samplers#
- class jaqmc.sampler.mcmc.MCMCSampler(*, steps=10, initial_width=0.1, adapt_frequency=100, pmove_range=(0.5, 0.55), sampling_proposal=<function gaussian_proposal>)[source]#
Metropolis-Hastings MCMC sampler.
- Parameters:
steps (
int, default:10) – Number of Metropolis-Hastings updates per sample draw. Controls decorrelation between consecutive samples. Also determines the granularity of burn-in (seeWorkStageConfig.burn_in).initial_width (
float, default:0.1) – Initial width (stddev) of the Gaussian proposal.adapt_frequency (
int, default:100) – Frequency of adaptive width updates.pmove_range (
tuple[float,float], default:(0.5, 0.55)) – Target range for acceptance rate.sampling_proposal (
SamplingProposal, default:<function gaussian_proposal at 0x75b4cfc276a0>) – Proposal function for MCMC moves.
- init(data, rngs)[source]#
Initialize adaptive Metropolis-Hastings sampler state.
- Parameters:
data – Current sample data. Not used.
rngs – Random key. Not used.
- Returns:
Initial
MCMCStatecontaining proposal width, acceptance history, and adaptation counter.
- step(batch_log_prob, data, state, rngs)[source]#
Run multiple MH updates and adapt proposal width.
- Parameters:
batch_log_prob (
BatchLogProb) – Log-probability function over a batch.data (
TypeVar(StateT, bound=PyTree)) – Current MCMC configurations.state (
MCMCState) – Sampler state.rngs (
PRNGKey) – Random state.
- Return type:
tuple[TypeVar(StateT, bound=PyTree),dict[str,Any],MCMCState]- Returns:
Tuple of
(data, stats, new_state)after one sampler step.- Type Parameters:
StateT – PyTree sample/state type threaded through the step.
- Raises:
ValueError – If
batch_log_prob(data)does not return shape(batch_size,).
- class jaqmc.sampler.mcmc.SamplingProposal(*args, **kwargs)[source]#