Thermal Hamiltonian Monte Carlo
Our implementation of Hamiltonian Monte Carlo (HMC) is a light wrapper around the AdvancedHMC.jl package. If you want to learn about the HMC theory, refer to the references and documentation provided with AdvancedHMC.jl.
Currently, our implementation works for systems with classical nuclei only (i.e. Simulation
but not RingPolymerSimulation
).
Example
In this example we use Hamiltonian Monte Carlo to sample the canonical distribution of a 3 dimensional harmonic oscillator potential containing 4 atoms.
using NQCDynamics
using Unitful
using UnitfulAtomic
sim = Simulation(Atoms([:H, :H, :C, :C]), Harmonic(dofs=3); temperature=300u"K")
r0 = randn(size(sim))
chain, stats = InitialConditions.ThermalMonteCarlo.run_advancedhmc_sampling(sim, r0, 1e4)
┌ Info: Finished 1000 adapation steps
│ adaptor =
│ StanHMCAdaptor(
│ pc=WelfordVar,
│ ssa=NesterovDualAveraging(γ=0.05, t_0=10.0, κ=0.75, δ=0.5, state.ϵ=1.1503832774307403),
│ init_buffer=75, term_buffer=50, window_size=25,
│ state=window(76, 950), window_splits(100, 150, 250, 450, 950)
│ )
│ κ.τ.integrator = Leapfrog(ϵ=1.15)
└ h.metric = DiagEuclideanMetric([0.0010310703178621572, 0.0 ...])
┌ Info: Finished 10000 sampling steps for 1 chains in 1.026149039 (s)
│ h = Hamiltonian(metric=DiagEuclideanMetric([0.0010310703178621572, 0.0 ...]), kinetic=AdvancedHMC.GaussianKinetic())
│ κ = AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.MultinomialTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.GeneralisedNoUTurn{Float64}}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.MultinomialTS}(integrator=Leapfrog(ϵ=1.15), tc=AdvancedHMC.GeneralisedNoUTurn{Float64}(10, 1000.0)))
│ EBFMI_est = 0.4791182901349628
└ average_acceptance_rate = 0.5882873490087689
The Monte Carlo chain
contains the nuclear configurations that we have sampled:
chain
10000-element Vector{Matrix{Float64}}:
[-0.4115519261731946 -0.16429320083499177 -0.1535045050240213 0.11133234993635893; -0.12093648315944366 0.3882529377450148 -0.5564318147340577 -0.03840608516306716; 0.8345363540239612 0.39884629001245564 0.06415838611010016 -0.49585808396069475]
[-0.4115519261731946 -0.16429320083499177 -0.1535045050240213 0.11133234993635893; -0.12093648315944366 0.3882529377450148 -0.5564318147340577 -0.03840608516306716; 0.8345363540239612 0.39884629001245564 0.06415838611010016 -0.49585808396069475]
[-0.4115519261731946 -0.16429320083499177 -0.1535045050240213 0.11133234993635893; -0.12093648315944366 0.3882529377450148 -0.5564318147340577 -0.03840608516306716; 0.8345363540239612 0.39884629001245564 0.06415838611010016 -0.49585808396069475]
[-0.4115519261731946 -0.16429320083499177 -0.1535045050240213 0.11133234993635893; -0.12093648315944366 0.3882529377450148 -0.5564318147340577 -0.03840608516306716; 0.8345363540239612 0.39884629001245564 0.06415838611010016 -0.49585808396069475]
[0.12341134205103554 0.03888870520635597 0.10992462550997116 -0.03378038815775743; 0.06588477238201312 -0.12265017512999582 0.1785234962375859 -0.01762501682680865; -0.2906708190632438 -0.132103729369818 -0.019632486647679075 0.10875541145276146]
[0.12341134205103554 0.03888870520635597 0.10992462550997116 -0.03378038815775743; 0.06588477238201312 -0.12265017512999582 0.1785234962375859 -0.01762501682680865; -0.2906708190632438 -0.132103729369818 -0.019632486647679075 0.10875541145276146]
[0.0011319462744826952 -0.009670648806923449 0.029031571515983012 -0.014469341706601972; 0.011386496948077668 0.04091720656871886 0.0045525020582918435 -0.031276189257819866; -0.015088468589209952 -0.04380862794915355 0.013227021894237242 0.026046999264147996]
[0.0011319462744826952 -0.009670648806923449 0.029031571515983012 -0.014469341706601972; 0.011386496948077668 0.04091720656871886 0.0045525020582918435 -0.031276189257819866; -0.015088468589209952 -0.04380862794915355 0.013227021894237242 0.026046999264147996]
[-0.0110331870195901 -0.010086358697373834 0.039574421823282226 0.01942519062680091; -0.004715110423334061 0.053922996275101276 -0.10057077151122487 -0.07567231598942606; 0.004915894315890543 0.0152309111433136 -0.015309554846835749 0.0564002359229844]
[-0.005732397847594111 -0.016989402984785976 0.00900975277516955 -0.062244425730071606; -0.03451956719900249 -0.04471549804475532 0.02255663547797722 0.08709003653803665; 0.04364128737258702 -0.021400322055731794 -0.055285434191656185 -0.0403492219456978]
⋮
[-0.05808290755438083 -0.024734547397725355 0.00510152267789957 -0.0546698256336197; 0.039479054078402756 -0.03832379623107025 0.03275049186725386 0.020571779434351654; -0.04525444725150475 0.008649792986685778 -0.009647781452846536 0.0014089995441683677]
[0.031812247722905626 0.0315373417078571 -0.031504408806122305 0.017372912501771635; 0.008205411164489848 0.007595882488836872 -0.015480572007533628 0.011936734307827157; 0.005466673619934838 0.007932140570719795 0.016438756527119076 -0.006569580755566071]
[0.031812247722905626 0.0315373417078571 -0.031504408806122305 0.017372912501771635; 0.008205411164489848 0.007595882488836872 -0.015480572007533628 0.011936734307827157; 0.005466673619934838 0.007932140570719795 0.016438756527119076 -0.006569580755566071]
[0.031812247722905626 0.0315373417078571 -0.031504408806122305 0.017372912501771635; 0.008205411164489848 0.007595882488836872 -0.015480572007533628 0.011936734307827157; 0.005466673619934838 0.007932140570719795 0.016438756527119076 -0.006569580755566071]
[0.031812247722905626 0.0315373417078571 -0.031504408806122305 0.017372912501771635; 0.008205411164489848 0.007595882488836872 -0.015480572007533628 0.011936734307827157; 0.005466673619934838 0.007932140570719795 0.016438756527119076 -0.006569580755566071]
[0.031812247722905626 0.0315373417078571 -0.031504408806122305 0.017372912501771635; 0.008205411164489848 0.007595882488836872 -0.015480572007533628 0.011936734307827157; 0.005466673619934838 0.007932140570719795 0.016438756527119076 -0.006569580755566071]
[-0.001635879104899915 -0.0013772083305760643 -0.00968191571649639 0.016858642558391153; -0.0010955431366184935 -0.0047577660805826275 -0.005003509023131653 0.019492433407137473; -0.0056399451727705266 -0.0021489193778656025 0.008552498502814178 -0.003378951336983696]
[-0.001635879104899915 -0.0013772083305760643 -0.00968191571649639 0.016858642558391153; -0.0010955431366184935 -0.0047577660805826275 -0.005003509023131653 0.019492433407137473; -0.0056399451727705266 -0.0021489193778656025 0.008552498502814178 -0.003378951336983696]
[-0.03463176676918244 -0.02776913917512677 0.02029794468676728 0.013011573062028235; 0.0061131674630412886 -0.02698762536126174 -0.022263942909869074 0.016444087530498995; -0.009923537810548423 -0.0020224163863082917 -0.007404284433984031 0.0017323572199599894]
and stats
contains extra information about the sampling procedure:
stats
10000-element Vector{NamedTuple}:
(n_steps = 5, is_accept = true, acceptance_rate = 1.0, log_density = -954.8531323850946, hamiltonian_energy = 3739.405318388911, hamiltonian_energy_error = -5336.988532229819, max_hamiltonian_energy_error = -5491.180295280743, tree_depth = 2, numerical_error = false, step_size = 0.05, nom_step_size = 0.05, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -954.8531323850946, hamiltonian_energy = 967.1518589354084, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 2.528589221649744e11, tree_depth = 0, numerical_error = true, step_size = 1.241032542311506, nom_step_size = 1.241032542311506, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -954.8531323850946, hamiltonian_energy = 961.5922707612572, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.0772496638434496e9, tree_depth = 0, numerical_error = true, step_size = 0.5, nom_step_size = 0.5, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -954.8531323850946, hamiltonian_energy = 959.053648065942, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 280544.3862039453, tree_depth = 0, numerical_error = true, step_size = 0.13192866018819532, nom_step_size = 0.13192866018819532, is_adapt = true)
(n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -102.98816742303079, hamiltonian_energy = 775.352552708051, hamiltonian_energy_error = -184.85216609810698, max_hamiltonian_energy_error = -184.85216609810698, tree_depth = 2, numerical_error = false, step_size = 0.028716309633808685, nom_step_size = 0.028716309633808685, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 0.0, log_density = -102.98816742303079, hamiltonian_energy = 119.91184513652247, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 8522.396106150087, tree_depth = 0, numerical_error = true, step_size = 0.11260612534950616, nom_step_size = 0.11260612534950616, is_adapt = true)
(n_steps = 3, is_accept = true, acceptance_rate = 1.0, log_density = -3.6577537833757012, hamiltonian_energy = 91.66252020747363, hamiltonian_energy_error = -14.31261824382615, max_hamiltonian_energy_error = -14.31261824382615, tree_depth = 2, numerical_error = false, step_size = 0.02340023160450795, nom_step_size = 0.02340023160450795, is_adapt = true)
(n_steps = 1, is_accept = true, acceptance_rate = 7.52563510236005e-166, log_density = -3.6577537833757012, hamiltonian_energy = 7.745708078964702, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 380.2108102309533, tree_depth = 1, numerical_error = false, step_size = 0.10545494475829686, nom_step_size = 0.10545494475829686, is_adapt = true)
(n_steps = 7, is_accept = true, acceptance_rate = 0.6117405007309478, log_density = -12.951594626650472, hamiltonian_energy = 15.722269366767188, hamiltonian_energy_error = 1.139252551276197, max_hamiltonian_energy_error = 1.139252551276197, tree_depth = 3, numerical_error = false, step_size = 0.021583114937805906, nom_step_size = 0.021583114937805906, is_adapt = true)
(n_steps = 3, is_accept = true, acceptance_rate = 0.9895667617049176, log_density = -11.898756405165251, hamiltonian_energy = 20.629048113727997, hamiltonian_energy_error = -0.2535596153534527, max_hamiltonian_energy_error = -0.6794377307466419, tree_depth = 2, numerical_error = false, step_size = 0.030252478068166968, nom_step_size = 0.030252478068166968, is_adapt = true)
⋮
(n_steps = 3, is_accept = true, acceptance_rate = 0.4144297114459074, log_density = -7.231873370976679, hamiltonian_energy = 10.310915856584028, hamiltonian_energy_error = 1.1653370117619488, max_hamiltonian_energy_error = 1.2185508881764449, tree_depth = 2, numerical_error = false, step_size = 1.1503832774307403, nom_step_size = 1.1503832774307403, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.40791097226669176, log_density = -2.2179639397074067, hamiltonian_energy = 11.726089424782343, hamiltonian_energy_error = -1.5904238246599043, max_hamiltonian_energy_error = 2.3219080442682856, tree_depth = 2, numerical_error = false, step_size = 1.1503832774307403, nom_step_size = 1.1503832774307403, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.2960275662932333, log_density = -2.2179639397074067, hamiltonian_energy = 7.467949843951599, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 1.9397933308146484, tree_depth = 2, numerical_error = false, step_size = 1.1503832774307403, nom_step_size = 1.1503832774307403, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.37108906237430395, log_density = -2.2179639397074067, hamiltonian_energy = 8.160915055937554, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 2.2786041732003444, tree_depth = 2, numerical_error = false, step_size = 1.1503832774307403, nom_step_size = 1.1503832774307403, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.1762423218339326, log_density = -2.2179639397074067, hamiltonian_energy = 10.373128071698703, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 3.744032148103763, tree_depth = 2, numerical_error = false, step_size = 1.1503832774307403, nom_step_size = 1.1503832774307403, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.20505564605738558, log_density = -2.2179639397074067, hamiltonian_energy = 8.729924058200481, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 2.03303825908999, tree_depth = 2, numerical_error = false, step_size = 1.1503832774307403, nom_step_size = 1.1503832774307403, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.47049205711954384, log_density = -0.4906839259545522, hamiltonian_energy = 6.778853548791356, hamiltonian_energy_error = -0.6146854703625282, max_hamiltonian_energy_error = 2.629709398298285, tree_depth = 2, numerical_error = false, step_size = 1.1503832774307403, nom_step_size = 1.1503832774307403, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.2160905254581622, log_density = -0.4906839259545522, hamiltonian_energy = 5.779781632882606, hamiltonian_energy_error = 0.0, max_hamiltonian_energy_error = 2.6164250737053765, tree_depth = 1, numerical_error = false, step_size = 1.1503832774307403, nom_step_size = 1.1503832774307403, is_adapt = false)
(n_steps = 3, is_accept = true, acceptance_rate = 0.36269950724064687, log_density = -2.2335699672598657, hamiltonian_energy = 5.890199819274208, hamiltonian_energy_error = 0.5791603286229448, max_hamiltonian_energy_error = 2.197586623173027, tree_depth = 2, numerical_error = false, step_size = 1.1503832774307403, nom_step_size = 1.1503832774307403, is_adapt = false)
Here we should see that the energy expectation for the generated ensemble matches with the equipartition theorem:
julia> Estimators.@estimate potential_energy(sim, chain)
0.006079282006760821
julia> austrip(sim.temperature) * 3 * 4 / 2
0.005700260814198944