Training#

Configuration reference for jaqmc molecule train. This page shows the effective defaults for the train workflow preset. Use --dry-run to see the resolved config for your run, or add workflow.config.verbose=true to include field descriptions. Keys use the same dot notation as CLI overrides, such as train.run.iterations=5000. Defaults are resolved in this order: schema defaults, workflow preset, YAML config, then CLI overrides. For evaluation config, see Evaluation.

Root-level runtime keys such as logging.*, jax.*, and distributed.* are shared by all commands. See Runtime Configuration.

Workflow (workflow.*)#

These keys control workflow-level settings shared across all stages.

workflow.seed

Default: None · Type: int | None

Fixed random seed.


workflow.batch_size

Default: 4096 · Type: int

Number of walkers (samples) to use in each iteration.


workflow.save_path

Default: '' · Type: str

Path to save checkpoints and logs.


workflow.restore_path

Default: '' · Type: str

Path to restore checkpoints from.


workflow.config.ignore_extra

Default: False · Type: bool

If True, silently ignore unrecognized config keys.


workflow.config.verbose

Default: False · Type: bool

If True, print the fully resolved config with field descriptions at startup.

System (system.*)#

Defines the molecular system to simulate. The implementation is selected by system.module.

  • Default module selection: unset, so system.* is read directly as an arbitrary molecule config. Built-in choices are:

    • unset: arbitrary molecule config

    • atom: single-atom generator

    • diatomic: diatomic generator

Arbitrary molecules (default)#

system.atoms

Default: [Atom(symbol='H', coords=[0, 0, 0], atomic_number=1, charge=1)] · Type: list[Atom]

List of atoms in the system.


system.basis

Default: 'sto-3g' · Type: str | dict[str, str]

The basis set for Hartree-Fock pretrain.


system.ecp

Default: None · Type: str | dict[str, str] | None

Effective core potential specification.


system.electron_spins

Default: (1, 0) · Type: tuple[int, int]

Tuple of two integers representing the number of up and down electrons.


system.fixed_spins_per_atom

Default: None · Type: list[tuple[int, int]] | None

Optional list of fixed spin configurations per atom.


system.electron_init_width

Default: 1.0 · Type: float

Width of the Gaussian distribution for initializing electron positions.

Single atoms (system.module=atom)#

system.symbol

Default: 'H' · Type: str

Element symbol (e.g., “H”, “Li”, “Fe”).


system.electron_init_width

Default: 1.0 · Type: float

Width of Gaussian for electron initialization.


system.basis

Default: 'sto-3g' · Type: str

Basis set name.


system.ecp

Default: None · Type: str | None

Effective core potential name.

Diatomic molecules (system.module=diatomic)#

system.formula

Default: 'H2' · Type: str

Chemical formula (e.g., "H2", "LiH", "N2", "ClF").


system.bond_length

Default: 1.4 · Type: float

Distance between the two atoms.


system.unit

Default: bohr · Type: LengthUnit

Length unit for bond_length and atom coordinates.


system.basis

Default: 'sto-3g' · Type: str | dict[str, str]

Basis set name, or per-element mapping (e.g., {"Li": "ccecpccpvdz", "H": "cc-pvdz"}).


system.ecp

Default: None · Type: str | dict[str, str] | None

Effective core potential specification.


system.spin

Default: 0 · Type: int

Total spin (number of unpaired electrons).


system.electron_init_width

Default: 1.0 · Type: float

Width of Gaussian for electron initialization.

Wavefunction (wf.*)#

Selects and configures the neural-network ansatz.

  • Default module selection: ferminet. Effective defaults for the built-in architectures are listed below. Built-in choices are ferminet and psiformer.

FermiNet options (wf.*)#

wf.ndets

Default: 16 · Type: int

Number of determinants.


wf.hidden_dims_single

Default: [256, 256, 256, 256] · Type: list[int]

Hidden dimensions for single-electron stream.


wf.hidden_dims_double

Default: [32, 32, 32, 32] · Type: list[int]

Hidden dimensions for double-electron stream.


