Workflows#

Workflows orchestrate the full training or evaluation pipeline. A workflow composes one or more Work Stages — for example, a training workflow might run a pretrain stage followed by a VMC training stage.

Use the built-in VMCWorkflow and EvaluationWorkflow base classes, or write a plain function that creates and returns a workflow instance.

class jaqmc.workflow.base.Workflow(cfg)[source]#

Base class for all workflows.

Subclasses must override run().

prepare(dry_run=False)[source]#

Finalize config and log startup info.

On the master process, validates unused config keys and writes the resolved config to disk.

Return type:

None

run()[source]#

Execute the workflow.

Subclasses must override this method.

Return type:

None

class jaqmc.workflow.vmc.VMCWorkflow(cfg)[source]#

VMC workflow with pretrain -> train -> eval pipeline.

Subclass and set stages in __init__:

class MyWorkflow(VMCWorkflow):
    def __init__(self, cfg):
        super().__init__(cfg)
        train = VMCWorkStage.builder(cfg.scoped("train"), wf)
        sampler = cfg.get("sampler", MCMCSampler)
        train.configure_sample_plan(wf.logpsi, {"electrons": sampler})
        train.configure_optimizer(
            default="jaqmc.optimizer.kfac", f_log_psi=wf.logpsi
        )
        train.configure_estimators(...)
        train.configure_loss_grads(f_log_psi=wf.logpsi)
        self.train_stage = train.build()
        self.data_init = data_init

MyWorkflow(cfg)()
train_stage[source]#

The main training stage.

pretrain_stage[source]#

Optional pretraining stage.

data_init[source]#

Function to initialize electron configurations.

restore_checkpoint(checkpoint_path, *, stage='train', rngs=None)[source]#

Restore state from a checkpoint.

Parameters:
  • checkpoint_path (str | Path) – Path to checkpoint file or directory.

  • stage (Literal['pretrain', 'train'], default: 'train') – Name of the stage to restore ("train" or "pretrain").

  • rngs (PRNGKey | None, default: None) – Random key for create_state. Defaults to PRNGKey(0).

Returns:

Restored state

Raises:

ValueError – Invalid stage name passed.

run()[source]#

Execute the pretrain -> train -> eval pipeline.

Override in subclasses to inject pre-run logic (e.g. SCF), then call super().run(context, rngs).

Return type:

None

class jaqmc.workflow.evaluation.EvaluationWorkflow(cfg)[source]#

Evaluation workflow that loads params from a training checkpoint.

Creates fresh evaluation state (data, estimator_state), then loads params, batched_data, and sampler_state from the training checkpoint. The evaluation stage handles its own checkpointing for resumability.

eval_stage[source]#

The evaluation stage.

data_init[source]#

Function to initialize electron configurations.

run()[source]#

Execute the evaluation workflow.

  1. Create fresh eval state (data + estimator_state as template)

  2. Load params, data, sampler_state from training checkpoint

  3. Run evaluation (the stage handles its own checkpoint for resumability)

Return type:

None

Workflow configuration#

class jaqmc.workflow.base.WorkflowConfig(*, seed=None, batch_size=4096, save_path='', restore_path='', config=<factory>)[source]#

Base configuration for workflows.

Parameters:
  • seed (int | None, default: None) – Fixed random seed. If not provided, current time will be used.

  • batch_size (int, default: 4096) – Number of walkers (samples) to use in each iteration.

  • save_path (str, default: '') – Path to save checkpoints and logs. Can be any path supported by fsspec/universal_pathlib.

  • restore_path (str, default: '') – Path to restore checkpoints from. When set, checkpoints are restored from this path instead of save_path. Can be a directory or a specific checkpoint file.

  • config (ConfigCheck, default: <factory>) – Controls config validation behavior (extra-key warnings, verbose output).

class jaqmc.workflow.base.ConfigCheck(*, ignore_extra=False, verbose=False)[source]#

Controls config validation behavior.

Parameters:
  • ignore_extra (bool, default: False) – If True, silently ignore unrecognized config keys. If False, raise an error on extra keys.

  • verbose (bool, default: False) – If True, print the fully resolved config with field descriptions at startup.

class jaqmc.workflow.evaluation.EvaluationWorkflowConfig(*, seed=None, batch_size=4096, save_path='', restore_path='', config=<factory>, source_path)[source]#

Workflow config for evaluation.

Extends WorkflowConfig.

Parameters:

source_path (str) – Path to the training run directory or checkpoint file to load parameters from.