Wavefunctions

Contents

Wavefunctions#

A wavefunction maps electron configurations to log-amplitudes (and optionally signs or phases). All JaQMC wavefunctions are Flax Linen modules that subclass Wavefunction.

For architecture-specific options (FermiNet, Psiformer, etc.), see the configuration reference pages (Molecules, Solids, Quantum Hall).

Base class and protocols#

class jaqmc.wavefunction.Wavefunction(parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Base class for JaQMC wavefunctions.

A Wavefunction is a Flax nn.Module with three complementary execution interfaces:

  • __call__(data): model definition for one walker (implemented by subclasses).

  • apply(params, data, ...): Flax runtime API that executes methods with explicit variables.

  • evaluate(params, data): JaQMC wrapper over apply that provides a stable typed contract for framework consumers.

abstractmethod __call__(data)[source]#

Define the forward pass for one walker.

This method should be written in standard Flax style and does not take explicit parameters.

Return type:

TypeVar(OutputT)

evaluate(params, data)[source]#

Framework-level execution entrypoint.

JaQMC wrapper over apply() that provides a stable typed contract, following the signature of WavefunctionEvaluate.

Return type:

TypeVar(OutputT)

Returns:

The wavefunction output for the provided parameters and data.

init_params(data, rngs)[source]#

Initialize parameters from one sample walker.

JaQMC wrapper over init() that provides a stable typed contract, following the signature of WavefunctionInit.

Return type:

Params

Returns:

A PyTree containing the initialized wavefunction parameters.

class jaqmc.wavefunction.WavefunctionLike(*args, **kwargs)[source]#

Minimal wavefunction interface required by framework components.

App-level protocols can extend this with domain-specific methods such as logpsi, phase_logpsi, or orbitals.

init_params: jaqmc.wavefunction.base.WavefunctionInit[source]#
evaluate: jaqmc.wavefunction.base.WavefunctionEvaluate[source]#
class jaqmc.wavefunction.WavefunctionEvaluate(*args, **kwargs)[source]#
__call__(params, data)[source]#

Evaluate a wavefunction with explicit parameters.

Parameters:
  • params (Params) – Parameter PyTree to evaluate.

  • data (TypeVar(DataT, bound= Data)) – One-walker input sample.

Return type:

TypeVar(OutputT)

Returns:

Model output for this walker (scalar or structured output).

class jaqmc.wavefunction.WavefunctionInit(*args, **kwargs)[source]#
__call__(data, rngs)[source]#

Initialize wavefunction parameters from one walker sample.

Return type:

Params

Returns:

A PyTree representing the initial wavefunction parameters.

type jaqmc.wavefunction.base.NumericWavefunctionEvaluate = WavefunctionEvaluate[DataT, Array][source]#

Callable protocol for one-walker numeric wavefunction evaluation.

This alias specializes WavefunctionEvaluate to scalar-array outputs, which are typically log-amplitude values such as log|psi|.

Output types#

class jaqmc.wavefunction.base.RealWFOutput[source]#

Structured output for real-valued wavefunctions.

class jaqmc.wavefunction.base.ComplexWFOutput[source]#

Structured output for complex-valued wavefunctions.

class jaqmc.wavefunction.base.LogPsiWFOutput[source]#

Minimal structured wavefunction output containing logpsi.

Log-determinant output (used by FermiNet / Psiformer)#

class jaqmc.wavefunction.output.logdet.RealLogDetOutput[source]#

Output of LogDet for real-valued orbital matrices.

logpsi[source]#

Log absolute wavefunction value (real scalar).

sign_logpsi[source]#

Sign of the wavefunction (+1 or -1).

sign_logdets[source]#

Signs of individual determinants.

abs_logdets[source]#

Log absolute values of individual determinants.

class jaqmc.wavefunction.output.logdet.ComplexLogDetOutput[source]#

Output of LogDet for complex-valued orbital matrices.

logpsi[source]#

Log wavefunction value (complex scalar).

sign_logdets[source]#

Signs (phases) of individual determinants.

abs_logdets[source]#

Log absolute values of individual determinants.

Input features#

class jaqmc.wavefunction.input.atomic.AtomicEmbedding[source]#

Output from MoleculeFeatures or SolidFeatures.

ae_features[source]#

Flattened atom-electron features for backbone (n_elec, n_atoms * (ndim + 1)).

ee_features[source]#

Electron-electron features for backbone (n_elec, n_elec, ndim + 1).

r_ae[source]#

Atom-electron distances (n_elec, n_atoms).

ae_vec[source]#

Atom-electron displacement vectors (n_elec, n_atoms, ndim).

class jaqmc.wavefunction.input.atomic.MoleculeFeatures(rescale=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Input features for molecular systems (OBC).

rescale[source]#

If True, make input features grow as log(r) rather than r.

class jaqmc.wavefunction.input.atomic.SolidFeatures(simulation_lattice, primitive_lattice, distance_type=DistanceType.nu, sym_type=SymmetryType.minimal, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Input features for periodic systems (solids).

simulation_lattice[source]#

Lattice vectors of the simulation cell (nelectrons).

primitive_lattice[source]#

Lattice vectors of the primitive cell (natoms).

distance_type[source]#

Type of periodic distance to use (‘nu’ or ‘tri’).

sym_type[source]#

Symmetry type for auxiliary lattice vectors.

setup()[source]#

Precompute symmetry-reduced lattice vectors for distance evaluation.

Backbone architectures#

class jaqmc.wavefunction.backbone.ferminet.FermiLayers(nspins, hidden_dims, use_last_layer=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

FermiNet interaction layers with single- and double-electron streams.

Each layer updates the single-electron stream by aggregating features from both streams, then updates the double-electron stream independently. Residual connections are added when input and output dimensions match.

Parameters:
  • nspins (tuple[int, int]) – Tuple of (num_spin_up, num_spin_down) electrons.

  • hidden_dims (list[tuple[int, int]]) – List of (single_dim, double_dim) pairs, one per layer.

  • use_last_layer (bool, default: False) – If True, also update the double stream in the final layer and return aggregated features. If False, skip the final double-stream update.

aggregate_features(h_one, h_two)[source]#

Concatenate electron features with spin-channel averages.

For each non-empty spin channel, this computes mean single-stream features and mean pairwise-stream features, then concatenates those aggregates with the original single-electron features.

Parameters:
  • h_one (Array) – Single-electron features of shape (n_electrons, single_dim).

  • h_two (Array) – Pairwise electron features of shape (n_electrons, n_electrons, double_dim).

Returns:

Aggregated single-electron features with the same leading electron axis as h_one.

class jaqmc.wavefunction.backbone.psiformer.PsiformerBackbone(nspins, num_layers=2, num_heads=4, heads_dim=64, mlp_hidden_dims=(256, ), layer_norm_mode=LayerNormMode.pre, with_bias=True, input_bias=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Self-attention backbone for molecular wavefunctions.

Psiformer processes one-electron features through multiple self-attention layers, modeling electron-electron interactions implicitly through attention rather than explicit two-electron feature streams.

The architecture:

  1. Concatenates spin encoding to input features

  2. Projects to attention dimension

  3. Applies num_layers PsiformerLayer blocks

  4. Outputs processed one-electron features

Parameters:
  • nspins (tuple[int, int]) – Tuple of (num_spin_up, num_spin_down) electrons.

  • num_layers (int, default: 2) – Number of Psiformer layers.

  • num_heads (int, default: 4) – Number of attention heads.

  • heads_dim (int, default: 64) – Dimension of each attention head.

  • mlp_hidden_dims (Sequence[int], default: (256,)) – Hidden dimensions for MLP blocks.

  • layer_norm_mode (LayerNormMode, default: <LayerNormMode.pre: 'pre'>) –

    LayerNorm application mode. Options:

  • with_bias (bool, default: True) – Whether to use bias in attention QKV projections.

  • input_bias (bool, default: True) – Whether to use bias in the input projection layer.

class jaqmc.wavefunction.backbone.psiformer.PsiformerLayer(num_heads, heads_dim, mlp_hidden_dims, layer_norm_mode=LayerNormMode.pre, with_bias=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Single Psiformer layer combining attention and MLP with residual connections.

Each layer applies:

  1. Optional LayerNorm (before or after, depending on mode)

  2. Multi-head self-attention with residual connection

  3. Optional LayerNorm (before or after, depending on mode)

  4. MLP with residual connection

Parameters:
  • num_heads (int) – Number of attention heads.

  • heads_dim (int) – Dimension of each attention head.

  • mlp_hidden_dims (Sequence[int]) – Hidden dimensions for the MLP block.

  • layer_norm_mode (LayerNormMode, default: <LayerNormMode.pre: 'pre'>) –

    LayerNorm application mode. Options:

  • with_bias (bool, default: True) – Whether to use bias in attention QKV projections.

class jaqmc.wavefunction.backbone.psiformer.LayerNormMode(*values)[source]#

LayerNorm application mode for Psiformer layers.

pre[source]#

Apply LayerNorm before attention/MLP blocks (Pre-LN). Matches internal_ferminet implementation.

post[source]#

Apply LayerNorm after attention/MLP blocks (Post-LN). Matches public FermiNet implementation.

null[source]#

No LayerNorm applied.

Orbital projection and envelope#

class jaqmc.wavefunction.output.orbital.OrbitalProjection(nspins, ndets, orbitals_spin_split=True, use_bias=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Project backbone features to orbital matrix for molecules.

Handles both spin-split and non-split configurations:

  • Spin-split: Separate dense layers for each spin channel, allowing different orbital transformations for spin-up and spin-down electrons.

  • Non-split: Single dense layer shared across all electrons.

The output is reshaped and transposed to produce the standard orbital matrix format (ndets, n_electrons, n_electrons) used by determinant layers.

Parameters:
  • nspins (tuple[int, int]) – Tuple of (num_spin_up, num_spin_down) electrons.

  • ndets (int) – Number of determinants.

  • orbitals_spin_split (bool, default: True) – If True, use separate projection for each spin channel. Only effective when both spin channels are occupied.

  • use_bias (bool, default: False) – Whether to use bias in dense layers.

class jaqmc.wavefunction.output.orbital.SplitChannelDense(channels, features, use_bias=True, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Apply separate dense layers to each spin channel.

Parameters:
  • channels (tuple[int, int]) – Tuple of (num_spin_up, num_spin_down) electrons.

  • features (list[int]) – Output feature dimensions for DenseGeneral.

  • use_bias (bool, default: True) – Whether to use bias in dense layers.

class jaqmc.wavefunction.output.envelope.Envelope(envelope_type, ndets, nspins, orbitals_spin_split=False, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Envelope function for wavefunctions.

Computes \(E_{ik} = \sum_I \pi_{kI} \exp(-d_{iIk})\) where the effective distance \(d_{iIk}\) depends on EnvelopeType. Works for both OBC (molecules) and PBC (solids).

envelope_type[source]#

Type of envelope. See EnvelopeType.

ndets[source]#

Number of determinants.

nspins[source]#

Tuple of (num_spin_up, num_spin_down) electrons.

orbitals_spin_split[source]#

If True, use separate envelope parameters for each spin channel.

setup()[source]#

Initialize envelope submodules based on the configuration.

class jaqmc.wavefunction.output.envelope.EnvelopeType(*values)[source]#

Envelope types controlling wavefunction decay near atoms.

The envelope \(E_{ik} = \sum_I \pi_{kI} \exp(-d_{iIk})\) modulates orbital output. The effective distance \(d_{iIk}\) depends on the type.

isotropic[source]#

Scalar decay rate \(d = \sigma \| \mathbf{r} \|\). Simple but \(\sigma\) can go negative.

abs_isotropic[source]#

\(d = |\sigma| \| \mathbf{r} \|\). Always-decaying variant. Generally preferred.

diagonal[source]#

Per-dimension decay rates \(d = \| \mathbf{r} \odot \boldsymbol{\sigma} \|\).

null[source]#

No envelope (returns ones).

class jaqmc.wavefunction.output.logdet.LogDet(parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Compute log-determinant sum from orbital matrices.

Given an orbital tensor of shape (ndets, n, n), computes

\[\log\psi = \log\!\sum_k \operatorname{sign}(\det M_k)\, \exp(\log|\det M_k|)\]

using the logsumexp trick for numerical stability.

Jastrow factor#

class jaqmc.wavefunction.jastrow.SimpleEEJastrow(nspins, alpha_init=1.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Jastrow factor for electron-electron cusp conditions.

Implements the simple electron-electron Jastrow factor from FermiNet:

\[J = \sum_{i<j} f(r_{ij})\]

where the cusp function is:

\[f(r) = -\frac{c \alpha^2}{\alpha + r}\]

with \(c = 0.25\) for parallel spins and \(c = 0.5\) for antiparallel spins.

This form satisfies the electron-electron cusp condition:

\[\frac{d f}{d r}\bigg|_{r=0} = \frac{c \alpha^2}{\alpha^2} = c\]

The Jastrow factor is applied multiplicatively to orbitals as:

\[\phi'_{ij} = \phi_{ij} \cdot \exp(J / n_{electrons})\]
Parameters:
  • nspins (tuple[int, int]) – Tuple of (num_spin_up, num_spin_down) electrons.

  • alpha_init (float, default: 1.0) – Initial value for the learnable decay parameters (alpha_par and alpha_anti). Defaults to 1.0.

class jaqmc.app.molecule.wavefunction.psiformer.JastrowType(*values)[source]#

Available Jastrow factor types for Psiformer wavefunction.

NONE[source]#

Disable the Jastrow factor.

SIMPLE_EE[source]#

Enable the built-in electron-electron Jastrow factor.

Data#

In most wavefunction and per-walker estimator hooks, Data represents one walker’s structured runtime input. BatchedData pairs a Data-shaped pytree with metadata describing which fields carry a leading walker axis during batched execution.

Most user-defined wavefunctions and per-walker estimators only work with Data. BatchedData is the lower-level representation used when framework code or workflow plumbing needs to manipulate full walker batches explicitly.

For the built-in data-shape convention and the detailed explanation of fields_with_batch, see Runtime Data Conventions.

class jaqmc.data.Data[source]#

Base container for structured wavefunction input data.

Data instances behave like lightweight, JAX-compatible dataclasses whose fields can be accessed both as attributes and as mapping keys. They are used to pass structured inputs (e.g. coordinates, atomic positions) between samplers, wavefunctions, and estimators.

property field_names: list[str][source]#

Return dataclass field names in declaration order.

merge(values)[source]#

Return a new instance with values merged into this one.

Parameters:

values (Mapping[str, Any]) – Mapping from field names to replacement values.

Return type:

Self

Returns:

A new Data (or subclass) instance where the provided values override the corresponding fields of this instance.

Raises:

KeyError – If any of the keys in values are not valid field names for this dataclass.

subset(fields)[source]#

Return a dictionary containing only the selected fields.

Parameters:

fields (Sequence[str]) – Sequence of field names to keep.

Return type:

dict[str, Any]

Returns:

A new dict mapping each requested field name to its value in this instance.

Raises:

KeyError – If any of the requested fields are not valid field names for this dataclass.

class jaqmc.data.BatchedData(data, fields_with_batch)[source]#

Container pairing one data instance with batched-field metadata.

data[source]#

Structured runtime data. Batched fields keep the same dataclass structure as one-walker Data, but carry an extra leading walker axis.

fields_with_batch[source]#

Field names in data whose leaves carry the leading walker axis. Fields not listed here are shared across walkers.

The type variable DataT is the concrete Data subtype stored in data.

all_gather()[source]#

Gather distributed arrays from all devices to each local node.

For fields that are batched (sharded along the batch axis), this collects all shards and materializes the complete array on each device. Unbatched fields are left unchanged.

Return type:

Self

Returns:

A new BatchedData with all-gathered batched fields.

property batch_size: int[source]#

Return the leading dimension shared by batched fields.

The size is read from the first leaf of the first field listed in fields_with_batch. Call check() when you need to verify that all batched fields use the same leading size.

Returns:

The detected batch size, or 0 when no fields are marked as batched.

check()[source]#

Validate the batched-field metadata against the wrapped data.

The check verifies that every name in fields_with_batch exists on data and, for concrete JAX arrays, that every batched leaf shares the same leading batch size. Shape validation is skipped during JAX tracing and for non-array leaves.

Raises:
  • KeyError – If fields_with_batch names fields that do not exist on data.

  • ValueError – If a batched array leaf does not use the common leading batch size.

fully_batched_data()[source]#

Return a new Data with all fields batched.

Fields that are already batched (listed in fields_with_batch) are left unchanged; fields that are unbatched are duplicated along a new leading batch dimension so that their shapes become (batch_size, *orig_shape).

Raises:

ValueError – If the current batch_size is zero.

Return type:

TypeVar(DataT, bound= Data)

property partition_spec[source]#

Describe how this batched data should be sharded.

Fields listed in fields_with_batch are assigned a jax.sharding.PartitionSpec over BATCH_AXIS_NAME. Unbatched fields receive an empty partition spec, meaning they are shared rather than sharded over walkers.

Returns:

A new BatchedData whose data fields contain partition specs matching the wrapped data structure.

unbatched_example()[source]#

Return a one-walker-shaped example matching this data structure.

Batched fields are replaced with arrays of ones after dropping their leading batch axis. Unbatched fields are replaced with ones_like arrays of the same shape. The result is useful for initialization code that needs representative single-walker input shapes, not actual sampled values.

Return type:

TypeVar(DataT, bound= Data)

Returns:

A new DataT instance with the same fields as data and single-walker shapes for batched fields.

property vmap_axis[source]#

Describe batch axes for use with jax.vmap().

Each dataclass field is mapped to an axis specification: 0 for fields that are batched along the leading dimension, and None for fields that are treated as broadcasted or scalar. The returned object is another instance of this dataclass, intended to be used as the in_axes argument to jax.vmap.

Returns:

A new instance of the same dataclass with integer/None values describing the batch axes per field.