wf.use_last_layer

Default: False · Type: bool

If False, skip the double-electron stream update in the final layer and return single-electron features directly.


wf.envelope

Default: abs_isotropic · Type: EnvelopeType

Type of envelope function to apply to orbitals.


wf.orbitals_spin_split

Default: True · Type: bool

If True, use separate orbital layer and envelope parameters for each spin channel.


wf.full_det

Default: True · Type: bool

Psiformer options (wf.*)#

wf.ndets

Default: 16 · Type: int

Number of determinants.


wf.num_layers

Default: 4 · Type: int

Number of Psiformer layers.


wf.num_heads

Default: 4 · Type: int

Number of attention heads.


wf.heads_dim

Default: 64 · Type: int

Dimension of each attention head.


wf.mlp_hidden_dims

Default: [256] · Type: list[int]

Hidden dimensions for MLP blocks.


wf.layer_norm_mode

Default: pre · Type: LayerNormMode

LayerNorm application mode.


wf.jastrow

Default: SIMPLE_EE · Type: JastrowType

Jastrow factor type.


wf.with_bias

Default: True · Type: bool

Whether to use bias in attention QKV projections.


wf.input_bias

Default: True · Type: bool

Whether to use bias in the input projection layer.


wf.envelope

Default: abs_isotropic · Type: EnvelopeType

Envelope type for orbital decay at infinity.


wf.orbitals_spin_split

Default: True · Type: bool

If True, use separate orbital layer and envelope parameters for each spin channel.


wf.bias_orbitals

Default: False · Type: bool

If True, include bias in the orbital projection layer.


wf.rescale

Default: True · Type: bool

If True, use log-scaled input features (log(1+r)) instead of linear features.


wf.jastrow_alpha_init

Default: 1.0 · Type: float

Initial value for Jastrow alpha parameters.


wf.full_det

Default: True · Type: bool


Train Stage (train.*)#

The main VMC optimization loop. Samples electron configurations, computes energy, and updates wavefunction parameters.

Run options (train.run.*)#

train.run.check_vma

Default: True · Type: bool

Enable JAX validity checks during shard_map.


train.run.iterations

Default: 200000 · Type: int

Total number of iterations to run.


train.run.burn_in

Default: 100 · Type: int

Sampling iterations to discard before the main loop for MCMC equilibration.


train.run.save_time_interval

Default: 600 · Type: int

Minimum wall-clock seconds between checkpoint saves.


train.run.save_step_interval

Default: 1000 · Type: int

Save checkpoints only at steps that are multiples of this value.


train.run.stop_on_nan

Default: 'loss' · Type: bool | str

Abort training when NaN is detected in step statistics. True checks all stat keys, False disables the check, or pass a comma-separated string of specific keys to monitor (e.g. "loss").

Optimizer (train.optim.*)#

  • Default optimizer module: kfac. Effective defaults for the built-in optimizers are listed below.

KFAC options#

train.optim.learning_rate

Default: Standard · Type: swappable

The learning rate. Swappable component; the nested keys below are the options for the current module Standard and change when train.optim.learning_rate.module changes.

train.optim.learning_rate.module

Default: Standard · Type: module path

Select the implementation used for this component.

train.optim.learning_rate.rate

Default: 0.05 · Type: float

Initial learning rate.

train.optim.learning_rate.delay

Default: 2000 · Type: float

Delay in steps before decay starts.

train.optim.learning_rate.decay

Default: 1 · Type: float

Decay rate exponent.


train.optim.norm_constraint

Default: 0.001 · Type: float

The update is scaled down so that its approximate squared Fisher norm \(v^T F v\) is at most the specified value.


train.optim.curvature_ema

Default: 0.95 · Type: float

Decay factor used when calculating the covariance estimate moving averages.


train.optim.l2_reg

Default: 0.0 · Type: float

Tell the optimizer what L2 regularization coefficient you are using.


train.optim.inverse_update_period

Default: 1 · Type: int

Number of steps in between updating the inverse curvature approximation.


