Writing Wavefunctions#

The Writing Workflows tutorial showed how a wavefunction fits into a workflow — subclass Wavefunction, implement __call__, and pass wf.evaluate to the builder. This page clarifies the __call__/apply/evaluate interfaces, then covers the reusable building blocks for complex architectures and how to make a wavefunction YAML-configurable. Both levels build on Flax basics. If terms like jit, vmap, pytrees, or Module.apply still feel unfamiliar, read JAX for JaQMC first.

The __call__ Contract#

Wavefunction is a Flax nn.Module with one abstract method:

class Wavefunction[DataT: Data, OutputT](nn.Module, ABC):
    @abstractmethod
    def __call__(self, data: DataT) -> OutputT: ...

The type parameters are:

  • DataT — your data type, a subclass of Data (e.g., HydrogenAtomData, MoleculeData). Data is a JAX-compatible dataclass that flows through jit, grad, and vmap.

  • OutputT — the return type. A scalar jnp.ndarray for simple wavefunctions, or a TypedDict like RealLogDetOutput for architectures that return additional information (sign, per-determinant values).

Both are optional — class MyWF(Wavefunction): works fine when you don’t need the type constraints, as the hydrogen atom example shows.

The base class provides two methods for free:

  • init_params(data, rngs) — initializes parameters by calling Flax’s self.init(rngs, data).

  • evaluate(params, data) — runs the forward pass by calling self.apply(params, data).

These are what the rest of JaQMC (samplers, optimizers, estimators) interact with. You only implement __call__.

__call__ receives a single walker — one Data instance, not a batch. JaQMC handles batching externally with jax.vmap, so your implementation never needs to think about the batch dimension. It also runs inside jax.jit, so avoid Python-level data-dependent control flow — use jax.lax.cond or jnp.where instead of if statements that branch on array values. JAX for JaQMC explains why this single-walker-plus-vmap pattern appears throughout the framework.

For built-in-style wavefunctions, treat data.electrons as one walker’s particle coordinates, typically with shape (n_particles, ndim). You only need to think about BatchedData once you start writing lower-level sampler, workflow, or estimator plumbing that manipulates whole walker batches directly; see Runtime Data Conventions. Even if data_init or sampling code is working with batch-shaped arrays elsewhere, the wavefunction contract here stays single-walker.

Execution Interfaces: __call__ vs apply vs evaluate#

These three names are related but serve different layers:

Interface

Signature

Owned by

Typical caller

__call__

(data) -> OutputT

Your wavefunction subclass

Flax internals via apply

apply

(variables, *args, method=...) -> Any

Flax nn.Module

Advanced users and internal wrappers

evaluate

(params, data) -> OutputT

JaQMC Wavefunction base class

Workflow, sampler, optimizer, estimator wiring

  • Use __call__ to define model math for one walker.

  • Use evaluate as the default JaQMC-facing callable with explicit parameters.

  • Use apply directly only for advanced cases (for example, calling an alternate method like get_orbitals).

Framework call path in practice:

workflow/estimator/sampler -> wf.logpsi or wf.evaluate
                          -> wf.evaluate(params, data)
                          -> wf.apply(params, data)
                          -> wf.__call__(data)

Minimal example:

class MyWF(Wavefunction[MyData, jnp.ndarray]):
    @nn.compact
    def __call__(self, data: MyData) -> jnp.ndarray:
        alpha = self.param("alpha", lambda *_: jnp.array(0.0))
        return alpha * jnp.linalg.norm(data.electrons)

wf = MyWF()
params = wf.init_params(data, rngs)
value = wf.evaluate(params, data)           # framework contract
value2 = wf.apply(params, data)             # equivalent low-level Flax call

Structured Return Types#

Production wavefunctions (FermiNet, Psiformer) return more than a scalar — they also provide the sign of the wavefunction, which is needed for energy calculations involving pseudopotentials. They return RealLogDetOutput:

from jaqmc.wavefunction.output.logdet import RealLogDetOutput

class MyWavefunction(Wavefunction[MoleculeData, RealLogDetOutput]):
    def __call__(self, data: MoleculeData) -> RealLogDetOutput:
        ...
        return RealLogDetOutput(
            logpsi=log_amplitude,       # log|psi| (scalar)
            sign_logpsi=sign,           # sign of psi (+1 or -1)
            sign_logdets=signs,         # signs of individual determinants
            abs_logdets=logdets,        # log|det| for each determinant
        )

If your __call__ returns RealLogDetOutput, the extraction methods below become one-liner delegations to evaluate.

Reusable Building Blocks#

The jaqmc.wavefunction package provides Flax modules for the common stages of a molecular wavefunction. You can compose them to build new architectures while only implementing the novel part — typically the backbone.

The built-in wavefunctions (FermiNet, Psiformer) follow this pattern:

  1. Input features — construct atom-electron and electron-electron feature vectors from raw positions.

  2. Backbone — transform those features through interaction layers (message-passing in FermiNet, self-attention in Psiformer) to produce per-electron representations.

  3. Orbital projection — project the per-electron representations into orbital matrices, one per determinant.

  4. Envelope — multiply each orbital by a distance-dependent envelope that enforces the correct asymptotic decay.

  5. Log-determinant — compute the log-sum of Slater determinants to produce the final log|ψ| and sign.

