Thermal Hamiltonian Monte Carlo

Our implementation of Hamiltonian Monte Carlo (HMC) is a light wrapper around the AdvancedHMC.jl package. If you want to learn about the HMC theory, refer to the references and documentation provided with AdvancedHMC.jl.

Currently, our implementation works for systems with classical nuclei only (i.e. Simulation but not RingPolymerSimulation).

Example

In this example we use Hamiltonian Monte Carlo to sample the canonical distribution of a 3 dimensional harmonic oscillator potential containing 4 atoms.

using NQCDynamics
using Unitful
using UnitfulAtomic

sim = Simulation(Atoms([:H, :H, :C, :C]), Harmonic(dofs=3); temperature=300u"K")
r0 = randn(size(sim))
chain, stats = InitialConditions.ThermalMonteCarlo.run_advancedhmc_sampling(sim, r0, 1e4)
┌ Info: Finished 1000 adapation steps
  adaptor =
   StanHMCAdaptor(
       pc=WelfordVar,
       ssa=NesterovDualAveraging(γ=0.05, t_0=10.0, κ=0.75, δ=0.5, state.ϵ=1.1358019454652142),
       init_buffer=75, term_buffer=50, window_size=25,
       state=window(76, 950), window_splits(100, 150, 250, 450, 950)
   )
  κ.τ.integrator = Leapfrog(ϵ=1.14)
  h.metric = DiagEuclideanMetric([0.0008761264415207006, 0.0 ...])
┌ Info: Finished 10000 sampling steps for 1 chains in 0.979490942 (s)
  h = Hamiltonian(metric=DiagEuclideanMetric([0.0008761264415207006, 0.0 ...]), kinetic=AdvancedHMC.GaussianKinetic())
  κ = AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.GeneralisedNoUTurn{Float64}}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.MultinomialTS}(integrator=Leapfrog(ϵ=1.14), tc=AdvancedHMC.GeneralisedNoUTurn{Float64}(10, 1000.0)))
  EBFMI_est = 0.5490644882040283
  average_acceptance_rate = 0.6160005523266187

The Monte Carlo chain contains the nuclear configurations that we have sampled:

chain
10000-element Vector{Matrix{Float64}}:
 [0.17495509156579414 -0.42241566556935095 0.38868650623799894 0.5119404347848917; 0.3599457341474037 -0.38353116694528877 -0.1171962353414514 -0.008240921385553823; 0.1257869333190489 0.4060578305130733 0.37297391702357063 -0.30908081828262357]
 [0.17495509156579414 -0.42241566556935095 0.38868650623799894 0.5119404347848917; 0.3599457341474037 -0.38353116694528877 -0.1171962353414514 -0.008240921385553823; 0.1257869333190489 0.4060578305130733 0.37297391702357063 -0.30908081828262357]
 [0.17495509156579414 -0.42241566556935095 0.38868650623799894 0.5119404347848917; 0.3599457341474037 -0.38353116694528877 -0.1171962353414514 -0.008240921385553823; 0.1257869333190489 0.4060578305130733 0.37297391702357063 -0.30908081828262357]
 [0.17495509156579414 -0.42241566556935095 0.38868650623799894 0.5119404347848917; 0.3599457341474037 -0.38353116694528877 -0.1171962353414514 -0.008240921385553823; 0.1257869333190489 0.4060578305130733 0.37297391702357063 -0.30908081828262357]
 [-0.06843498764080462 0.1097837891441959 -0.19873263321346138 -0.16929978357221287; -0.09093994423080998 0.17624240746173914 0.0038565523110889455 0.007041280134739033; -0.04874481951136514 -0.09382567470860503 -0.1103241951514172 0.1407763684439736]
 [-0.06843498764080462 0.1097837891441959 -0.19873263321346138 -0.16929978357221287; -0.09093994423080998 0.17624240746173914 0.0038565523110889455 0.007041280134739033; -0.04874481951136514 -0.09382567470860503 -0.1103241951514172 0.1407763684439736]
 [-0.09179950553939102 0.03013256218189267 0.0506992802690778 0.021037700460078218; 0.006771237045441192 -0.014604848046416818 0.05264639099706991 -0.005436698032824222; -0.03490255804800346 -0.0076393176067472285 -0.0011362188909672805 0.000782235162648845]
 [-0.09179950553939102 0.03013256218189267 0.0506992802690778 0.021037700460078218; 0.006771237045441192 -0.014604848046416818 0.05264639099706991 -0.005436698032824222; -0.03490255804800346 -0.0076393176067472285 -0.0011362188909672805 0.000782235162648845]
 [-0.012939266331387922 -0.05008794337355776 0.027673992917234456 0.014019911163478347; 0.054948116379555434 -0.020267904520370826 0.0024790702952615165 0.032062669566813334; 0.03656082887568928 -0.013916479976904777 -0.05341130191187886 -0.007036838785653015]
 [-0.012939266331387922 -0.05008794337355776 0.027673992917234456 0.014019911163478347; 0.054948116379555434 -0.020267904520370826 0.0024790702952615165 0.032062669566813334; 0.03656082887568928 -0.013916479976904777 -0.05341130191187886 -0.007036838785653015]
 ⋮
 [0.07236054515278299 -0.017715135244409957 0.010297927307371663 -0.022301894934100088; -0.00616407824834834 -0.0128131613639947 -0.0210571772185496 -0.032819430849950024; -0.0030005788728570803 0.03908118151998814 -0.0528100339053259 -0.0102430271802708]
 [-0.06567417486371571 0.02008357959860345 -0.016285180516398018 0.025238533875129518; 0.0030742803020152067 0.004617923339384766 0.006420937582030207 0.04158733705327716; -0.03669665920963058 -0.04802121227492409 0.02233327370967527 0.009928457678745261]
 [0.0021545536722516606 0.01228045366754727 -0.05948396567904542 0.03812553747592636; -0.01030886672601147 0.023174369755528026 0.004406281374849603 0.033625494357209705; -0.0006979029458509603 -0.013940871819582097 -0.01182021356811365 -0.0655780416268127]
 [0.001974208406628873 0.010629949173161513 0.02993827428233955 -0.03601405285942004; 0.02598971158653519 0.013998228465880315 -0.02615462775641951 -0.036115751040152896; -0.011045486102502092 0.045760597726847825 0.04731965338643348 0.058464462169982016]
 [0.003424305809139395 -0.007610782724422359 -0.022405307564380073 0.020309532145804998; -0.021304903519541864 -0.016744926335761707 0.03765856053350978 0.04371035163216976; 0.033807799887140055 -0.013973243791290814 -0.02198456172575842 -0.04176449322261249]
 [-0.001154361278105173 0.007329064082959487 0.011240178927696343 -0.016995447243032546; 0.019464555669561255 0.020734768446391015 -0.04908499741134013 -0.06601357894229487; -0.0029882066330856882 0.009247993157349987 0.0592401219290701 0.04906138765361412]
 [-0.03464932645981842 0.014455328589195852 -0.003741977457608757 -0.0005431213482706351; 0.00013244349689361598 0.025029357553585707 -0.05933806710511645 -0.040114332876024575; -0.007100081186356624 0.03929686673874391 0.020416992282924547 0.0232678538275159]
 [-0.03464932645981842 0.014455328589195852 -0.003741977457608757 -0.0005431213482706351; 0.00013244349689361598 0.025029357553585707 -0.05933806710511645 -0.040114332876024575; -0.007100081186356624 0.03929686673874391 0.020416992282924547 0.0232678538275159]
 [0.027121544774802237 0.027216427828640033 0.028850246892995743 0.03817654010263323; 0.0395642592112478 0.04624151759177968 0.05908437381785949 0.009384056059970482; 0.017004594870342827 -0.059065775125792905 6.143703632958963e-5 0.010409052386231264]

and stats contains extra information about the sampling procedure:

stats
10000-element Vector{NamedTuple}:
 (n_steps = 1, is_accept = true, acceptance_rate = 1.0, log_density = -698.9207852098116, hamiltonian_energy = 3182.9169027460116, hamiltonian_energy_error = -4754.405219124301, max_hamiltonian_energy_error = -4754.405219124301, tree_depth = 1, numerical_error = false, step_size = 0.05, nom_step_size = 0.05, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -698.9207852098116, hamiltonian_energy = 701.5818849735638, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.852671579751894e11, tree_depth = 0, numerical_error = true, step_size = 1.241032542311506, nom_step_size = 1.241032542311506, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -698.9207852098116, hamiltonian_energy = 706.0024178853896, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 7.885161371510936e8, tree_depth = 0, numerical_error = true, step_size = 0.5, nom_step_size = 0.5, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -698.9207852098116, hamiltonian_energy = 705.970946240534, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 213938.86371376424, tree_depth = 0, numerical_error = true, step_size = 0.13192866018819532, nom_step_size = 0.13192866018819532, is_adapt = true)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -88.13151737272914, hamiltonian_energy = 573.0515480404932, hamiltonian_energy_error = -132.5394561733026, max_hamiltonian_energy_error = -132.5394561733026, tree_depth = 2, numerical_error = false, step_size = 0.028716309633808685, nom_step_size = 0.028716309633808685, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -88.13151737272914, hamiltonian_energy = 91.1447519468772, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 8384.73380464544, tree_depth = 0, numerical_error = true, step_size = 0.11260612534950616, nom_step_size = 0.11260612534950616, is_adapt = true)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -8.782189947300319, hamiltonian_energy = 84.28280799931532, hamiltonian_energy_error = -11.433523628166284, max_hamiltonian_energy_error = -11.433523628166284, tree_depth = 2, numerical_error = false, step_size = 0.02340023160450795, nom_step_size = 0.02340023160450795, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 1.2824063638657347e-244, log_density = -8.782189947300319, hamiltonian_energy = 13.693237231280502, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 561.5820244057808, tree_depth = 1, numerical_error = false, step_size = 0.10545494475829686, nom_step_size = 0.10545494475829686, is_adapt = true)
 (n_steps = 7, is_accept = true, acceptance_rate = 0.9956273742340843, log_density = -6.597343067885718, hamiltonian_energy = 14.45454424473207, hamiltonian_energy_error = -0.2678217137021548, max_hamiltonian_energy_error = -0.2678217137021548, tree_depth = 3, numerical_error = false, step_size = 0.021583114937805906, nom_step_size = 0.021583114937805906, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 2.0e-323, log_density = -6.597343067885718, hamiltonian_energy = 16.542900060107097, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 742.9512342575928, tree_depth = 1, numerical_error = false, step_size = 0.10168221778980614, nom_step_size = 0.10168221778980614, is_adapt = true)
 ⋮
 (n_steps = 3, is_accept = true, acceptance_rate = 0.688686237425633, log_density = -6.476624999341613, hamiltonian_energy = 9.98572802122489, hamiltonian_energy_error = 0.639233153064513, max_hamiltonian_energy_error = 0.639233153064513, tree_depth = 2, numerical_error = false, step_size = 1.1358019454652142, nom_step_size = 1.1358019454652142, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.8257229577582031, log_density = -6.141919270255924, hamiltonian_energy = 12.20892535629821, hamiltonian_energy_error = -0.07327914859253148, max_hamiltonian_energy_error = 0.7396506971161614, tree_depth = 2, numerical_error = false, step_size = 1.1358019454652142, nom_step_size = 1.1358019454652142, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.8735062227075425, log_density = -6.092245303996102, hamiltonian_energy = 10.593134293243033, hamiltonian_energy_error = -0.05935594009506673, max_hamiltonian_energy_error = 0.26678730779048543, tree_depth = 2, numerical_error = false, step_size = 1.1358019454652142, nom_step_size = 1.1358019454652142, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.8219560010221948, log_density = -6.864586489762573, hamiltonian_energy = 10.573621920158356, hamiltonian_energy_error = 0.2999219389090797, max_hamiltonian_energy_error = 0.32159472619015617, tree_depth = 2, numerical_error = false, step_size = 1.1358019454652142, nom_step_size = 1.1358019454652142, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.7990835627023928, log_density = -4.532945746503692, hamiltonian_energy = 11.426878378151198, hamiltonian_energy_error = -0.7460469867286683, max_hamiltonian_energy_error = 0.9231877413901515, tree_depth = 2, numerical_error = false, step_size = 1.1358019454652142, nom_step_size = 1.1358019454652142, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.5442435367616878, log_density = -7.39809846831846, hamiltonian_energy = 10.499465243949517, hamiltonian_energy_error = 0.9107185949431287, max_hamiltonian_energy_error = 1.4675237910466628, tree_depth = 2, numerical_error = false, step_size = 1.1358019454652142, nom_step_size = 1.1358019454652142, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -5.1225989740124245, hamiltonian_energy = 8.770533976366366, hamiltonian_energy_error = -0.7554998036383793, max_hamiltonian_energy_error = -1.69565601322671, tree_depth = 2, numerical_error = false, step_size = 1.1358019454652142, nom_step_size = 1.1358019454652142, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.4420483181163701, log_density = -5.1225989740124245, hamiltonian_energy = 11.754701464067045, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.0859513113166468, tree_depth = 2, numerical_error = false, step_size = 1.1358019454652142, nom_step_size = 1.1358019454652142, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.3525808483004173, log_density = -7.860168539252162, hamiltonian_energy = 13.315220210688757, hamiltonian_energy_error = 0.7903027859794882, max_hamiltonian_energy_error = 1.38012722547907, tree_depth = 2, numerical_error = false, step_size = 1.1358019454652142, nom_step_size = 1.1358019454652142, is_adapt = false)

Here we should see that the energy expectation for the generated ensemble matches with the equipartition theorem:

julia> Estimators.@estimate potential_energy(sim, chain)0.006009909492176189
julia> austrip(sim.temperature) * 3 * 4 / 20.005700260814198944