train.optim.damping

Default: 0.001 · Type: float

Fixed damping parameter.

SR options#

train.optim.learning_rate

Default: Standard · Type: swappable

Step size (scalar or schedule). Swappable component; the nested keys below are the options for the current module Standard and change when train.optim.learning_rate.module changes.

train.optim.learning_rate.module

Default: Standard · Type: module path

Select the implementation used for this component.

train.optim.learning_rate.rate

Default: 0.05 · Type: float

Initial learning rate.

train.optim.learning_rate.delay

Default: 2000 · Type: float

Delay in steps before decay starts.

train.optim.learning_rate.decay

Default: 1 · Type: float

Decay rate exponent.


train.optim.max_norm

Default: Constant · Type: swappable

Constrained update norm C (scalar or schedule). Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.max_norm.module changes.

train.optim.max_norm.module

Default: Constant · Type: module path

Select the implementation used for this component.

train.optim.max_norm.rate

Default: 0.05 · Type: float

The constant rate.


train.optim.damping

Default: Constant · Type: swappable

Damping lambda (scalar or schedule). Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.damping.module changes.

train.optim.damping.module

Default: Constant · Type: module path

Select the implementation used for this component.

train.optim.damping.rate

Default: 0.05 · Type: float

The constant rate.


train.optim.max_cond_num

Default: 10000000.0 · Type: float | None

Maximum condition number for adaptive damping.


train.optim.spring_mu

Default: Constant · Type: swappable

SPRING momentum coefficient mu (scalar or schedule). Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.spring_mu.module changes.

train.optim.spring_mu.module

Default: Constant · Type: module path

Select the implementation used for this component.

train.optim.spring_mu.rate

Default: 0.05 · Type: float

The constant rate.


train.optim.march_beta

Default: Constant · Type: swappable

Decay factor for the MARCH variance accumulator (scalar or schedule). Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.march_beta.module changes.

train.optim.march_beta.module

Default: Constant · Type: module path

Select the implementation used for this component.

train.optim.march_beta.rate

Default: 0.05 · Type: float

The constant rate.


train.optim.march_mode

Default: 'var' · Type: Literal[var, diff]

MARCH variance mode. "diff" uses update differences and "var" uses score variance along the batch axis.


train.optim.eps

Default: 1e-08 · Type: float

Small numerical constant for stability.


train.optim.mixed_precision

Default: True · Type: bool

Whether to use mixed precision for Gram factorization.


train.optim.score_chunk_size

Default: 128 · Type: int | None

Chunk size for score computation.


train.optim.score_norm_clip

Default: None · Type: float | None

Optional clip value for the mean absolute score per batch row.


train.optim.gram_num_chunks

Default: 4 · Type: int | None

Number of chunks for Gram matrix computation.


train.optim.gram_dot_prec

Default: 'F64' · Type: str | None

Precision mode for Gram matrix dot products.


train.optim.prune_inactive

Default: False · Type: bool

Whether to structurally prune inactive parameter leaves when forming the SR system.

Adam options#

train.optim.learning_rate

Default: Standard · Type: swappable

A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate(). Swappable component; the nested keys below are the options for the current module Standard and change when train.optim.learning_rate.module changes.

train.optim.learning_rate.module

Default: Standard · Type: module path

Select the implementation used for this component.

train.optim.learning_rate.rate

Default: 0.05 · Type: float

Initial learning rate.

train.optim.learning_rate.delay

Default: 2000 · Type: float

Delay in steps before decay starts.

train.optim.learning_rate.decay

Default: 1 · Type: float

Decay rate exponent.


train.optim.b1

Default: 0.9 · Type: float

Exponential decay rate to track the first moment of past gradients.


train.optim.b2

Default: 0.999 · Type: float

Exponential decay rate to track the second moment of past gradients.


train.optim.eps

Default: 1e-08 · Type: float

A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.


train.optim.eps_root

Default: 0.0 · Type: float

A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling.

LAMB options#

train.optim.learning_rate

Default: Standard · Type: swappable

