Thermal Hamiltonian Monte Carlo
Our implementation of Hamiltonian Monte Carlo (HMC) is a light wrapper around the AdvancedHMC.jl package. If you want to learn about the HMC theory, refer to the references and documentation provided with AdvancedHMC.jl.
Currently, our implementation works for systems with classical nuclei only (i.e. Simulation
but not RingPolymerSimulation
).
Example
In this example we use Hamiltonian Monte Carlo to sample the canonical distribution of a 3 dimensional harmonic oscillator potential containing 4 atoms.
using NQCDynamics
using Unitful
using UnitfulAtomic
sim = Simulation(Atoms([:H, :H, :C, :C]), Harmonic(dofs=3); temperature=300u"K")
r0 = randn(size(sim))
chain, stats = InitialConditions.ThermalMonteCarlo.run_advancedhmc_sampling(sim, r0, 1e4)
┌ Info: Finished 1000 adapation steps
│ adaptor =
│ StanHMCAdaptor(
│ pc=WelfordVar,
│ ssa=NesterovDualAveraging(γ=0.05, t_0=10.0, κ=0.75, δ=0.5, state.ϵ=1.1890316524126845),
│ init_buffer=75, term_buffer=50, window_size=25,
│ state=window(76, 950), window_splits(100, 150, 250, 450, 950)
│ )
│ κ.τ.integrator = Leapfrog(ϵ=1.19)
└ h.metric = DiagEuclideanMetric([0.0007204188837036886, 0.0 ...])
┌ Info: Finished 10000 sampling steps for 1 chains in 1.05244372 (s)
│ h = Hamiltonian(metric=DiagEuclideanMetric([0.0007204188837036886, 0.0 ...]), kinetic=AdvancedHMC.GaussianKinetic())
│ κ = AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.GeneralisedNoUTurn{Float64}}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.MultinomialTS}(integrator=Leapfrog(ϵ=1.19), tc=AdvancedHMC.GeneralisedNoUTurn{Float64}(10, 1000.0)))
│ EBFMI_est = 0.6630066575571423
└ average_acceptance_rate = 0.5978659385736624
The Monte Carlo chain
contains the nuclear configurations that we have sampled:
chain
10000-element Vector{Matrix{Float64}}:
[-0.3676860541664053 -0.06690552068509159 -0.029657308252405423 0.01853626969246817; -0.014090694647324015 0.1718511992875688 0.016273709998317717 0.2820319717579689; 0.14051543668001687 0.11263033992754079 0.12261703482739439 -0.3122106959041866]
[-0.3676860541664053 -0.06690552068509159 -0.029657308252405423 0.01853626969246817; -0.014090694647324015 0.1718511992875688 0.016273709998317717 0.2820319717579689; 0.14051543668001687 0.11263033992754079 0.12261703482739439 -0.3122106959041866]
[-0.3676860541664053 -0.06690552068509159 -0.029657308252405423 0.01853626969246817; -0.014090694647324015 0.1718511992875688 0.016273709998317717 0.2820319717579689; 0.14051543668001687 0.11263033992754079 0.12261703482739439 -0.3122106959041866]
[-0.3676860541664053 -0.06690552068509159 -0.029657308252405423 0.01853626969246817; -0.014090694647324015 0.1718511992875688 0.016273709998317717 0.2820319717579689; 0.14051543668001687 0.11263033992754079 0.12261703482739439 -0.3122106959041866]
[0.1647923804164668 -0.01957043025632018 -0.015478276472673334 -0.033981777599805627; 0.013516303317958061 -0.02864152128508146 -0.024737075114316644 -0.05865662075728906; -0.06071393238458292 -0.07414190900032785 -0.033971920946719844 0.09745077764592824]
[0.1647923804164668 -0.01957043025632018 -0.015478276472673334 -0.033981777599805627; 0.013516303317958061 -0.02864152128508146 -0.024737075114316644 -0.05865662075728906; -0.06071393238458292 -0.07414190900032785 -0.033971920946719844 0.09745077764592824]
[0.014415435642155347 0.02877780636437332 0.03017327635101399 -0.020326619699334855; -0.0035835508889333215 -0.011522541199289329 -0.0072323336731852095 0.02760324518982845; -0.027477517503105897 0.012225449597785307 0.0003003836043402597 0.04519985793181021]
[0.014415435642155347 0.02877780636437332 0.03017327635101399 -0.020326619699334855; -0.0035835508889333215 -0.011522541199289329 -0.0072323336731852095 0.02760324518982845; -0.027477517503105897 0.012225449597785307 0.0003003836043402597 0.04519985793181021]
[-0.013824343842636061 -0.007943802969074716 0.028595640968442845 -0.007513099985455277; 0.013299379741828266 -0.01566790328455065 -0.04293314587166269 0.017804730344817676; -0.02665963705294158 0.008751489438792889 -0.0028203670927330563 0.039699977577620296]
[-0.013824343842636061 -0.007943802969074716 0.028595640968442845 -0.007513099985455277; 0.013299379741828266 -0.01566790328455065 -0.04293314587166269 0.017804730344817676; -0.02665963705294158 0.008751489438792889 -0.0028203670927330563 0.039699977577620296]
⋮
[0.010311616505844088 0.06299631089902877 0.02589319989427713 -0.003373805878899707; -0.02056283874268322 -0.03702956844366866 0.03170157284023012 -0.05330222050042717; -0.02170766835395073 -0.0452268468322817 0.10078001650093091 -0.017967922926943852]
[-0.017513863951400338 -0.060805178402887516 0.01118666894347009 0.020906533234664893; 0.03261466486728471 -0.012028127438706439 -0.030024098954095156 0.022322435473717317; -0.03172151610825996 0.048306963075682524 -0.09435253494496948 -0.012642642858838538]
[0.07652894524713638 0.013692027937343254 -0.02946617246767299 0.012152269456638887; -0.03176893503145099 0.02625758241074007 0.010290839052775905 -0.0023458785716558017; 0.03682346621766465 -0.021968958595030102 0.016951549967609253 0.014995959099591545]
[-0.0806729239607114 0.014482065948146619 0.009477314901140715 -0.018501445547878755; 0.00906249080718307 -0.0542366240807083 -0.027537454717692953 -0.010391510486747055; -0.0459022150245838 0.062019378667249006 0.003607646114043385 -0.0067192548918338126]
[0.08950353842883435 -0.00563173189568205 -0.0440301202857786 0.007156904860594903; -0.03471901159470063 0.008150497622717759 -0.007671932395639798 0.012656480618025117; 0.03654998062758815 -0.029960413590667007 -0.03328722898316325 0.037408033763411515]
[-0.033758407944495625 -0.0008363532781185956 0.06606379791947428 0.017195973963099596; 0.010830611611571526 0.02015892764920757 -0.011716230476522125 -0.03506102629273891; -0.04053336864059742 0.055486679171167536 0.01293187849313698 -0.010599816950200569]
[-0.013436115073212747 0.03147823927397847 -0.056651255739397124 -0.019126348890753545; -0.028662273482452434 0.01108855976313998 0.03670174016652935 0.049214508987229705; 0.017840038070888763 -0.02699621270351374 -0.003396123676134146 0.01621420212308598]
[-0.0388001039612765 0.03733725849802267 -0.06580048616025073 0.01490399318519985; -0.015129624550815918 -0.032324552531555674 0.017280846907077183 0.07247454697074088; -0.005770972250391363 -0.011477314530600975 0.02023845972931092 -0.013542783320788006]
[0.014542632712774443 -0.020062067748006805 0.03782256233692599 -0.015330374353324672; 0.020288038644573528 -0.0025391040600564047 -0.017500907459827585 -0.05710579785655673; 0.009617960573619147 0.05924354623142822 -0.003273131589407577 -0.027133391011608644]
and stats
contains extra information about the sampling procedure:
stats
10000-element Vector{NamedTuple}:
(n_steps = 1, is_accept = true, acceptance_rate = 1.0, log_density = -208.08069926699343, hamiltonian_energy = 1159.6025135781542, hamiltonian_energy_error = -1817.5036113901874, max_hamiltonian_energy_error = -1817.5036113901874, tree_depth = 1, numerical_error = false, step_size = 0.05, nom_step_size = 0.05, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -208.08069926699343, hamiltonian_energy = 215.1872172000875, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 5.514795206573016e10, tree_depth = 0, numerical_error = true, step_size = 1.241032542311506, nom_step_size = 1.241032542311506, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -208.08069926699343, hamiltonian_energy = 213.26978096997246, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 2.366555565887356e8, tree_depth = 0, numerical_error = true, step_size = 0.5, nom_step_size = 0.5, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -208.08069926699343, hamiltonian_energy = 217.99994137774829, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 56782.50569605116, tree_depth = 0, numerical_error = true, step_size = 0.13192866018819532, nom_step_size = 0.13192866018819532, is_adapt = true)
(n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -28.326792069381206, hamiltonian_energy = 173.47581712870743, hamiltonian_energy_error = -39.006063726962225, max_hamiltonian_energy_error = -39.006063726962225, tree_depth = 2, numerical_error = false, step_size = 0.028716309633808685, nom_step_size = 0.028716309633808685, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -28.326792069381206, hamiltonian_energy = 36.537192410546716, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 3543.482042827406, tree_depth = 0, numerical_error = true, step_size = 0.11260612534950616, nom_step_size = 0.11260612534950616, is_adapt = true)
(n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -3.2982777966109285, hamiltonian_energy = 27.475582643501447, hamiltonian_energy_error = -3.6063835523312875, max_hamiltonian_energy_error = -3.6063835523312875, tree_depth = 2, numerical_error = false, step_size = 0.02340023160450795, nom_step_size = 0.02340023160450795, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 6.754241786099173e-173, log_density = -3.2982777966109285, hamiltonian_energy = 8.866444819108741, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 396.43705036769865, tree_depth = 1, numerical_error = false, step_size = 0.10545494475829686, nom_step_size = 0.10545494475829686, is_adapt = true)
(n_steps = 7, is_accept = true, acceptance_rate = 0.9342772892978329, log_density = -3.201096411726343, hamiltonian_energy = 6.913802528471361, hamiltonian_energy_error = -0.01191263574805479, max_hamiltonian_energy_error = 0.1237640945115226, tree_depth = 3, numerical_error = false, step_size = 0.021583114937805906, nom_step_size = 0.021583114937805906, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 6.773671848641367e-65, log_density = -3.201096411726343, hamiltonian_energy = 8.574236343924582, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 147.7549877341659, tree_depth = 1, numerical_error = false, step_size = 0.08377338266889917, nom_step_size = 0.08377338266889917, is_adapt = true)
⋮
(n_steps = 3, is_accept = true, acceptance_rate = 0.4527034282366842, log_density = -12.311536249823181, hamiltonian_energy = 19.258587283610822, hamiltonian_energy_error = 0.7584508881557923, max_hamiltonian_energy_error = 0.8782953120443118, tree_depth = 2, numerical_error = false, step_size = 1.1890316524126845, nom_step_size = 1.1890316524126845, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.680940036696048, log_density = -10.302904485702676, hamiltonian_energy = 21.612742373024567, hamiltonian_energy_error = -0.7245797574271293, max_hamiltonian_energy_error = 1.2583423928287267, tree_depth = 2, numerical_error = false, step_size = 1.1890316524126845, nom_step_size = 1.1890316524126845, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.7723370010290377, log_density = -5.905546141675037, hamiltonian_energy = 16.249104498061563, hamiltonian_energy_error = -1.4021599783514205, max_hamiltonian_energy_error = -1.4021599783514205, tree_depth = 2, numerical_error = false, step_size = 1.1890316524126845, nom_step_size = 1.1890316524126845, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.4824403383894231, log_density = -8.974109943619037, hamiltonian_energy = 12.365721204987786, hamiltonian_energy_error = 0.9456077661311006, max_hamiltonian_energy_error = 1.430351511614603, tree_depth = 2, numerical_error = false, step_size = 1.1890316524126845, nom_step_size = 1.1890316524126845, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.5297183547195554, log_density = -8.559758895713404, hamiltonian_energy = 16.23165650088573, hamiltonian_energy_error = -0.17874809564924377, max_hamiltonian_energy_error = 1.4546585818218674, tree_depth = 2, numerical_error = false, step_size = 1.1890316524126845, nom_step_size = 1.1890316524126845, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.7181470214687042, log_density = -6.679695296162311, hamiltonian_energy = 13.55122387133624, hamiltonian_energy_error = -0.4149622471047021, max_hamiltonian_energy_error = 0.6619259971927818, tree_depth = 2, numerical_error = false, step_size = 1.1890316524126845, nom_step_size = 1.1890316524126845, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.8700213680108965, log_density = -5.674295575069492, hamiltonian_energy = 10.732138071399763, hamiltonian_energy_error = -0.23794282415083856, max_hamiltonian_energy_error = -0.23794282415083856, tree_depth = 2, numerical_error = false, step_size = 1.1890316524126845, nom_step_size = 1.1890316524126845, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.8154182881485146, log_density = -7.912457370153743, hamiltonian_energy = 10.395817571263837, hamiltonian_energy_error = 0.8068650452293955, max_hamiltonian_energy_error = -1.0350578805594548, tree_depth = 2, numerical_error = false, step_size = 1.1890316524126845, nom_step_size = 1.1890316524126845, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.7740323951902303, log_density = -5.586157044463571, hamiltonian_energy = 11.818946818251483, hamiltonian_energy_error = -0.8765228345229019, max_hamiltonian_energy_error = -0.8765228345229019, tree_depth = 2, numerical_error = false, step_size = 1.1890316524126845, nom_step_size = 1.1890316524126845, is_adapt = false)
Here we should see that the energy expectation for the generated ensemble matches with the equipartition theorem:
julia> Estimators.@estimate potential_energy(sim, chain)
0.005817032791921449
julia> austrip(sim.temperature) * 3 * 4 / 2
0.005700260814198944