Thermal Hamiltonian Monte Carlo

Our implementation of Hamiltonian Monte Carlo (HMC) is a light wrapper around the AdvancedHMC.jl package. If you want to learn about the HMC theory, refer to the references and documentation provided with AdvancedHMC.jl.

Currently, our implementation works for systems with classical nuclei only (i.e. Simulation but not RingPolymerSimulation).

Example

In this example we use Hamiltonian Monte Carlo to sample the canonical distribution of a 3 dimensional harmonic oscillator potential containing 4 atoms.

using NQCDynamics
using Unitful
using UnitfulAtomic

sim = Simulation(Atoms([:H, :H, :C, :C]), Harmonic(dofs=3); temperature=300u"K")
r0 = randn(size(sim))
chain, stats = InitialConditions.ThermalMonteCarlo.run_advancedhmc_sampling(sim, r0, 1e4)
┌ Info: Finished 1000 adapation steps
  adaptor =
   StanHMCAdaptor(
       pc=WelfordVar,
       ssa=NesterovDualAveraging(γ=0.05, t_0=10.0, κ=0.75, δ=0.5, state.ϵ=1.0979347828663584),
       init_buffer=75, term_buffer=50, window_size=25,
       state=window(76, 950), window_splits(100, 150, 250, 450, 950)
   )
  κ.τ.integrator = Leapfrog(ϵ=1.1)
  h.metric = DiagEuclideanMetric([0.0009201242897161104, 0.0 ...])
┌ Info: Finished 10000 sampling steps for 1 chains in 1.150642592 (s)
  h = Hamiltonian(metric=DiagEuclideanMetric([0.0009201242897161104, 0.0 ...]), kinetic=AdvancedHMC.GaussianKinetic())
  κ = AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.GeneralisedNoUTurn{Float64}}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.MultinomialTS}(integrator=Leapfrog(ϵ=1.1), tc=AdvancedHMC.GeneralisedNoUTurn{Float64}(10, 1000.0)))
  EBFMI_est = 0.44086601048351326
  average_acceptance_rate = 0.6347707721739947

The Monte Carlo chain contains the nuclear configurations that we have sampled:

chain
10000-element Vector{Matrix{Float64}}:
 [0.6189344074500809 0.4513302742119083 0.1979887157904473 0.11682326752684574; -0.19710849062967406 -0.09627390935099156 -0.22255711957722435 0.33385366512099823; -0.5906025230136058 0.10493058463889099 0.2953507170797237 -0.022926308566159345]
 [0.6189344074500809 0.4513302742119083 0.1979887157904473 0.11682326752684574; -0.19710849062967406 -0.09627390935099156 -0.22255711957722435 0.33385366512099823; -0.5906025230136058 0.10493058463889099 0.2953507170797237 -0.022926308566159345]
 [0.6189344074500809 0.4513302742119083 0.1979887157904473 0.11682326752684574; -0.19710849062967406 -0.09627390935099156 -0.22255711957722435 0.33385366512099823; -0.5906025230136058 0.10493058463889099 0.2953507170797237 -0.022926308566159345]
 [0.6189344074500809 0.4513302742119083 0.1979887157904473 0.11682326752684574; -0.19710849062967406 -0.09627390935099156 -0.22255711957722435 0.33385366512099823; -0.5906025230136058 0.10493058463889099 0.2953507170797237 -0.022926308566159345]
 [-0.2418703011629133 -0.1407309824048591 -0.06954459256822769 -0.06641519957041032; 0.10672301583742447 0.010256714633894493 0.07538620644612853 -0.14859963586751518; 0.20724717242463542 -0.10388722669483227 -0.07395546688623197 0.03800421769296676]
 [-0.2418703011629133 -0.1407309824048591 -0.06954459256822769 -0.06641519957041032; 0.10672301583742447 0.010256714633894493 0.07538620644612853 -0.14859963586751518; 0.20724717242463542 -0.10388722669483227 -0.07395546688623197 0.03800421769296676]
 [-0.03966502811982747 -0.00016395771108326795 -0.018882790036993446 -0.043137648778792094; -0.0050297280836662794 0.014754675375953761 0.042781422140524064 0.011564658517054696; 0.0382960612364103 0.010097218871655822 0.08416477105146294 0.005196576918928698]
 [-0.03966502811982747 -0.00016395771108326795 -0.018882790036993446 -0.043137648778792094; -0.0050297280836662794 0.014754675375953761 0.042781422140524064 0.011564658517054696; 0.0382960612364103 0.010097218871655822 0.08416477105146294 0.005196576918928698]
 [0.03272084689223044 0.008862391597306875 0.013890143144455691 0.036205544644997055; 0.025886956625036525 -0.009841696372876149 -0.0643917610163228 -0.007908827873475034; -0.012497608298390847 -0.019710124696029667 -0.06581271714695533 -0.031629928191536816]
 [0.03272084689223044 0.008862391597306875 0.013890143144455691 0.036205544644997055; 0.025886956625036525 -0.009841696372876149 -0.0643917610163228 -0.007908827873475034; -0.012497608298390847 -0.019710124696029667 -0.06581271714695533 -0.031629928191536816]
 ⋮
 [-0.0288618457965239 0.023628647228398643 0.042105376258740515 0.011605328972502637; 0.04510144810032775 -0.0036299101695705834 -0.0493568008296247 -0.05995066121921808; -0.036987760484247324 0.03854136898624236 -0.033531482861628055 -0.00238859837315223]
 [0.016921243656581424 -0.03529928852782575 -0.02930100030446841 -0.014535832539207568; -0.027979209112868664 -0.0013636396528456211 0.052207219653535315 0.06340848918387844; 0.040072171854117045 -0.0405543298426313 0.03534318794057706 0.0039491540317404894]
 [-0.004948651135059632 -0.0036559977993703974 0.04875226280859067 0.036052612288941625; -0.023162614242297916 0.0015070963582364282 -0.03941221122868872 -0.06911339835473237; 0.0035365916214586887 -0.04419203107588196 -0.00839457032310726 0.03255056448815558]
 [-0.015464874145631878 0.01579402939214125 -0.057014974832883064 -0.009418786939642348; 0.031246537364596608 0.015514710440249423 0.014590355683267449 0.05515147936956616; 0.020339714587480733 0.038804161017494364 0.004445161417417767 -0.03701731431288233]
 [-0.013281422139997365 0.0008829436890032578 0.05671123950712173 0.02100935147021773; -0.04167629666111412 -0.01890354311317303 -0.02011901130629531 -0.032713631365874546; 0.004103669028703051 -0.02768022800853363 0.009725480001812982 -0.017232712467957413]
 [0.03383884042034409 -0.02389842374371271 0.03918151204580553 -0.0027507005603847126; -0.01683177766309695 -0.0015177885710130172 0.04720127220480706 -0.010624663950664314; -0.047544031800081606 -0.03459483880120602 0.017288712992063145 -0.028700502371846032]
 [-0.029730293230488244 0.05335935403825188 -0.03583486925303235 -0.01760560725623233; 0.013060162550869116 0.02259930181701638 -0.04922296656989996 0.01201076506401162; 0.054168216357996166 0.03406159448007197 -0.02763912378892264 0.03268571488568224]
 [-0.008117631328601018 0.04812509475870157 -0.032737161663857525 0.03928268716121382; -0.015284652269495788 0.007795835538327991 -0.03150938162258968 0.08392062111625885; 0.0043555999646640794 0.05120848754905767 0.00779144940052787 0.009585218238916521]
 [-0.008280590407594976 -0.04313793495319519 -0.0013419935802528593 0.017818022133304315; -0.032377603576374876 -0.0036981625215300722 0.009024317999107384 -0.012402638733690352; 0.016682453263521503 0.021368175721959218 0.04079580138885776 0.01642326665105025]

and stats contains extra information about the sampling procedure:

stats
10000-element Vector{NamedTuple}:
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -682.2402659411766, hamiltonian_energy = 2407.141413581641, hamiltonian_energy_error = -3307.320276950624, max_hamiltonian_energy_error = -3307.320276950624, tree_depth = 2, numerical_error = false, step_size = 0.05, nom_step_size = 0.05, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -682.2402659411766, hamiltonian_energy = 687.5696285944811, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.8105028129748398e11, tree_depth = 0, numerical_error = true, step_size = 1.241032542311506, nom_step_size = 1.241032542311506, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -682.2402659411766, hamiltonian_energy = 691.1573809476641, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 7.709838235608963e8, tree_depth = 0, numerical_error = true, step_size = 0.5, nom_step_size = 0.5, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -682.2402659411766, hamiltonian_energy = 687.6854321302559, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 206077.0415447624, tree_depth = 0, numerical_error = true, step_size = 0.13192866018819532, nom_step_size = 0.13192866018819532, is_adapt = true)
 (n_steps = 7, is_accept = true, acceptance_rate = 1.0, log_density = -98.6646924027053, hamiltonian_energy = 560.5686401248593, hamiltonian_energy_error = -126.634165375419, max_hamiltonian_energy_error = -143.6436038325727, tree_depth = 2, numerical_error = false, step_size = 0.028716309633808685, nom_step_size = 0.028716309633808685, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -98.6646924027053, hamiltonian_energy = 103.84589910123276, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 10046.392411082063, tree_depth = 0, numerical_error = true, step_size = 0.11260612534950616, nom_step_size = 0.11260612534950616, is_adapt = true)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -7.724385024492403, hamiltonian_energy = 91.99859451962824, hamiltonian_energy_error = -13.103679475275442, max_hamiltonian_energy_error = -13.103679475275442, tree_depth = 2, numerical_error = false, step_size = 0.02340023160450795, nom_step_size = 0.02340023160450795, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 4.984844072331594e-272, log_density = -7.724385024492403, hamiltonian_energy = 13.365844064284822, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 624.6967431708276, tree_depth = 1, numerical_error = false, step_size = 0.10545494475829686, nom_step_size = 0.10545494475829686, is_adapt = true)
 (n_steps = 7, is_accept = true, acceptance_rate = 0.9616571894599834, log_density = -7.107706682003233, hamiltonian_energy = 14.593782470991453, hamiltonian_energy_error = -0.07559332969489674, max_hamiltonian_energy_error = 0.1133281129793211, tree_depth = 3, numerical_error = false, step_size = 0.021583114937805906, nom_step_size = 0.021583114937805906, is_adapt = true)
 (n_steps = 1, is_accept = true, acceptance_rate = 4.3291338430145736e-58, log_density = -7.107706682003233, hamiltonian_energy = 15.372454051842192, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 132.08456790790547, tree_depth = 1, numerical_error = false, step_size = 0.091339028389092, nom_step_size = 0.091339028389092, is_adapt = true)
 ⋮
 (n_steps = 3, is_accept = true, acceptance_rate = 0.37998655084034993, log_density = -8.083817064766594, hamiltonian_energy = 13.54284830681414, hamiltonian_energy_error = 0.9139357456171453, max_hamiltonian_energy_error = 1.5758803366706928, tree_depth = 2, numerical_error = false, step_size = 1.0979347828663584, nom_step_size = 1.0979347828663584, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -7.709280206143042, hamiltonian_energy = 12.386800177891665, hamiltonian_energy_error = -0.1404631377912704, max_hamiltonian_energy_error = -0.8027141885057212, tree_depth = 2, numerical_error = false, step_size = 1.0979347828663584, nom_step_size = 1.0979347828663584, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.6108219602382834, log_density = -7.198956316210327, hamiltonian_energy = 15.346629107319167, hamiltonian_energy_error = -0.15962463364157742, max_hamiltonian_energy_error = 1.3339332673506696, tree_depth = 2, numerical_error = false, step_size = 1.0979347828663584, nom_step_size = 1.0979347828663584, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -6.109804125530404, hamiltonian_energy = 9.158179907600555, hamiltonian_energy_error = -0.24700141596063574, max_hamiltonian_energy_error = -1.2275708497546827, tree_depth = 2, numerical_error = false, step_size = 1.0979347828663584, nom_step_size = 1.0979347828663584, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -4.51481115335423, hamiltonian_energy = 8.182375629760184, hamiltonian_energy_error = -0.41225724170439726, max_hamiltonian_energy_error = -0.6512735102177309, tree_depth = 2, numerical_error = false, step_size = 1.0979347828663584, nom_step_size = 1.0979347828663584, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.7500764995835622, log_density = -5.507785500023664, hamiltonian_energy = 8.992242064170638, hamiltonian_energy_error = 0.2668226414879378, max_hamiltonian_energy_error = 0.3227361792470944, tree_depth = 2, numerical_error = false, step_size = 1.0979347828663584, nom_step_size = 1.0979347828663584, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.5748482381978867, log_density = -7.631401583195358, hamiltonian_energy = 12.795025890175292, hamiltonian_energy_error = 0.6449836831000653, max_hamiltonian_energy_error = 1.6100700544386424, tree_depth = 2, numerical_error = false, step_size = 1.0979347828663584, nom_step_size = 1.0979347828663584, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.9771379827800906, log_density = -8.484108064453917, hamiltonian_energy = 13.29774384462448, hamiltonian_energy_error = 0.07105147287466984, max_hamiltonian_energy_error = -0.4747105129384792, tree_depth = 2, numerical_error = false, step_size = 1.0979347828663584, nom_step_size = 1.0979347828663584, is_adapt = false)
 (n_steps = 3, is_accept = true, acceptance_rate = 0.6767538220303803, log_density = -3.2708535624576918, hamiltonian_energy = 11.570660460641426, hamiltonian_energy_error = -1.3885688394059894, max_hamiltonian_energy_error = -1.3885688394059894, tree_depth = 2, numerical_error = false, step_size = 1.0979347828663584, nom_step_size = 1.0979347828663584, is_adapt = false)

Here we should see that the energy expectation for the generated ensemble matches with the equipartition theorem:

julia> Estimators.@estimate potential_energy(sim, chain)0.005978482920505031
julia> austrip(sim.temperature) * 3 * 4 / 20.005700260814198944