This isn’t a required architecture — the only hard contract is __call__. A wavefunction that skips the orbital/determinant machinery entirely (like the hydrogen atom example) is perfectly valid. But when you do want determinant-based antisymmetry, these modules save you from reimplementing the standard stages.

See the Wavefunctions for the full list of available modules and their parameters.

Example: Custom Backbone with Standard I/O#

The most common extension point is the backbone — the layers that transform input features into single-electron representations. Here’s how to write a custom backbone while reusing the standard input, orbital projection, envelope, and log-determinant layers:

from jaqmc.app.molecule.data import MoleculeData
from jaqmc.utils.wiring import runtime_dep
from jaqmc.wavefunction import Wavefunction
from jaqmc.wavefunction.input.atomic import MoleculeFeatures
from jaqmc.wavefunction.output.envelope import Envelope, EnvelopeType
from jaqmc.wavefunction.output.logdet import LogDet, RealLogDetOutput
from jaqmc.wavefunction.output.orbital import OrbitalProjection


class MyBackbone(nn.Module):
    """Your custom interaction layers."""
    nspins: tuple[int, int]
    hidden_dim: int = 128

    @nn.compact
    def __call__(self, h_one, h_two):
        # h_one: (n_electrons, feature_dim) — single-electron features
        # h_two: (n_electrons, n_electrons, feature_dim) — pairwise features
        ...
        return h_one  # per-electron representations for orbital projection


class MyWavefunction(Wavefunction[MoleculeData, RealLogDetOutput]):
    nspins: tuple[int, int] = runtime_dep()  # set by workflow; see below
    ndets: int = 8
    hidden_dim: int = 128

    def setup(self):
        self.features = MoleculeFeatures()
        self.backbone = MyBackbone(self.nspins, self.hidden_dim)
        self.orbitals = OrbitalProjection(nspins=self.nspins, ndets=self.ndets)
        self.envelope = Envelope(envelope_type=EnvelopeType.abs_isotropic,
                                 ndets=self.ndets, nspins=self.nspins)
        self.logdet = LogDet()

    def __call__(self, data: MoleculeData):
        emb = self.features(data.electrons, data.atoms)
        h_one = self.backbone(emb["ae_features"], emb["ee_features"])
        orbs = self.orbitals(h_one) * self.envelope(emb["ae_vec"], emb["r_ae"])
        return self.logdet(orbs)

runtime_dep() marks fields whose values come from the workflow rather than from user config — nspins is determined by the molecular system, so the workflow sets it at startup. Accessing a runtime_dep() field before the workflow sets it raises AttributeError with a descriptive message. See Making it YAML-configurable for more.

Extraction Methods#

For the hydrogen atom, evaluate is all you need — it returns a scalar and every consumer uses it directly. Production wavefunctions return richer output (like RealLogDetOutput), and different consumers need different slices. Extraction methods give each consumer exactly the interface it needs:

Method

Returns

Used by

logpsi(params, data)

\(\log\lvert\psi\rvert\) (scalar)

VMC loss, MCMC sampling

phase_logpsi(params, data)

\((\operatorname{sgn}\psi,\;\log\lvert\psi\rvert)\)

Pseudopotential estimator (needs sign for wavefunction ratios)

When __call__ returns RealLogDetOutput, both are one-liner extractions from evaluate:

def logpsi(self, params, data):
    return self.evaluate(params, data)["logpsi"]

def phase_logpsi(self, params, data):
    out = self.evaluate(params, data)
    return out["sign_logpsi"], out["logpsi"]

Molecule wavefunction protocol

The molecule and solid apps define a protocol that formalizes which extraction methods they expect. The workflow validates it at startup — if you forget a method, you’ll get a clear error. The protocol also includes an orbitals method for pretraining against Hartree-Fock references — see src/jaqmc/app/molecule/wavefunction/ferminet.py for the implementation pattern.

Making It YAML-Configurable#

To let users select your wavefunction from the CLI, the class must be importable via the module path syntax. Fields fall into two categories:

  • runtime_dep() — set by the workflow from the system config (e.g., nspins from the molecular geometry). Not user-configurable.

  • Regular fields with defaults (e.g., ndets: int = 8) — configurable via YAML or CLI overrides.

Unlike estimators/optimizers/samplers, wavefunction classes do not need @configurable_dataclass: the Wavefunction base class automatically handles configuration serialization for subclasses and excludes Flax internal fields (parent, name). For non-wavefunction components, see Custom Components, which uses @configurable_dataclass, runtime_dep(), and wire().

On the workflow side, get_module() resolves the class from config and instantiates it with the user’s field overrides. The workflow then sets runtime dependencies before passing the wavefunction to the builder:

wf = cfg.get_module("wf", "jaqmc.app.molecule.wavefunction.ferminet")
wf.nspins = system_config.electron_spins  # set runtime_dep before use

train = VMCWorkStage.builder(cfg.scoped("train"), wf)
sampler = cfg.get("sampler", MCMCSampler)
train.configure_sample_plan(wf.logpsi, {"electrons": sampler})
# ... rest of wiring as in the workflows tutorial

Run with:

jaqmc molecule train wf.module=my_package.my_wf:MyWavefunction wf.ndets=16 wf.hidden_dim=256

Where to Look#