Training#
Configuration reference for jaqmc molecule train.
This page shows the effective defaults for the train workflow preset. Use
--dry-run to see the resolved config for your run, or add
workflow.config.verbose=true to include field descriptions. Keys use the same
dot notation as CLI overrides, such as train.run.iterations=5000. Defaults
are resolved in this order: schema defaults, workflow preset, YAML config, then
CLI overrides. For evaluation config, see Evaluation.
Root-level runtime keys such as logging.*, jax.*, and distributed.* are
shared by all commands. See Runtime Configuration.
Workflow (workflow.*)#
These keys control workflow-level settings shared across all stages.
workflow.seed
Fixed random seed.
workflow.batch_size
Number of walkers (samples) to use in each iteration.
workflow.save_path
Path to save checkpoints and logs.
workflow.restore_path
Path to restore checkpoints from.
workflow.config.ignore_extra
If True, silently ignore unrecognized config keys.
workflow.config.verbose
If True, print the fully resolved config with field descriptions at startup.
System (system.*)#
Defines the molecular system to simulate. The implementation is selected by
system.module.
Default module selection: unset, so
system.*is read directly as an arbitrary molecule config. Built-in choices are:unset: arbitrary molecule config
atom: single-atom generatordiatomic: diatomic generator
Arbitrary molecules (default)#
system.atoms
List of atoms in the system.
system.basis
The basis set for Hartree-Fock pretrain.
system.ecp
Effective core potential specification.
system.electron_spins
Tuple of two integers representing the number of up and down electrons.
system.fixed_spins_per_atom
Optional list of fixed spin configurations per atom.
system.electron_init_width
Width of the Gaussian distribution for initializing electron positions.
Single atoms (system.module=atom)#
system.symbol
Element symbol (e.g., “H”, “Li”, “Fe”).
system.electron_init_width
Width of Gaussian for electron initialization.
system.basis
Basis set name.
system.ecp
Effective core potential name.
Diatomic molecules (system.module=diatomic)#
system.formula
Chemical formula (e.g., "H2", "LiH", "N2", "ClF").
system.bond_length
Distance between the two atoms.
system.unit
Length unit for bond_length and atom coordinates.
system.basis
Basis set name, or per-element mapping (e.g., {"Li": "ccecpccpvdz", "H": "cc-pvdz"}).
system.ecp
Effective core potential specification.
system.spin
Total spin (number of unpaired electrons).
system.electron_init_width
Width of Gaussian for electron initialization.
Wavefunction (wf.*)#
Selects and configures the neural-network ansatz.
Default module selection:
ferminet. Effective defaults for the built-in architectures are listed below. Built-in choices areferminetandpsiformer.
FermiNet options (wf.*)#
wf.ndets
Number of determinants.
Hidden dimensions for single-electron stream.
Hidden dimensions for double-electron stream.
wf.use_last_layer
If False, skip the double-electron stream update in the final layer and return single-electron features directly.
wf.envelope
Type of envelope function to apply to orbitals.
wf.orbitals_spin_split
If True, use separate orbital layer and envelope parameters for each spin channel.
wf.full_det
Psiformer options (wf.*)#
wf.ndets
Number of determinants.
wf.num_layers
Number of Psiformer layers.
wf.num_heads
Number of attention heads.
wf.heads_dim
Dimension of each attention head.
Hidden dimensions for MLP blocks.
wf.layer_norm_mode
LayerNorm application mode.
wf.jastrow
Jastrow factor type.
wf.with_bias
Whether to use bias in attention QKV projections.
wf.input_bias
Whether to use bias in the input projection layer.
wf.envelope
Envelope type for orbital decay at infinity.
wf.orbitals_spin_split
If True, use separate orbital layer and envelope parameters for each spin channel.
wf.bias_orbitals
If True, include bias in the orbital projection layer.
wf.rescale
If True, use log-scaled input features (log(1+r)) instead of linear features.
wf.jastrow_alpha_init
Initial value for Jastrow alpha parameters.
wf.full_det
Train Stage (train.*)#
The main VMC optimization loop. Samples electron configurations, computes energy, and updates wavefunction parameters.
Run options (train.run.*)#
train.run.check_vma
Enable JAX validity checks during shard_map.
train.run.iterations
Total number of iterations to run.
train.run.burn_in
Sampling iterations to discard before the main loop for MCMC equilibration.
train.run.save_time_interval
Minimum wall-clock seconds between checkpoint saves.
train.run.save_step_interval
Save checkpoints only at steps that are multiples of this value.
train.run.stop_on_nan
Abort training when NaN is detected in step statistics. True checks all stat keys, False disables the check, or pass a comma-separated string of specific keys to monitor (e.g. "loss").
Optimizer (train.optim.*)#
Default optimizer module:
kfac. Effective defaults for the built-in optimizers are listed below.
KFAC options#
train.optim.learning_rate
The learning rate. Swappable component; the nested keys below are the options for the current module Standard and change when train.optim.learning_rate.module changes.
train.optim.learning_rate.module
Select the implementation used for this component.
train.optim.learning_rate.rate
Initial learning rate.
train.optim.learning_rate.delay
Delay in steps before decay starts.
train.optim.learning_rate.decay
Decay rate exponent.
train.optim.norm_constraint
The update is scaled down so that its approximate squared Fisher norm \(v^T F v\) is at most the specified value.
train.optim.curvature_ema
Decay factor used when calculating the covariance estimate moving averages.
train.optim.l2_reg
Tell the optimizer what L2 regularization coefficient you are using.
train.optim.inverse_update_period
Number of steps in between updating the inverse curvature approximation.
train.optim.damping
Fixed damping parameter.
SR options#
train.optim.learning_rate
Step size (scalar or schedule). Swappable component; the nested keys below are the options for the current module Standard and change when train.optim.learning_rate.module changes.
train.optim.learning_rate.module
Select the implementation used for this component.
train.optim.learning_rate.rate
Initial learning rate.
train.optim.learning_rate.delay
Delay in steps before decay starts.
train.optim.learning_rate.decay
Decay rate exponent.
train.optim.max_norm
Constrained update norm C (scalar or schedule). Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.max_norm.module changes.
train.optim.max_norm.module
Select the implementation used for this component.
train.optim.max_norm.rate
The constant rate.
train.optim.damping
Damping lambda (scalar or schedule). Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.damping.module changes.
train.optim.damping.module
Select the implementation used for this component.
train.optim.damping.rate
The constant rate.
train.optim.max_cond_num
Maximum condition number for adaptive damping.
train.optim.spring_mu
SPRING momentum coefficient mu (scalar or schedule). Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.spring_mu.module changes.
train.optim.spring_mu.module
Select the implementation used for this component.
train.optim.spring_mu.rate
The constant rate.
train.optim.march_beta
Decay factor for the MARCH variance accumulator (scalar or schedule). Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.march_beta.module changes.
train.optim.march_beta.module
Select the implementation used for this component.
train.optim.march_beta.rate
The constant rate.
train.optim.march_mode
MARCH variance mode. "diff" uses update differences and "var" uses score variance along the batch axis.
train.optim.eps
Small numerical constant for stability.
train.optim.mixed_precision
Whether to use mixed precision for Gram factorization.
train.optim.score_chunk_size
Chunk size for score computation.
train.optim.score_norm_clip
Optional clip value for the mean absolute score per batch row.
train.optim.gram_num_chunks
Number of chunks for Gram matrix computation.
train.optim.gram_dot_prec
Precision mode for Gram matrix dot products.
train.optim.prune_inactive
Whether to structurally prune inactive parameter leaves when forming the SR system.
Adam options#
train.optim.learning_rate
A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate(). Swappable component; the nested keys below are the options for the current module Standard and change when train.optim.learning_rate.module changes.
train.optim.learning_rate.module
Select the implementation used for this component.
train.optim.learning_rate.rate
Initial learning rate.
train.optim.learning_rate.delay
Delay in steps before decay starts.
train.optim.learning_rate.decay
Decay rate exponent.
train.optim.b1
Exponential decay rate to track the first moment of past gradients.
train.optim.b2
Exponential decay rate to track the second moment of past gradients.
train.optim.eps
A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
train.optim.eps_root
A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling.
LAMB options#
train.optim.learning_rate
A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate(). Swappable component; the nested keys below are the options for the current module Standard and change when train.optim.learning_rate.module changes.
train.optim.learning_rate.module
Select the implementation used for this component.
train.optim.learning_rate.rate
Initial learning rate.
train.optim.learning_rate.delay
Delay in steps before decay starts.
train.optim.learning_rate.decay
Decay rate exponent.
train.optim.b1
Exponential decay rate to track the first moment of past gradients.
train.optim.b2
Exponential decay rate to track the second moment of past gradients.
train.optim.eps
A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
train.optim.eps_root
A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling.
train.optim.weight_decay
Strength of the weight decay regularization. Swappable component; the nested keys below are the options for the current module Constant and change when train.optim.weight_decay.module changes.
train.optim.weight_decay.module
Select the implementation used for this component.
train.optim.weight_decay.rate
The constant rate.
Sampler (train.sampler.*)#
Default sampler module:
mcmc, and its effective keys are listed below.
train.sampler.steps
Number of Metropolis-Hastings updates per sample draw.
train.sampler.initial_width
Initial width (stddev) of the Gaussian proposal.
train.sampler.adapt_frequency
Frequency of adaptive width updates.
train.sampler.pmove_range
Target range for acceptance rate.
Writers (train.writers.*)#
The train stage enables console, csv, and hdf5 writers by default.
Console writer (train.writers.console.*)#
train.writers.console.interval
Step interval for logging.
train.writers.console.fields
Comma-separated list of field specs.
CSV writer (train.writers.csv.*)#
train.writers.csv.path_template
Output path template.
HDF5 writer (train.writers.hdf5.*)#
train.writers.hdf5.path_template
Output path template.
Loss gradients (train.grads.*)#
Loss and gradient estimator. Computes the VMC loss and parameter gradients.
train.grads.vmap_chunk_size
Number of walkers to evaluate per vmap chunk.
train.grads.loss_key
Key in prev_walker_stats to use as the loss.
train.grads.clip_scale
Multiplier on the interquartile range (IQR) that sets the clipping window for local energies.
Pretrain Stage (pretrain.*)#
Initializes the neural network to approximate Hartree-Fock orbitals before VMC training. It uses the same run, sampler, and writer schemas as the train stage, but with a different optimizer default and a workflow-wired supervised loss.
Run options (pretrain.run.*)#
pretrain.run.check_vma
Enable JAX validity checks during shard_map.
pretrain.run.iterations
Total number of iterations to run.
pretrain.run.burn_in
Sampling iterations to discard before the main loop for MCMC equilibration.
pretrain.run.save_time_interval
Minimum wall-clock seconds between checkpoint saves.
pretrain.run.save_step_interval
Save checkpoints only at steps that are multiples of this value.
pretrain.run.stop_on_nan
Abort training when NaN is detected in step statistics. True checks all stat keys, False disables the check, or pass a comma-separated string of specific keys to monitor (e.g. "loss").
Optimizer (pretrain.optim.*)#
Default optimizer module:
optax:adam, and its effective keys are listed below.
Effective Adam defaults#
pretrain.optim.learning_rate
A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate(). Swappable component; the nested keys below are the options for the current module Standard and change when pretrain.optim.learning_rate.module changes.
pretrain.optim.learning_rate.module
Select the implementation used for this component.
pretrain.optim.learning_rate.rate
Initial learning rate.
pretrain.optim.learning_rate.delay
Delay in steps before decay starts.
pretrain.optim.learning_rate.decay
Decay rate exponent.
pretrain.optim.b1
Exponential decay rate to track the first moment of past gradients.
pretrain.optim.b2
Exponential decay rate to track the second moment of past gradients.
pretrain.optim.eps
A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
pretrain.optim.eps_root
A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling.
Sampler (pretrain.sampler.*)#
Default sampler module:
mcmc.
pretrain.sampler.steps
Number of Metropolis-Hastings updates per sample draw.
pretrain.sampler.initial_width
Initial width (stddev) of the Gaussian proposal.
pretrain.sampler.adapt_frequency
Frequency of adaptive width updates.
pretrain.sampler.pmove_range
Target range for acceptance rate.
Writers (pretrain.writers.*)#
The pretrain stage enables console, csv, and hdf5 writers by default.
Console writer (pretrain.writers.console.*)#
pretrain.writers.console.interval
Step interval for logging.
pretrain.writers.console.fields
Comma-separated list of field specs.
CSV writer (pretrain.writers.csv.*)#
pretrain.writers.csv.path_template
Output path template.
HDF5 writer (pretrain.writers.hdf5.*)#
pretrain.writers.hdf5.path_template
Output path template.
Loss gradients#
Pretraining does not use configurable pretrain.grads.* settings. The workflow
wires a supervised Hartree-Fock orbital-matching loss directly.
Estimators (estimators.*)#
Energy estimators are configured programmatically by the workflow and are not typically overridden via config. The same definitions are used by Evaluation.
total_energyand the electron-nuclei potential are always added by the workflow and are not configurable via config keys.estimators.enabled.spindefaults tofalse.
Kinetic energy (estimators.energy.kinetic.*)#
estimators.energy.kinetic.vmap_chunk_size
Number of walkers to evaluate per vmap chunk.
estimators.energy.kinetic.mode
Laplacian computation strategy. forward_laplacian is the default for JAX 0.7.1 and later, scan for earlier versions.
estimators.energy.kinetic.sparsity_threshold
Sparsity threshold when using forward_laplacian mode.
ECP energy (estimators.energy.ecp.*)#
Added automatically when system.ecp is set.
estimators.energy.ecp.vmap_chunk_size
Number of walkers to evaluate per vmap chunk.
estimators.energy.ecp.max_core
Maximum number of nearest ECP atoms to consider per electron when evaluating nonlocal integrals.
estimators.energy.ecp.quadrature_id
Spherical quadrature rule used to evaluate nonlocal ECP integrals.
estimators.energy.ecp.electrons_field
Name of electron position field in data.
estimators.energy.ecp.atoms_field
Name of atom position field in data.