Thermal Hamiltonian Monte Carlo

Our implementation of Hamiltonian Monte Carlo (HMC) is a light wrapper around the AdvancedHMC.jl package. If you want to learn about the HMC theory, refer to the references and documentation provided with AdvancedHMC.jl.

Currently, our implementation works for systems with classical nuclei only (i.e. Simulation but not RingPolymerSimulation).

Example

In this example we use Hamiltonian Monte Carlo to sample the canonical distribution of a 3 dimensional harmonic oscillator potential containing 4 atoms.

using NQCDynamics
using Unitful
using UnitfulAtomic

sim = Simulation(Atoms([:H, :H, :C, :C]), Harmonic(dofs=3); temperature=300u"K")
r0 = randn(size(sim))
chain, stats = InitialConditions.ThermalMonteCarlo.run_advancedhmc_sampling(sim, r0, 1e4)
┌ Info: Finished 1000 adapation steps
  adaptor =
   StanHMCAdaptor(
       pc=WelfordVar,
       ssa=NesterovDualAveraging(γ=0.05, t_0=10.0, κ=0.75, δ=0.5, state.ϵ=1.0813645048439533),
       init_buffer=75, term_buffer=50, window_size=25,
       state=window(76, 950), window_splits(100, 150, 250, 450, 950)
   )
  κ.τ.integrator = Leapfrog(ϵ=1.08)
  h.metric = DiagEuclideanMetric([0.0009007461671184132, 0.0 ...])
┌ Info: Finished 10000 sampling steps for 1 chains in 1.017975053 (s)
  h = Hamiltonian(metric=DiagEuclideanMetric([0.0009007461671184132, 0.0 ...]), kinetic=AdvancedHMC.GaussianKinetic())
  κ = AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.GeneralisedNoUTurn{Float64}}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.MultinomialTS}(integrator=Leapfrog(ϵ=1.08), tc=AdvancedHMC.GeneralisedNoUTurn{Float64}(10, 1000.0)))
  EBFMI_est = 0.5083846696215895
  average_acceptance_rate = 0.6364440073258754

The Monte Carlo chain contains the nuclear configurations that we have sampled:

chain
10000-element Vector{Matrix{Float64}}:
 [-0.014415914300573006 0.04805549149568894 -0.014554469031885578 0.1780692153347233; 0.4960784411742638 -0.09169500221789036 -0.4603853830667064 -0.014729414518943124; -0.15655649732347032 -0.4574389000053569 0.23415227762986834 0.2735073117450326]
 [-0.014415914300573006 0.04805549149568894 -0.014554469031885578 0.1780692153347233; 0.4960784411742638 -0.09169500221789036 -0.4603853830667064 -0.014729414518943124; -0.15655649732347032 -0.4574389000053569 0.23415227762986834 0.2735073117450326]
 [-0.014415914300573006 0.04805549149568894 -0.014554469031885578 0.1780692153347233; 0.4960784411742638 -0.09169500221789036 -0.4603853830667064 -0.014729414518943124; -0.15655649732347032 -0.4574389000053569 0.23415227762986834 0.2735073117450326]
 [-0.014415914300573006 0.04805549149568894 -0.014554469031885578 0.1780692153347233; 0.4960784411742638 -0.09169500221789036 -0.4603853830667064 -0.014729414518943124; -0.15655649732347032 -0.4574389000053569 0.23415227762986834 0.2735073117450326]
 [0.016386889424236478 -0.025507566297429948 -0.019289493394010107 -0.12070604394893827; -0.17845220160300657 0.017947548792145565 0.1617058778251958 0.008123706457974698; 0.10200519116624752 0.19007067572754371 -0.055813753669714694 -0.06616320225131689]
 [0.016386889424236478 -0.025507566297429948 -0.019289493394010107 -0.12070604394893827; -0.17845220160300657 0.017947548792145565 0.1617058778251958 0.008123706457974698; 0.10200519116624752 0.19007067572754371 -0.055813753669714694 -0.06616320225131689]
 [-0.03996932749254287 -0.03376138709745012 0.014598694891294155 -0.02221897537869183; -0.003908214251265377 0.016964522336070185 -0.012785669671274383 0.037425813405547716; 0.014370643125647245 0.03342379568054368 -0.02296353558627285 0.0026020632726369702]
 [-0.03996932749254287 -0.03376138709745012 0.014598694891294155 -0.02221897537869183; -0.003908214251265377 0.016964522336070185 -0.012785669671274383 0.037425813405547716; 0.014370643125647245 0.03342379568054368 -0.02296353558627285 0.0026020632726369702]
 [0.019668236153912534 0.019791691042579145 -0.02884535138792994 0.006305772411620121; 7.254567703446173e-5 -0.0024374814412310153 0.027552354803441752 -0.02624824856735041; 0.0034825941447114914 -0.016944681457882062 0.0482098593843805 -0.018837430070690317]
 [-0.019470455736962988 -0.020886686181734393 0.028486266106373132 -0.007038353350674341; 0.0010291828550473975 0.0010217751716543877 -0.02700342479542256 0.026654120765021917; -0.003699574862256751 0.016730623017575243 -0.048368992572413505 0.01910341022202864]
 ⋮
 [0.015286652112319828 -0.0072217926945912225 -0.016750635956701874 -0.021568501472906364; 0.005318982434299421 -0.010323781894769268 -0.02434563283615659 0.022210303447742434; -0.015642375562431088 -0.012078301506710028 -0.001269148969694699 0.02458563914590195]
 [0.022548365261828303 -0.04136263646064135 0.021186680707531696 0.02648438106692125; -0.008845285856261972 -0.027558312135030977 0.05088103568975322 -0.008417979792532605; -0.005780267778076321 0.027388393189339926 0.06500988886628568 -0.03854782496509379]
 [-0.025197923542242956 0.026108002320423387 -0.045567577351901256 0.022106855687389522; 0.02061217185067245 0.00847826115596868 0.030787628382587643 -0.003948159105011277; 0.013098999023239181 -0.004070751980252283 -0.007396668981766746 -0.03977020865465155]
 [0.030282066451434955 -0.025227873788962255 0.012512476706473867 -0.02857564592436943; -0.0128660514688002 -0.024826704339866657 -0.04890738061708657 -0.0111668617604277; -0.019612795872386103 0.0029027748555208004 0.00023975839607488086 0.040420620999674015]
 [-0.028433528763302785 -0.021305086486183452 -0.026759613929544106 0.02730226232857107; 0.018021050265657368 0.016746501885611487 0.019814647287781685 0.019234075777346846; 0.026333700440240786 0.01641388108755759 0.006065542082150951 -0.010247585560298117]
 [-0.0362610472230494 -0.018464974970374443 0.014248629345164566 -0.0037497374742144866; -0.002492976785811254 0.010290471896100703 0.051392932705318695 -0.029022970519149375; -0.024398892379256566 0.047802371630213135 0.0070149364190477004 -0.004320931474233145]
 [-0.0362610472230494 -0.018464974970374443 0.014248629345164566 -0.0037497374742144866; -0.002492976785811254 0.010290471896100703 0.051392932705318695 -0.029022970519149375; -0.024398892379256566 0.047802371630213135 0.0070149364190477004 -0.004320931474233145]
 [0.008716532011451304 0.016737540219491552 -0.016209415102765905 -0.004654244469797407; -0.041701875610082415 -0.029949231775989023 0.026432706829304434 0.03224094638041265; 0.009998420963959084 0.024004710938680476 0.007070102187318481 -0.006393290844057568]
 [-0.00763052872755162 -0.03590836843234825 0.009250240044416092 -0.03369651702030138; -0.04883076305239963 0.01600836952330314 -0.027998087943039027 -0.04168864071912377; 0.056764617636984135 0.006459061698083461 0.020683347653245587 -0.015598702207821944]

and stats contains extra information about the sampling procedure:

stats
10000-element Vector{NamedTuple}:
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -454.9817601289648, hamiltonian_energy = 1831.4984566062838, hamiltonian_energy_error = -2635.340816247979, max_hamiltonian_energy_error = -2635.340816247979, tree_depth = 1, numerical_error = false, step_size = 0.05, nom_step_size = 0.05, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -454.9817601289648, hamiltonian_energy = 459.7396866626376, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.2079775393716478e11, tree_depth = 0, numerical_error = true, step_size = 1.241032542311506, nom_step_size = 1.241032542311506, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -454.9817601289648, hamiltonian_energy = 466.108180419175, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 5.0647379676265377e8, tree_depth = 0, numerical_error = true, step_size = 0.5, nom_step_size = 0.5, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -454.9817601289648, hamiltonian_energy = 458.5787718475182, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 132456.38328382533, tree_depth = 0, numerical_error = true, step_size = 0.13192866018819532, nom_step_size = 0.13192866018819532, is_adapt = true)
 (n_steps = 7, is_accept = true, acceptance_rate = 1.0, log_density = -67.50636017572734, hamiltonian_energy = 375.08694297515956, hamiltonian_energy_error = -84.08101041492642, max_hamiltonian_energy_error = -85.86988950780528, tree_depth = 2, numerical_error = false, step_size = 0.028716309633808685, nom_step_size = 0.028716309633808685, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -67.50636017572734, hamiltonian_energy = 72.49149776449985, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 7584.566422921895, tree_depth = 0, numerical_error = true, step_size = 0.11260612534950616, nom_step_size = 0.11260612534950616, is_adapt = true)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -3.7730786054583803, hamiltonian_energy = 61.43350000256024, hamiltonian_energy_error = -9.183392025837392, max_hamiltonian_energy_error = -9.183392025837392, tree_depth = 2, numerical_error = false, step_size = 0.02340023160450795, nom_step_size = 0.02340023160450795, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 5.297230825587318e-193, log_density = -3.7730786054583803, hamiltonian_energy = 9.015380540698054, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 442.7317387495742, tree_depth = 1, numerical_error = false, step_size = 0.10545494475829686, nom_step_size = 0.10545494475829686, is_adapt = true)
 (n_steps = 7, is_accept = true, acceptance_rate = 0.7696696895898827, log_density = -3.201278257434399, hamiltonian_energy = 10.9547693500218, hamiltonian_energy_error = -0.0700921197481339, max_hamiltonian_energy_error = 0.5881806248321677, tree_depth = 3, numerical_error = false, step_size = 0.021583114937805906, nom_step_size = 0.021583114937805906, is_adapt = true)
 (n_steps = 15, is_accept = true, acceptance_rate = 0.19459869989527698, log_density = -3.2188994250145146, hamiltonian_energy = 7.341042283713571, hamiltonian_energy_error = 0.011506389517477622, max_hamiltonian_energy_error = 5.413804929019674, tree_depth = 4, numerical_error = false, step_size = 0.04981428822470307, nom_step_size = 0.04981428822470307, is_adapt = true)
 ⋮
 (n_steps = 3, is_accept = true, acceptance_rate = 0.47238051736577297, log_density = -1.7099918715287117, hamiltonian_energy = 7.962569937041674, hamiltonian_energy_error = 0.014870667299246243, max_hamiltonian_energy_error = 1.557521750660845, tree_depth = 2, numerical_error = false, step_size = 1.0813645048439533, nom_step_size = 1.0813645048439533, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.45564964433540905, log_density = -7.03272673814734, hamiltonian_energy = 10.249707770475695, hamiltonian_energy_error = 1.554122254414919, max_hamiltonian_energy_error = 1.8124045594771676, tree_depth = 2, numerical_error = false, step_size = 1.0813645048439533, nom_step_size = 1.0813645048439533, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.7222929960054078, log_density = -3.7716294630870055, hamiltonian_energy = 10.247914953828003, hamiltonian_energy_error = -0.891140300551795, max_hamiltonian_energy_error = -0.891140300551795, tree_depth = 2, numerical_error = false, step_size = 1.0813645048439533, nom_step_size = 1.0813645048439533, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.48139167725144355, log_density = -4.132485613977428, hamiltonian_energy = 12.193535879851385, hamiltonian_energy_error = 0.012861584652101854, max_hamiltonian_energy_error = 1.7147668803213616, tree_depth = 2, numerical_error = false, step_size = 1.0813645048439533, nom_step_size = 1.0813645048439533, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.971643840454981, log_density = -2.7347797661041424, hamiltonian_energy = 5.753188801223445, hamiltonian_energy_error = -0.38093470146894326, max_hamiltonian_energy_error = -0.38093470146894326, tree_depth = 2, numerical_error = false, step_size = 1.0813645048439533, nom_step_size = 1.0813645048439533, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.6041342956139565, log_density = -4.429708034180563, hamiltonian_energy = 6.847279719195399, hamiltonian_energy_error = 0.5105714318959453, max_hamiltonian_energy_error = 0.5991013367241713, tree_depth = 2, numerical_error = false, step_size = 1.0813645048439533, nom_step_size = 1.0813645048439533, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.24352744774159432, log_density = -4.429708034180563, hamiltonian_energy = 11.83614311572163, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.6521857518128726, tree_depth = 2, numerical_error = false, step_size = 1.0813645048439533, nom_step_size = 1.0813645048439533, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.7598628978726358, log_density = -3.042889672893302, hamiltonian_energy = 7.276906168310967, hamiltonian_energy_error = -0.4055604067138727, max_hamiltonian_energy_error = 0.513842351826689, tree_depth = 2, numerical_error = false, step_size = 1.0813645048439533, nom_step_size = 1.0813645048439533, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.271078592681927, log_density = -6.139859467136345, hamiltonian_energy = 12.223265755141123, hamiltonian_energy_error = 0.7714110095772817, max_hamiltonian_energy_error = 2.44864430078062, tree_depth = 2, numerical_error = false, step_size = 1.0813645048439533, nom_step_size = 1.0813645048439533, is_adapt = false)

Here we should see that the energy expectation for the generated ensemble matches with the equipartition theorem:

julia> Estimators.@estimate potential_energy(sim, chain)0.005893177675448116
julia> austrip(sim.temperature) * 3 * 4 / 20.005700260814198944