Workflows#
Workflows orchestrate the full training or evaluation pipeline. A workflow composes one or more Work Stages — for example, a training workflow might run a pretrain stage followed by a VMC training stage.
Use the built-in VMCWorkflow and EvaluationWorkflow base classes, or write a plain function that creates and returns a workflow instance.
- class jaqmc.workflow.base.Workflow(cfg)[source]#
Base class for all workflows.
Subclasses must override
run().
- class jaqmc.workflow.vmc.VMCWorkflow(cfg)[source]#
VMC workflow with pretrain -> train -> eval pipeline.
Subclass and set stages in
__init__:class MyWorkflow(VMCWorkflow): def __init__(self, cfg): super().__init__(cfg) train = VMCWorkStage.builder(cfg.scoped("train"), wf) sampler = cfg.get("sampler", MCMCSampler) train.configure_sample_plan(wf.logpsi, {"electrons": sampler}) train.configure_optimizer( default="jaqmc.optimizer.kfac", f_log_psi=wf.logpsi ) train.configure_estimators(...) train.configure_loss_grads(f_log_psi=wf.logpsi) self.train_stage = train.build() self.data_init = data_init MyWorkflow(cfg)()
- restore_checkpoint(checkpoint_path, *, stage='train', rngs=None)[source]#
Restore state from a checkpoint.
- Parameters:
- Returns:
Restored state
- Raises:
ValueError – Invalid stage name passed.
- class jaqmc.workflow.evaluation.EvaluationWorkflow(cfg)[source]#
Evaluation workflow that loads params from a training checkpoint.
Creates fresh evaluation state (data, estimator_state), then loads
params,batched_data, andsampler_statefrom the training checkpoint. The evaluation stage handles its own checkpointing for resumability.
Workflow configuration#
- class jaqmc.workflow.base.WorkflowConfig(*, seed=None, batch_size=4096, save_path='', restore_path='', config=<factory>)[source]#
Base configuration for workflows.
- Parameters:
seed (
int|None, default:None) – Fixed random seed. If not provided, current time will be used.batch_size (
int, default:4096) – Number of walkers (samples) to use in each iteration.save_path (
str, default:'') – Path to save checkpoints and logs. Can be any path supported by fsspec/universal_pathlib.restore_path (
str, default:'') – Path to restore checkpoints from. When set, checkpoints are restored from this path instead ofsave_path. Can be a directory or a specific checkpoint file.config (
ConfigCheck, default:<factory>) – Controls config validation behavior (extra-key warnings, verbose output).
- class jaqmc.workflow.base.ConfigCheck(*, ignore_extra=False, verbose=False)[source]#
Controls config validation behavior.
- class jaqmc.workflow.evaluation.EvaluationWorkflowConfig(*, seed=None, batch_size=4096, save_path='', restore_path='', config=<factory>, source_path)[source]#
Workflow config for evaluation.
Extends
WorkflowConfig.- Parameters:
source_path (
str) – Path to the training run directory or checkpoint file to load parameters from.