Wavefunctions#
A wavefunction maps electron configurations to log-amplitudes (and optionally signs or phases). All JaQMC wavefunctions are Flax Linen modules that subclass Wavefunction.
For architecture-specific options (FermiNet, Psiformer, etc.), see the configuration reference pages (Molecules, Solids, Quantum Hall).
Base class and protocols#
- class jaqmc.wavefunction.Wavefunction(parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Base class for JaQMC wavefunctions.
A
Wavefunctionis a Flaxnn.Modulewith three complementary execution interfaces:__call__(data): model definition for one walker (implemented by subclasses).apply(params, data, ...): Flax runtime API that executes methods with explicit variables.evaluate(params, data): JaQMC wrapper overapplythat provides a stable typed contract for framework consumers.
- abstractmethod __call__(data)[source]#
Define the forward pass for one walker.
This method should be written in standard Flax style and does not take explicit parameters.
- Return type:
TypeVar(OutputT)
- evaluate(params, data)[source]#
Framework-level execution entrypoint.
JaQMC wrapper over
apply()that provides a stable typed contract, following the signature ofWavefunctionEvaluate.- Return type:
TypeVar(OutputT)- Returns:
The wavefunction output for the provided parameters and data.
- init_params(data, rngs)[source]#
Initialize parameters from one sample walker.
JaQMC wrapper over
init()that provides a stable typed contract, following the signature ofWavefunctionInit.- Return type:
- Returns:
A PyTree containing the initialized wavefunction parameters.
- class jaqmc.wavefunction.WavefunctionLike(*args, **kwargs)[source]#
Minimal wavefunction interface required by framework components.
App-level protocols can extend this with domain-specific methods such as
logpsi,phase_logpsi, ororbitals.- init_params: jaqmc.wavefunction.base.WavefunctionInit[source]#
- class jaqmc.wavefunction.WavefunctionEvaluate(*args, **kwargs)[source]#
- type jaqmc.wavefunction.base.NumericWavefunctionEvaluate = WavefunctionEvaluate[DataT, Array][source]#
Callable protocol for one-walker numeric wavefunction evaluation.
This alias specializes
WavefunctionEvaluateto scalar-array outputs, which are typically log-amplitude values such aslog|psi|.
Output types#
- class jaqmc.wavefunction.base.RealWFOutput[source]#
Structured output for real-valued wavefunctions.
- class jaqmc.wavefunction.base.ComplexWFOutput[source]#
Structured output for complex-valued wavefunctions.
- class jaqmc.wavefunction.base.LogPsiWFOutput[source]#
Minimal structured wavefunction output containing
logpsi.
Log-determinant output (used by FermiNet / Psiformer)#
- class jaqmc.wavefunction.output.logdet.RealLogDetOutput[source]#
Output of
LogDetfor real-valued orbital matrices.
Input features#
- class jaqmc.wavefunction.input.atomic.AtomicEmbedding[source]#
Output from MoleculeFeatures or SolidFeatures.
- class jaqmc.wavefunction.input.atomic.MoleculeFeatures(rescale=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Input features for molecular systems (OBC).
Backbone architectures#
- class jaqmc.wavefunction.backbone.ferminet.FermiLayers(nspins, hidden_dims, use_last_layer=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
FermiNet interaction layers with single- and double-electron streams.
Each layer updates the single-electron stream by aggregating features from both streams, then updates the double-electron stream independently. Residual connections are added when input and output dimensions match.
- Parameters:
nspins (
tuple[int,int]) – Tuple of (num_spin_up, num_spin_down) electrons.hidden_dims (
list[tuple[int,int]]) – List of (single_dim, double_dim) pairs, one per layer.use_last_layer (
bool, default:False) – If True, also update the double stream in the final layer and return aggregated features. If False, skip the final double-stream update.
- class jaqmc.wavefunction.backbone.psiformer.PsiformerBackbone(nspins, num_layers=2, num_heads=4, heads_dim=64, mlp_hidden_dims=(256, ), layer_norm_mode=LayerNormMode.pre, with_bias=True, input_bias=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Self-attention backbone for molecular wavefunctions.
Psiformer processes one-electron features through multiple self-attention layers, modeling electron-electron interactions implicitly through attention rather than explicit two-electron feature streams.
The architecture:
Concatenates spin encoding to input features
Projects to attention dimension
Applies
num_layersPsiformerLayer blocksOutputs processed one-electron features
- Parameters:
nspins (
tuple[int,int]) – Tuple of (num_spin_up, num_spin_down) electrons.num_layers (
int, default:2) – Number of Psiformer layers.num_heads (
int, default:4) – Number of attention heads.heads_dim (
int, default:64) – Dimension of each attention head.mlp_hidden_dims (
Sequence[int], default:(256,)) – Hidden dimensions for MLP blocks.layer_norm_mode (
LayerNormMode, default:<LayerNormMode.pre: 'pre'>) –LayerNorm application mode. Options:
LayerNormMode.pre: Apply LayerNorm before attention/MLP blocks (Pre-LN, matches internal_ferminet). Default.LayerNormMode.post: Apply LayerNorm after attention/MLP blocks (Post-LN, matches public FermiNet).LayerNormMode.null: No LayerNorm applied.
with_bias (
bool, default:True) – Whether to use bias in attention QKV projections.input_bias (
bool, default:True) – Whether to use bias in the input projection layer.
- class jaqmc.wavefunction.backbone.psiformer.PsiformerLayer(num_heads, heads_dim, mlp_hidden_dims, layer_norm_mode=LayerNormMode.pre, with_bias=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Single Psiformer layer combining attention and MLP with residual connections.
Each layer applies:
Optional LayerNorm (before or after, depending on mode)
Multi-head self-attention with residual connection
Optional LayerNorm (before or after, depending on mode)
MLP with residual connection
- Parameters:
num_heads (
int) – Number of attention heads.heads_dim (
int) – Dimension of each attention head.mlp_hidden_dims (
Sequence[int]) – Hidden dimensions for the MLP block.layer_norm_mode (
LayerNormMode, default:<LayerNormMode.pre: 'pre'>) –LayerNorm application mode. Options:
LayerNormMode.pre: Apply LayerNorm before attention/MLP blocks (Pre-LN, matches internal_ferminet). Default.LayerNormMode.post: Apply LayerNorm after attention/MLP blocks (Post-LN, matches public FermiNet).LayerNormMode.null: No LayerNorm applied.
with_bias (
bool, default:True) – Whether to use bias in attention QKV projections.
- class jaqmc.wavefunction.backbone.psiformer.LayerNormMode(*values)[source]#
LayerNorm application mode for Psiformer layers.
- pre[source]#
Apply LayerNorm before attention/MLP blocks (Pre-LN). Matches internal_ferminet implementation.
Orbital projection and envelope#
- class jaqmc.wavefunction.output.orbital.OrbitalProjection(nspins, ndets, orbitals_spin_split=True, use_bias=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Project backbone features to orbital matrix for molecules.
Handles both spin-split and non-split configurations:
Spin-split: Separate dense layers for each spin channel, allowing different orbital transformations for spin-up and spin-down electrons.
Non-split: Single dense layer shared across all electrons.
The output is reshaped and transposed to produce the standard orbital matrix format
(ndets, n_electrons, n_electrons)used by determinant layers.- Parameters:
nspins (
tuple[int,int]) – Tuple of (num_spin_up, num_spin_down) electrons.ndets (
int) – Number of determinants.orbitals_spin_split (
bool, default:True) – If True, use separate projection for each spin channel. Only effective when both spin channels are occupied.use_bias (
bool, default:False) – Whether to use bias in dense layers.
- class jaqmc.wavefunction.output.orbital.SplitChannelDense(channels, features, use_bias=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Apply separate dense layers to each spin channel.
- class jaqmc.wavefunction.output.envelope.Envelope(envelope_type, ndets, nspins, orbitals_spin_split=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Envelope function for wavefunctions.
Computes \(E_{ik} = \sum_I \pi_{kI} \exp(-d_{iIk})\) where the effective distance \(d_{iIk}\) depends on
EnvelopeType. Works for both OBC (molecules) and PBC (solids).- envelope_type[source]#
Type of envelope. See
EnvelopeType.
- class jaqmc.wavefunction.output.envelope.EnvelopeType(*values)[source]#
Envelope types controlling wavefunction decay near atoms.
The envelope \(E_{ik} = \sum_I \pi_{kI} \exp(-d_{iIk})\) modulates orbital output. The effective distance \(d_{iIk}\) depends on the type.
- isotropic[source]#
Scalar decay rate \(d = \sigma \| \mathbf{r} \|\). Simple but \(\sigma\) can go negative.
- class jaqmc.wavefunction.output.logdet.LogDet(parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Compute log-determinant sum from orbital matrices.
Given an orbital tensor of shape
(ndets, n, n), computes\[\log\psi = \log\!\sum_k \operatorname{sign}(\det M_k)\, \exp(\log|\det M_k|)\]using the logsumexp trick for numerical stability.
Jastrow factor#
- class jaqmc.wavefunction.jastrow.SimpleEEJastrow(nspins, alpha_init=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Jastrow factor for electron-electron cusp conditions.
Implements the simple electron-electron Jastrow factor from FermiNet:
\[J = \sum_{i<j} f(r_{ij})\]where the cusp function is:
\[f(r) = -\frac{c \alpha^2}{\alpha + r}\]with \(c = 0.25\) for parallel spins and \(c = 0.5\) for antiparallel spins.
This form satisfies the electron-electron cusp condition:
\[\frac{d f}{d r}\bigg|_{r=0} = \frac{c \alpha^2}{\alpha^2} = c\]The Jastrow factor is applied multiplicatively to orbitals as:
\[\phi'_{ij} = \phi_{ij} \cdot \exp(J / n_{electrons})\]
Data#
In most wavefunction and per-walker estimator hooks,
Data represents one walker’s structured runtime input.
BatchedData pairs a Data-shaped pytree with metadata
describing which fields carry a leading walker axis during batched execution.
Most user-defined wavefunctions and per-walker estimators only work with
Data. BatchedData is the lower-level representation used when framework code
or workflow plumbing needs to manipulate full walker batches explicitly.
For the built-in data-shape convention and the detailed explanation of
fields_with_batch, see Runtime Data Conventions.
- class jaqmc.data.Data[source]#
Base container for structured wavefunction input data.
Datainstances behave like lightweight, JAX-compatible dataclasses whose fields can be accessed both as attributes and as mapping keys. They are used to pass structured inputs (e.g. coordinates, atomic positions) between samplers, wavefunctions, and estimators.- merge(values)[source]#
Return a new instance with
valuesmerged into this one.- Parameters:
values (
Mapping[str,Any]) – Mapping from field names to replacement values.- Return type:
Self- Returns:
A new
Data(or subclass) instance where the providedvaluesoverride the corresponding fields of this instance.- Raises:
KeyError – If any of the keys in
valuesare not valid field names for this dataclass.
- class jaqmc.data.BatchedData(data, fields_with_batch)[source]#
Container pairing one data instance with batched-field metadata.
- data[source]#
Structured runtime data. Batched fields keep the same dataclass structure as one-walker
Data, but carry an extra leading walker axis.
- fields_with_batch[source]#
Field names in
datawhose leaves carry the leading walker axis. Fields not listed here are shared across walkers.
The type variable
DataTis the concreteDatasubtype stored indata.- all_gather()[source]#
Gather distributed arrays from all devices to each local node.
For fields that are batched (sharded along the batch axis), this collects all shards and materializes the complete array on each device. Unbatched fields are left unchanged.
- Return type:
Self- Returns:
A new
BatchedDatawith all-gathered batched fields.
- property batch_size: int[source]#
Return the leading dimension shared by batched fields.
The size is read from the first leaf of the first field listed in
fields_with_batch. Callcheck()when you need to verify that all batched fields use the same leading size.- Returns:
The detected batch size, or
0when no fields are marked as batched.
- check()[source]#
Validate the batched-field metadata against the wrapped data.
The check verifies that every name in
fields_with_batchexists ondataand, for concrete JAX arrays, that every batched leaf shares the same leading batch size. Shape validation is skipped during JAX tracing and for non-array leaves.- Raises:
KeyError – If
fields_with_batchnames fields that do not exist ondata.ValueError – If a batched array leaf does not use the common leading batch size.
- fully_batched_data()[source]#
Return a new
Datawith all fields batched.Fields that are already batched (listed in
fields_with_batch) are left unchanged; fields that are unbatched are duplicated along a new leading batch dimension so that their shapes become(batch_size, *orig_shape).- Raises:
ValueError – If the current
batch_sizeis zero.- Return type:
- property partition_spec[source]#
Describe how this batched data should be sharded.
Fields listed in
fields_with_batchare assigned ajax.sharding.PartitionSpecoverBATCH_AXIS_NAME. Unbatched fields receive an empty partition spec, meaning they are shared rather than sharded over walkers.- Returns:
A new
BatchedDatawhosedatafields contain partition specs matching the wrapped data structure.
- unbatched_example()[source]#
Return a one-walker-shaped example matching this data structure.
Batched fields are replaced with arrays of ones after dropping their leading batch axis. Unbatched fields are replaced with
ones_likearrays of the same shape. The result is useful for initialization code that needs representative single-walker input shapes, not actual sampled values.
- property vmap_axis[source]#
Describe batch axes for use with
jax.vmap().Each dataclass field is mapped to an axis specification:
0for fields that are batched along the leading dimension, andNonefor fields that are treated as broadcasted or scalar. The returned object is another instance of this dataclass, intended to be used as thein_axesargument tojax.vmap.- Returns:
A new instance of the same dataclass with integer/
Nonevalues describing the batch axes per field.