Thermal Hamiltonian Monte Carlo

Our implementation of Hamiltonian Monte Carlo (HMC) is a light wrapper around the AdvancedHMC.jl package. If you want to learn about the HMC theory, refer to the references and documentation provided with AdvancedHMC.jl.

Currently, our implementation works for systems with classical nuclei only (i.e. Simulation but not RingPolymerSimulation).

Example

In this example we use Hamiltonian Monte Carlo to sample the canonical distribution of a 3 dimensional harmonic oscillator potential containing 4 atoms.

using NQCDynamics
using Unitful
using UnitfulAtomic

sim = Simulation(Atoms([:H, :H, :C, :C]), Harmonic(dofs=3); temperature=300u"K")
r0 = randn(size(sim))
chain, stats = InitialConditions.ThermalMonteCarlo.run_advancedhmc_sampling(sim, r0, 1e4)
┌ Info: Finished 1000 adapation steps
  adaptor =
   StanHMCAdaptor(
       pc=WelfordVar,
       ssa=NesterovDualAveraging(γ=0.05, t_0=10.0, κ=0.75, δ=0.5, state.ϵ=1.1556299855213854),
       init_buffer=75, term_buffer=50, window_size=25,
       state=window(76, 950), window_splits(100, 150, 250, 450, 950)
   )
  κ.τ.integrator = Leapfrog(ϵ=1.16)
  h.metric = DiagEuclideanMetric([0.0009929229195338072, 0.0 ...])
┌ Info: Finished 10000 sampling steps for 1 chains in 1.147652449 (s)
  h = Hamiltonian(metric=DiagEuclideanMetric([0.0009929229195338072, 0.0 ...]), kinetic=AdvancedHMC.GaussianKinetic())
  κ = AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.GeneralisedNoUTurn{Float64}}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.MultinomialTS}(integrator=Leapfrog(ϵ=1.16), tc=AdvancedHMC.GeneralisedNoUTurn{Float64}(10, 1000.0)))
  EBFMI_est = 0.5507155334869924
  average_acceptance_rate = 0.5757917486804809

The Monte Carlo chain contains the nuclear configurations that we have sampled:

chain
10000-element Vector{Matrix{Float64}}:
 [0.4770541496303966 -0.008292282719788324 -0.3988100227093445 -0.26448192592997066; 0.600055051193342 0.41420154234410944 -0.41607400250985904 0.35134255056026564; -0.05944985300620953 0.21937436901997975 -0.18108760557359582 0.21185891475741758]
 [0.4770541496303966 -0.008292282719788324 -0.3988100227093445 -0.26448192592997066; 0.600055051193342 0.41420154234410944 -0.41607400250985904 0.35134255056026564; -0.05944985300620953 0.21937436901997975 -0.18108760557359582 0.21185891475741758]
 [0.4770541496303966 -0.008292282719788324 -0.3988100227093445 -0.26448192592997066; 0.600055051193342 0.41420154234410944 -0.41607400250985904 0.35134255056026564; -0.05944985300620953 0.21937436901997975 -0.18108760557359582 0.21185891475741758]
 [0.4770541496303966 -0.008292282719788324 -0.3988100227093445 -0.26448192592997066; 0.600055051193342 0.41420154234410944 -0.41607400250985904 0.35134255056026564; -0.05944985300620953 0.21937436901997975 -0.18108760557359582 0.21185891475741758]
 [-0.13486164616293933 -0.06713540759503968 0.09155009559610189 0.10329517507990718; -0.2285829954856472 -0.15393395264167295 0.124443175769824 -0.1501227918667077; 0.008659899763511336 -0.10087348428585637 0.07606374741563447 -0.0905785254615186]
 [-0.13486164616293933 -0.06713540759503968 0.09155009559610189 0.10329517507990718; -0.2285829954856472 -0.15393395264167295 0.124443175769824 -0.1501227918667077; 0.008659899763511336 -0.10087348428585637 0.07606374741563447 -0.0905785254615186]
 [-0.009833429927409723 -0.05683682956278916 -0.02094944798321998 0.02480581069881105; 0.028267412875527698 0.003619728995138932 -0.009194725806110124 -0.004451357045489421; -0.025360874476927264 -0.02284654196412153 -0.03881519960109297 0.006616997844048476]
 [-0.009833429927409723 -0.05683682956278916 -0.02094944798321998 0.02480581069881105; 0.028267412875527698 0.003619728995138932 -0.009194725806110124 -0.004451357045489421; -0.025360874476927264 -0.02284654196412153 -0.03881519960109297 0.006616997844048476]
 [0.003807240860798364 0.019128716011714935 0.061748472106802364 0.02674153950776929; -0.008069617911881242 -0.04253440855837008 -0.010580831024973055 -0.03241250148602971; 0.06000040795892202 -0.006459726816069463 -0.058497172364492526 -0.03861640463746929]
 [0.003807240860798364 0.019128716011714935 0.061748472106802364 0.02674153950776929; -0.008069617911881242 -0.04253440855837008 -0.010580831024973055 -0.03241250148602971; 0.06000040795892202 -0.006459726816069463 -0.058497172364492526 -0.03861640463746929]
 ⋮
 [0.012275953709583183 -0.03794044457088789 -0.026816678738858023 0.01247004288725917; 0.0783453106158807 0.04486830191526203 -0.032742770666345976 0.04051892233612195; 0.0372098366523395 0.001750992429792847 0.03124929695352502 0.004939158570444144]
 [-0.01548080229683786 0.029729846620187038 0.033568090426716754 0.03369934101848647; -0.055282342356072135 -0.021474295045428228 0.01285684561359389 -0.005711783494681412; -0.03862141111590476 -0.015996993779593818 0.00918516588684138 0.02909819669140583]
 [0.01768974230015598 -0.015734726229467713 -0.021354071541445516 -0.0391765304812437; 0.055706087842677776 0.01692176572663743 -0.03157634546681528 -0.010883133619811691; 0.02939383889085586 0.0019531470856290895 -0.009539485614463224 -0.04139057390714823]
 [-0.02677586753332599 -0.0007280955261372499 0.004806089944612197 0.019256762912543746; -0.051755767160736585 -0.01348056922688343 0.011213599752808556 -0.004877441143529747; -0.03101130818134649 0.007267320959521737 -0.016701784753896293 0.030172473944394144]
 [0.01587403484019645 0.002199174258576836 0.0010242924210963188 0.0053580429584118155; 0.06897666676260229 0.02957283468361099 -0.014398309637122178 0.05017490211630942; 0.004830363298812555 0.021797081930839783 0.005612089983150909 -0.009107498946800359]
 [-0.042236657579483715 0.017199150309697632 0.024228876647715503 0.011362344450648362; -0.0006297381690202192 0.030240087515389563 -0.0032066263893823672 -0.0004943361734538257; -0.013925138914203068 0.052207076371137376 -0.007007913195222903 -0.01659067123933492]
 [0.030650729217522887 -0.00083809574465641 -0.04500142629850365 -0.012555365243819021; -0.024608626156945732 -0.0374505309737697 0.008366556358392192 0.012340831427493763; 0.02979870292045403 -0.006716581213957601 -0.008512592287767441 -0.02541634949388881]
 [-0.0010221971392386786 -0.023460449630736888 0.017771265635716474 -0.04154449449553619; 0.020298861967536986 0.04785953187125548 -0.010345513568654253 0.007817933420001942; -0.046844072368523955 -0.005554142698632825 0.015905998938741634 0.0072495578152451945]
 [0.0116766026296365 0.03868884721987655 -0.04046120404709249 -0.00956302363364487; -0.015245863550601382 -0.035560869000242835 0.018160338262392883 0.030428171113836073; 0.03181528317260037 0.01688936407350467 -0.04579959580147623 -0.05334736362971225]

and stats contains extra information about the sampling procedure:

stats
10000-element Vector{NamedTuple}:
 (n_steps = 1, is_accept = true, acceptance_rate = 1.0, log_density = -744.2680561922676, hamiltonian_energy = 3419.698407693535, hamiltonian_energy_error = -5130.836514858936, max_hamiltonian_energy_error = -5130.836514858936, tree_depth = 1, numerical_error = false, step_size = 0.05, nom_step_size = 0.05, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -744.2680561922676, hamiltonian_energy = 748.1552258529941, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.9789308666050836e11, tree_depth = 0, numerical_error = true, step_size = 1.241032542311506, nom_step_size = 1.241032542311506, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -744.2680561922676, hamiltonian_energy = 749.3976649990748, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 8.358842074263023e8, tree_depth = 0, numerical_error = true, step_size = 0.5, nom_step_size = 0.5, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -744.2680561922676, hamiltonian_energy = 747.6079204814434, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 226542.08714372545, tree_depth = 0, numerical_error = true, step_size = 0.13192866018819532, nom_step_size = 0.13192866018819532, is_adapt = true)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -94.70911937140173, hamiltonian_energy = 608.6939455259768, hamiltonian_energy_error = -140.95235913953513, max_hamiltonian_energy_error = -140.95235913953513, tree_depth = 2, numerical_error = false, step_size = 0.028716309633808685, nom_step_size = 0.028716309633808685, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -94.70911937140173, hamiltonian_energy = 97.65492398959718, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 9619.312682503896, tree_depth = 0, numerical_error = true, step_size = 0.11260612534950616, nom_step_size = 0.11260612534950616, is_adapt = true)
 (n_steps = 7, is_accept = true, acceptance_rate = 0.9989927425715656, log_density = -4.217374710039435, hamiltonian_energy = 85.3012434509035, hamiltonian_energy_error = -13.039045626593534, max_hamiltonian_energy_error = -13.041299453213597, tree_depth = 3, numerical_error = false, step_size = 0.02340023160450795, nom_step_size = 0.02340023160450795, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 3.1116153013525464e-146, log_density = -4.217374710039435, hamiltonian_energy = 7.872429967351225, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 335.04228159627314, tree_depth = 1, numerical_error = false, step_size = 0.10512483611836129, nom_step_size = 0.10512483611836129, is_adapt = true)
 (n_steps = 7, is_accept = true, acceptance_rate = 0.689784177472622, log_density = -8.68388075376553, hamiltonian_energy = 14.231291269836156, hamiltonian_energy_error = 0.5440554716554562, max_hamiltonian_energy_error = 0.7404513199971792, tree_depth = 3, numerical_error = false, step_size = 0.021514901507904364, nom_step_size = 0.021514901507904364, is_adapt = true)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.021311930682430452, log_density = -8.68388075376553, hamiltonian_energy = 22.632433195237788, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 6.455466509839358, tree_depth = 2, numerical_error = false, step_size = 0.03858451398313371, nom_step_size = 0.03858451398313371, is_adapt = true)
 ⋮
 (n_steps = 3, is_accept = true, acceptance_rate = 0.38928482997387737, log_density = -8.272457624496953, hamiltonian_energy = 22.068843714809304, hamiltonian_energy_error = -1.257653327327116, max_hamiltonian_energy_error = 2.6554078986471517, tree_depth = 2, numerical_error = false, step_size = 1.1556299855213854, nom_step_size = 1.1556299855213854, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.8409640930529405, log_density = -5.147015279872097, hamiltonian_energy = 11.833314730275227, hamiltonian_energy_error = -1.0179059992050572, max_hamiltonian_energy_error = -1.0179059992050572, tree_depth = 2, numerical_error = false, step_size = 1.1556299855213854, nom_step_size = 1.1556299855213854, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -5.11993520933762, hamiltonian_energy = 6.963650783455326, hamiltonian_energy_error = -0.04947566066045006, max_hamiltonian_energy_error = -0.8710013971341013, tree_depth = 2, numerical_error = false, step_size = 1.1556299855213854, nom_step_size = 1.1556299855213854, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -3.3288798088679936, hamiltonian_energy = 6.250416612198777, hamiltonian_energy_error = -0.5939305066759868, max_hamiltonian_energy_error = -0.8213610105520344, tree_depth = 2, numerical_error = false, step_size = 1.1556299855213854, nom_step_size = 1.1556299855213854, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.39038443977217147, log_density = -4.8716884877330555, hamiltonian_energy = 10.576421333122003, hamiltonian_energy_error = 0.46854698808306416, max_hamiltonian_energy_error = 1.9762064400373927, tree_depth = 2, numerical_error = false, step_size = 1.1556299855213854, nom_step_size = 1.1556299855213854, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.9186137096411361, log_density = -3.6656865699283823, hamiltonian_energy = 7.329481954964212, hamiltonian_energy_error = -0.3189829838519014, max_hamiltonian_energy_error = -0.3189829838519014, tree_depth = 2, numerical_error = false, step_size = 1.1556299855213854, nom_step_size = 1.1556299855213854, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.5270491438038841, log_density = -3.6866142154764874, hamiltonian_energy = 8.63885225606998, hamiltonian_energy_error = 0.005307230926476336, max_hamiltonian_energy_error = 1.543252058849319, tree_depth = 2, numerical_error = false, step_size = 1.1556299855213854, nom_step_size = 1.1556299855213854, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.41055218070092686, log_density = -4.207545012230052, hamiltonian_energy = 9.782546456927312, hamiltonian_energy_error = 0.23725923921344538, max_hamiltonian_energy_error = 2.049867494878516, tree_depth = 2, numerical_error = false, step_size = 1.1556299855213854, nom_step_size = 1.1556299855213854, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.25286403802785223, log_density = -6.502555457424634, hamiltonian_energy = 12.320309324672829, hamiltonian_energy_error = 0.7169759830534748, max_hamiltonian_energy_error = 2.143765385119373, tree_depth = 2, numerical_error = false, step_size = 1.1556299855213854, nom_step_size = 1.1556299855213854, is_adapt = false)

Here we should see that the energy expectation for the generated ensemble matches with the equipartition theorem:

julia> Estimators.@estimate potential_energy(sim, chain)0.005978047555796304
julia> austrip(sim.temperature) * 3 * 4 / 20.005700260814198944