A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate(). Swappable component; the nested keys below are the options for the current module Standard and change when train.optim.learning_rate.module changes.

train.optim.learning_rate.module

Default: Standard · Type: module path

Select the implementation used for this component.

train.optim.learning_rate.rate

Default: 0.05 · Type: float

Initial learning rate.

train.optim.learning_rate.delay

Default: 2000 · Type: float

Delay in steps before decay starts.

train.optim.learning_rate.decay

Default: 1 · Type: float

Decay rate exponent.


train.optim.b1

Default: 0.9 · Type: float

Exponential decay rate to track the first moment of past gradients.


train.optim.b2

Default: 0.999 · Type: float

Exponential decay rate to track the second moment of past gradients.


train.optim.eps

Default: 1e-06 · Type: float

A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.


train.optim.eps_root

Default: 0.0 · Type: float

A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling.


train.optim.weight_decay

Default: Constant · Type: swappable

Strength of the weight decay regularization. Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.weight_decay.module changes.

train.optim.weight_decay.module

Default: Constant · Type: module path

Select the implementation used for this component.

train.optim.weight_decay.rate

Default: 0.05 · Type: float

The constant rate.

Sampler (train.sampler.*)#

  • Default sampler module: mcmc, and its effective keys are listed below.

train.sampler.steps

Default: 10 · Type: int

Number of Metropolis-Hastings updates per sample draw.


train.sampler.initial_width

Default: 0.1 · Type: float

Initial width (stddev) of the Gaussian proposal.


train.sampler.adapt_frequency

Default: 100 · Type: int

Frequency of adaptive width updates.


train.sampler.pmove_range

Default: (0.5, 0.55) · Type: tuple[float, float]

Target range for acceptance rate.

Writers (train.writers.*)#

The train stage enables console, csv, and hdf5 writers by default.

Console writer (train.writers.console.*)#

train.writers.console.interval

Default: 1 · Type: int

Step interval for logging.


train.writers.console.fields

Default: 'pmove:.2f,energy=total_energy:.4f,variance=total_energy_var:.4f' · Type: str

Comma-separated list of field specs.

CSV writer (train.writers.csv.*)#

train.writers.csv.path_template

Default: '{stage}_stats.csv' · Type: str

Output path template.

HDF5 writer (train.writers.hdf5.*)#

train.writers.hdf5.path_template

Default: '{stage}_stats.h5' · Type: str

Output path template.

Loss gradients (train.grads.*)#

Loss and gradient estimator. Computes the VMC loss and parameter gradients.

train.grads.vmap_chunk_size

Default: None · Type: int | None

Number of walkers to evaluate per vmap chunk.


train.grads.loss_key

Default: 'total_energy' · Type: str

Key in prev_walker_stats to use as the loss.


train.grads.clip_scale

Default: 5.0 · Type: float

Multiplier on the interquartile range (IQR) that sets the clipping window for local energies.


Pretrain Stage (pretrain.*)#

Initializes the neural network to approximate Hartree-Fock orbitals before VMC training. It uses the same run, sampler, and writer schemas as the train stage, but with a different optimizer default and a workflow-wired supervised loss.

Run options (pretrain.run.*)#

pretrain.run.check_vma

Default: True · Type: bool

Enable JAX validity checks during shard_map.


pretrain.run.iterations

Default: 2000 · Type: int

Total number of iterations to run.


pretrain.run.burn_in

Default: 100 · Type: int

Sampling iterations to discard before the main loop for MCMC equilibration.


pretrain.run.save_time_interval

Default: 600 · Type: int

Minimum wall-clock seconds between checkpoint saves.


pretrain.run.save_step_interval

Default: 1000 · Type: int

Save checkpoints only at steps that are multiples of this value.


pretrain.run.stop_on_nan

Default: 'loss' · Type: bool | str

Abort training when NaN is detected in step statistics. True checks all stat keys, False disables the check, or pass a comma-separated string of specific keys to monitor (e.g. "loss").

Optimizer (pretrain.optim.*)#

  • Default optimizer module: optax:adam, and its effective keys are listed below.

