Thermal Hamiltonian Monte Carlo

Our implementation of Hamiltonian Monte Carlo (HMC) is a light wrapper around the AdvancedHMC.jl package. If you want to learn about the HMC theory, refer to the references and documentation provided with AdvancedHMC.jl.

Currently, our implementation works for systems with classical nuclei only (i.e. Simulation but not RingPolymerSimulation).

Example

In this example we use Hamiltonian Monte Carlo to sample the canonical distribution of a 3 dimensional harmonic oscillator potential containing 4 atoms.

using NQCDynamics
using Unitful
using UnitfulAtomic

sim = Simulation(Atoms([:H, :H, :C, :C]), Harmonic(dofs=3); temperature=300u"K")
r0 = randn(size(sim))
chain, stats = InitialConditions.ThermalMonteCarlo.run_advancedhmc_sampling(sim, r0, 1e4)
┌ Info: Finished 1000 adapation steps
  adaptor =
   StanHMCAdaptor(
       pc=WelfordVar,
       ssa=NesterovDualAveraging(γ=0.05, t_0=10.0, κ=0.75, δ=0.5, state.ϵ=1.1551106684592662),
       init_buffer=75, term_buffer=50, window_size=25,
       state=window(76, 950), window_splits(100, 150, 250, 450, 950)
   )
  κ.τ.integrator = Leapfrog(ϵ=1.16)
  h.metric = DiagEuclideanMetric([0.001002628941135932, 0.00 ...])
┌ Info: Finished 10000 sampling steps for 1 chains in 1.029648794 (s)
  h = Hamiltonian(metric=DiagEuclideanMetric([0.001002628941135932, 0.00 ...]), kinetic=AdvancedHMC.GaussianKinetic())
  κ = AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.GeneralisedNoUTurn{Float64}}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.MultinomialTS}(integrator=Leapfrog(ϵ=1.16), tc=AdvancedHMC.GeneralisedNoUTurn{Float64}(10, 1000.0)))
  EBFMI_est = 0.5329133569314974
  average_acceptance_rate = 0.6014288968410406

The Monte Carlo chain contains the nuclear configurations that we have sampled:

chain
10000-element Vector{Matrix{Float64}}:
 [0.10904280412524114 0.18381867141703256 0.24620289341762946 -0.09124234160068409; 0.09734755530911476 0.26457854438036243 0.48282402661358725 0.23277144818802797; -0.22568279427211557 0.20054285757592866 0.04051124223704414 0.26339246731328847]
 [0.10904280412524114 0.18381867141703256 0.24620289341762946 -0.09124234160068409; 0.09734755530911476 0.26457854438036243 0.48282402661358725 0.23277144818802797; -0.22568279427211557 0.20054285757592866 0.04051124223704414 0.26339246731328847]
 [0.10904280412524114 0.18381867141703256 0.24620289341762946 -0.09124234160068409; 0.09734755530911476 0.26457854438036243 0.48282402661358725 0.23277144818802797; -0.22568279427211557 0.20054285757592866 0.04051124223704414 0.26339246731328847]
 [0.10904280412524114 0.18381867141703256 0.24620289341762946 -0.09124234160068409; 0.09734755530911476 0.26457854438036243 0.48282402661358725 0.23277144818802797; -0.22568279427211557 0.20054285757592866 0.04051124223704414 0.26339246731328847]
 [-0.03447241840195614 -0.09300614895969457 -0.17807443612092422 -0.0018732269397189888; -0.029018154825651968 -0.08706749089244087 -0.22332307726904277 -0.07323020014721238; 0.03539583627781864 -0.0673994518691274 0.00242478781975293 -0.0815192928216473]
 [-0.03447241840195614 -0.09300614895969457 -0.17807443612092422 -0.0018732269397189888; -0.029018154825651968 -0.08706749089244087 -0.22332307726904277 -0.07323020014721238; 0.03539583627781864 -0.0673994518691274 0.00242478781975293 -0.0815192928216473]
 [-0.03447218915370925 -0.02950426138457924 -0.018798427872207826 0.03021582181959529; -0.027520619721023504 0.020543952087105896 0.0328027494524433 0.0007933848840523394; -0.03067326325990996 0.013315670598410212 -0.0007708942190339128 0.009608140379620551]
 [-0.03447218915370925 -0.02950426138457924 -0.018798427872207826 0.03021582181959529; -0.027520619721023504 0.020543952087105896 0.0328027494524433 0.0007933848840523394; -0.03067326325990996 0.013315670598410212 -0.0007708942190339128 0.009608140379620551]
 [0.03825833260198358 0.01673674168884911 0.018886783747162572 -0.028074092918619606; 0.03322596749327229 -0.0011027108630984675 -0.012839947503663216 -0.02656217603970181; 0.015497097027458797 0.01615729339008508 0.0014799031791261367 -0.023069448729405858]
 [0.03825833260198358 0.01673674168884911 0.018886783747162572 -0.028074092918619606; 0.03322596749327229 -0.0011027108630984675 -0.012839947503663216 -0.02656217603970181; 0.015497097027458797 0.01615729339008508 0.0014799031791261367 -0.023069448729405858]
 ⋮
 [0.023546037633665445 -0.03757859961613142 -0.0059401390558872814 0.007507447613220718; 0.0013628622323561337 0.011677738937532968 0.04548324144074999 -0.04677030825079163; 0.01268159963106496 0.012594294295081292 -7.149210681415066e-5 0.0004370191210083599]
 [0.02050571022303989 0.02801452518083079 0.010655868375294628 -0.03532868265938635; -0.0057102000198347505 0.02571558553330424 -0.021587143721842474 0.05027115013584364; -0.005946751938087257 0.024836722891709552 0.028657049150507208 0.002975453555006939]
 [0.02050571022303989 0.02801452518083079 0.010655868375294628 -0.03532868265938635; -0.0057102000198347505 0.02571558553330424 -0.021587143721842474 0.05027115013584364; -0.005946751938087257 0.024836722891709552 0.028657049150507208 0.002975453555006939]
 [-0.002700094963242746 -0.05143933591421823 -0.020124770522521256 0.025988791774087222; -0.0034499672785555744 -0.05706642437184204 0.012372654827208444 0.010004885881620816; 0.0030310083964881647 -0.01481157372745456 -0.000583406334748085 -0.02890079455965569]
 [-0.002700094963242746 -0.05143933591421823 -0.020124770522521256 0.025988791774087222; -0.0034499672785555744 -0.05706642437184204 0.012372654827208444 0.010004885881620816; 0.0030310083964881647 -0.01481157372745456 -0.000583406334748085 -0.02890079455965569]
 [-0.052990311938270925 -0.015458323305490587 0.05263431108240653 -0.052490752305166026; 0.006782778775774042 0.06489775768418721 -0.03623307496963398 -0.022989879693289902; 0.00029691702467216953 0.04029521620829195 0.007212801501385167 0.014766586523503988]
 [-0.06880231652342111 -0.038884614082064235 -0.004020684294475996 -0.0064699666901786995; -0.04989542663926319 0.09050161443778412 -0.01896426029863396 0.05239946155490549; 0.0036338863012466698 -0.013060480433651128 0.00929444545763588 0.031188044849077737]
 [0.05772056165768034 0.021602336495570848 -0.01287597618359437 0.016370023796592803; 0.054238790687792644 -0.03079590764527243 0.030436834012725265 -0.05492472450230437; -0.014541820097458927 0.03983476474902724 -0.015419565222381686 -0.022388230849241866]
 [-0.008853322352113176 0.005812518583800809 0.027791297314379636 0.009434234292916555; 0.02492337535439339 0.01956344743627865 0.03015869713810539 -0.013887177426433576; -0.02879645823860405 -0.0021259172222257913 -0.01244162634590053 -0.004650893912662837]

and stats contains extra information about the sampling procedure:

stats
10000-element Vector{NamedTuple}:
 (n_steps = 1, is_accept = true, acceptance_rate = 1.0, log_density = -338.7042498982266, hamiltonian_energy = 1419.860453320751, hamiltonian_energy_error = -2066.7466667904428, max_hamiltonian_energy_error = -2066.7466667904428, tree_depth = 1, numerical_error = false, step_size = 0.05, nom_step_size = 0.05, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -338.7042498982266, hamiltonian_energy = 348.1189729584576, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 8.97597021775412e10, tree_depth = 0, numerical_error = true, step_size = 1.241032542311506, nom_step_size = 1.241032542311506, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -338.7042498982266, hamiltonian_energy = 343.33891237678694, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 3.768646430550283e8, tree_depth = 0, numerical_error = true, step_size = 0.5, nom_step_size = 0.5, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -338.7042498982266, hamiltonian_energy = 350.52588051334595, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 99594.66687750978, tree_depth = 0, numerical_error = true, step_size = 0.13192866018819532, nom_step_size = 0.13192866018819532, is_adapt = true)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -61.922425320163185, hamiltonian_energy = 285.8563460817301, hamiltonian_energy_error = -60.06083348212326, max_hamiltonian_energy_error = -60.06083348212326, tree_depth = 2, numerical_error = false, step_size = 0.028716309633808685, nom_step_size = 0.028716309633808685, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -61.922425320163185, hamiltonian_energy = 65.0866873864314, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 7651.377176672548, tree_depth = 0, numerical_error = true, step_size = 0.11260612534950616, nom_step_size = 0.11260612534950616, is_adapt = true)
 (n_steps = 7, is_accept = true, acceptance_rate = 1.0, log_density = -3.574765458290204, hamiltonian_energy = 56.52321549814825, hamiltonian_energy_error = -8.407372429285957, max_hamiltonian_energy_error = -8.416161802109109, tree_depth = 3, numerical_error = false, step_size = 0.02340023160450795, nom_step_size = 0.02340023160450795, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 6.669027429256939e-208, log_density = -3.574765458290204, hamiltonian_energy = 9.015463753844049, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 477.0402253061708, tree_depth = 1, numerical_error = false, step_size = 0.10545494475829686, nom_step_size = 0.10545494475829686, is_adapt = true)
 (n_steps = 7, is_accept = true, acceptance_rate = 0.7659509609065804, log_density = -3.1050610212632668, hamiltonian_energy = 10.365319601703712, hamiltonian_energy_error = -0.05757705422897352, max_hamiltonian_energy_error = 0.583690341836606, tree_depth = 3, numerical_error = false, step_size = 0.021583114937805906, nom_step_size = 0.021583114937805906, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 2.2400958631897312e-5, log_density = -3.1050610212632668, hamiltonian_energy = 10.654947530808721, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 10.706406803952161, tree_depth = 1, numerical_error = false, step_size = 0.049232722864258616, nom_step_size = 0.049232722864258616, is_adapt = true)
 ⋮
 (n_steps = 3, is_accept = true, acceptance_rate = 0.6487655910569137, log_density = -3.564186040209527, hamiltonian_energy = 6.064107915726854, hamiltonian_energy_error = 0.4189094581270849, max_hamiltonian_energy_error = 0.7291672931717548, tree_depth = 2, numerical_error = false, step_size = 1.1551106684592662, nom_step_size = 1.1551106684592662, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.3553005700370642, log_density = -4.071581385708382, hamiltonian_energy = 14.086319309758636, hamiltonian_energy_error = 0.1063053557826823, max_hamiltonian_energy_error = 2.871982314465642, tree_depth = 2, numerical_error = false, step_size = 1.1551106684592662, nom_step_size = 1.1551106684592662, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.5515713279213217, log_density = -4.071581385708382, hamiltonian_energy = 9.26337213206329, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.0259413668807813, tree_depth = 2, numerical_error = false, step_size = 1.1551106684592662, nom_step_size = 1.1551106684592662, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.6448239080781902, log_density = -4.378507489839383, hamiltonian_energy = 8.904538168491904, hamiltonian_energy_error = 0.23153572271460732, max_hamiltonian_energy_error = 1.0758373805534838, tree_depth = 2, numerical_error = false, step_size = 1.1551106684592662, nom_step_size = 1.1551106684592662, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.34942544046170926, log_density = -4.378507489839383, hamiltonian_energy = 10.603362692612674, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.3446290888387118, tree_depth = 2, numerical_error = false, step_size = 1.1551106684592662, nom_step_size = 1.1551106684592662, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.11992906536405723, log_density = -8.718310141505059, hamiltonian_energy = 19.844820574005315, hamiltonian_energy_error = 1.458139623371821, max_hamiltonian_energy_error = 4.81932913787367, tree_depth = 2, numerical_error = false, step_size = 1.1551106684592662, nom_step_size = 1.1551106684592662, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.7873099252361221, log_density = -11.226908941406803, hamiltonian_energy = 17.040606174638448, hamiltonian_energy_error = 0.8022548902181406, max_hamiltonian_energy_error = 0.8022548902181406, tree_depth = 2, numerical_error = false, step_size = 1.1551106684592662, nom_step_size = 1.1551106684592662, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -7.685291543919762, hamiltonian_energy = 15.163158916840853, hamiltonian_energy_error = -1.35833329016738, max_hamiltonian_energy_error = -1.35833329016738, tree_depth = 2, numerical_error = false, step_size = 1.1551106684592662, nom_step_size = 1.1551106684592662, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -2.1525393528888666, hamiltonian_energy = 7.471768221730121, hamiltonian_energy_error = -1.7367248845930643, max_hamiltonian_energy_error = -1.7367248845930643, tree_depth = 2, numerical_error = false, step_size = 1.1551106684592662, nom_step_size = 1.1551106684592662, is_adapt = false)

Here we should see that the energy expectation for the generated ensemble matches with the equipartition theorem:

julia> Estimators.@estimate potential_energy(sim, chain)0.005788476487684865
julia> austrip(sim.temperature) * 3 * 4 / 20.005700260814198944