Effective Adam defaults#

pretrain.optim.learning_rate

Default: Standard · Type: swappable

A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate(). Swappable component; the nested keys below are the options for the current module Standard and change when pretrain.optim.learning_rate.module changes.

pretrain.optim.learning_rate.module

Default: Standard · Type: module path

Select the implementation used for this component.

pretrain.optim.learning_rate.rate

Default: 0.0003 · Type: float

Initial learning rate.

pretrain.optim.learning_rate.delay

Default: 2000 · Type: float

Delay in steps before decay starts.

pretrain.optim.learning_rate.decay

Default: 1 · Type: float

Decay rate exponent.


pretrain.optim.b1

Default: 0.9 · Type: float

Exponential decay rate to track the first moment of past gradients.


pretrain.optim.b2

Default: 0.999 · Type: float

Exponential decay rate to track the second moment of past gradients.


pretrain.optim.eps

Default: 1e-08 · Type: float

A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.


pretrain.optim.eps_root

Default: 0.0 · Type: float

A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling.

Sampler (pretrain.sampler.*)#

  • Default sampler module: mcmc.

pretrain.sampler.steps

Default: 10 · Type: int

Number of Metropolis-Hastings updates per sample draw.


pretrain.sampler.initial_width

Default: 0.1 · Type: float

Initial width (stddev) of the Gaussian proposal.


pretrain.sampler.adapt_frequency

Default: 100 · Type: int

Frequency of adaptive width updates.


pretrain.sampler.pmove_range

Default: (0.5, 0.55) · Type: tuple[float, float]

Target range for acceptance rate.

Writers (pretrain.writers.*)#

The pretrain stage enables console, csv, and hdf5 writers by default.

Console writer (pretrain.writers.console.*)#

pretrain.writers.console.interval

Default: 1 · Type: int

Step interval for logging.


pretrain.writers.console.fields

Default: 'loss' · Type: str

Comma-separated list of field specs.

CSV writer (pretrain.writers.csv.*)#

pretrain.writers.csv.path_template

Default: '{stage}_stats.csv' · Type: str

Output path template.

HDF5 writer (pretrain.writers.hdf5.*)#

pretrain.writers.hdf5.path_template

Default: '{stage}_stats.h5' · Type: str

Output path template.

Loss gradients#

Pretraining does not use configurable pretrain.grads.* settings. The workflow wires a supervised Hartree-Fock orbital-matching loss directly.


Estimators (estimators.*)#

Energy estimators are configured programmatically by the workflow and are not typically overridden via config. The same definitions are used by Evaluation.

  • total_energy and the electron-nuclei potential are always added by the workflow and are not configurable via config keys.

  • estimators.enabled.spin defaults to false.

Kinetic energy (estimators.energy.kinetic.*)#

estimators.energy.kinetic.vmap_chunk_size

Default: None · Type: int | None

Number of walkers to evaluate per vmap chunk.


estimators.energy.kinetic.mode

Default: forward_laplacian · Type: LaplacianMode

Laplacian computation strategy. forward_laplacian is the default for JAX 0.7.1 and later, scan for earlier versions.


estimators.energy.kinetic.sparsity_threshold

Default: 0 · Type: int

Sparsity threshold when using forward_laplacian mode.

ECP energy (estimators.energy.ecp.*)#

Added automatically when system.ecp is set.

estimators.energy.ecp.vmap_chunk_size

Default: None · Type: int | None

Number of walkers to evaluate per vmap chunk.


estimators.energy.ecp.max_core

Default: 2 · Type: int

Maximum number of nearest ECP atoms to consider per electron when evaluating nonlocal integrals.


estimators.energy.ecp.quadrature_id

Default: None · Type: str | None

Spherical quadrature rule used to evaluate nonlocal ECP integrals.


estimators.energy.ecp.electrons_field

Default: 'electrons' · Type: str

Name of electron position field in data.


estimators.energy.ecp.atoms_field

Default: 'atoms' · Type: str

Name of atom position field in data.