Skip to content

Commit

Permalink
pack nuts results some more and update readme example
Browse files Browse the repository at this point in the history
  • Loading branch information
zoj613 authored and brandonwillard committed Jun 22, 2023
1 parent 0d90fbb commit ece0aaf
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 100 deletions.
16 changes: 4 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,12 @@ initial_state = nuts.new_state(y_vv, logprob_fn)

step_size = at.as_tensor(1e-2)
inverse_mass_matrix=at.as_tensor(1.0)
(
next_state,
potential_energy,
potential_energy_grad,
acceptance_prob,
num_doublings,
is_turning,
is_diverging,
), updates = kernel(*initial_state, step_size, inverse_mass_matrix)

next_step_fn = aesara.function([y_vv], next_state, updates=updates)
chain_info, updates = kernel(initial_state, step_size, inverse_mass_matrix)

next_step_fn = aesara.function([y_vv], chain_info.state.position, updates=updates)

print(next_step_fn(0))
# 0.14344008534533775
# 1.1034719409361107
```

## Install
Expand Down
19 changes: 12 additions & 7 deletions aehmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def step(
step_size: TensorVariable,
inverse_mass_matrix: TensorVariable,
num_integration_steps: int,
) -> Tuple[Tuple[IntegratorState, TensorVariable, bool], Dict]:
) -> Tuple[trajectory.Diagnostics, Dict]:
"""Perform a single step of the HMC algorithm.
Parameters
Expand Down Expand Up @@ -120,10 +120,8 @@ def step(
divergence_threshold,
)
updated_state = state._replace(momentum=momentum_generator(srng))
new_state, acceptance_proba, is_divergent, updates = proposal_generator(
srng, updated_state, step_size
)
return (new_state, acceptance_proba, is_divergent), updates
chain_info, updates = proposal_generator(srng, updated_state, step_size)
return chain_info, updates

return step

Expand Down Expand Up @@ -158,7 +156,7 @@ def hmc_proposal(

def propose(
srng: RandomStream, state: IntegratorState, step_size: TensorVariable
) -> Tuple[IntegratorState, TensorVariable, bool, Dict]:
) -> Tuple[trajectory.Diagnostics, Dict]:
"""Use the HMC algorithm to propose a new state.
Parameters
Expand Down Expand Up @@ -195,7 +193,14 @@ def propose(
p_accept = at.clip(at.exp(delta_energy), 0, 1.0)
do_accept = srng.bernoulli(p_accept)
final_state = IntegratorState(*ifelse(do_accept, new_state, state))
chain_info = trajectory.Diagnostics(
state=final_state,
acceptance_probability=p_accept,
is_diverging=is_transition_divergent,
num_doublings=None,
is_turning=None,
)

return final_state, p_accept, is_transition_divergent, updates
return chain_info, updates

return propose
66 changes: 30 additions & 36 deletions aehmc/nuts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, Dict, Tuple

import aesara.tensor as at
import numpy as np
Expand All @@ -9,7 +9,7 @@
from aehmc.integrators import IntegratorState
from aehmc.proposals import ProposalState
from aehmc.termination import iterative_uturn
from aehmc.trajectory import dynamic_integration, multiplicative_expansion
from aehmc.trajectory import Diagnostics, dynamic_integration, multiplicative_expansion

new_state = hmc.new_state

Expand Down Expand Up @@ -54,12 +54,10 @@ def potential_fn(x):
return -logprob_fn(x)

def step(
q: TensorVariable,
potential_energy: TensorVariable,
potential_energy_grad: TensorVariable,
state: IntegratorState,
step_size: TensorVariable,
inverse_mass_matrix: TensorVariable,
):
) -> Tuple[Diagnostics, Dict]:
"""Use the NUTS algorithm to propose a new state.
Parameters
Expand Down Expand Up @@ -112,50 +110,46 @@ def step(
max_num_expansions,
)

p = momentum_generator(srng)
initial_state = IntegratorState(
position=q,
momentum=p,
potential_energy=potential_energy,
potential_energy_grad=potential_energy_grad,
initial_state = state._replace(momentum=momentum_generator(srng))
initial_termination_state = new_termination_state(
initial_state.position, max_num_expansions
)
initial_energy = initial_state.potential_energy + kinetic_energy_fn(
initial_state.momentum
)
initial_termination_state = new_termination_state(q, max_num_expansions)
initial_energy = potential_energy + kinetic_energy_fn(p)
initial_proposal = ProposalState(
state=initial_state,
energy=initial_energy,
weight=at.as_tensor(0.0, dtype=np.float64),
sum_log_p_accept=at.as_tensor(-np.inf, dtype=np.float64),
)
result, updates = expand(

results, updates = expand(
initial_proposal,
initial_state,
initial_state,
p,
initial_state.momentum,
initial_termination_state,
initial_energy,
step_size,
)

# New MCMC proposal
q_new = result[0][-1]
potential_energy_new = result[2][-1]
potential_energy_grad_new = result[3][-1]

# Diagnostics
is_turning = result[-1][-1]
is_diverging = result[-2][-1]
num_doublings = result[-3][-1]
acceptance_probability = result[-4][-1]

return (
q_new,
potential_energy_new,
potential_energy_grad_new,
acceptance_probability,
num_doublings,
is_turning,
is_diverging,
), updates
# extract the last iteration from multiplicative_expansion chain diagnostics
chain_info = Diagnostics(
state=IntegratorState(
position=results.diagnostics.state.position[-1],
momentum=results.diagnostics.state.momentum[-1],
potential_energy=results.diagnostics.state.potential_energy[-1],
potential_energy_grad=results.diagnostics.state.potential_energy_grad[
-1
],
),
acceptance_probability=results.diagnostics.acceptance_probability[-1],
num_doublings=results.diagnostics.num_doublings[-1],
is_turning=results.diagnostics.is_turning[-1],
is_diverging=results.diagnostics.is_diverging[-1],
)

return chain_info, updates

return step
104 changes: 97 additions & 7 deletions aehmc/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, Tuple
from typing import Callable, Dict, NamedTuple, Tuple

import aesara
import aesara.tensor as at
Expand Down Expand Up @@ -376,6 +376,23 @@ def add_one_state(
return integrate


class Diagnostics(NamedTuple):
state: IntegratorState
acceptance_probability: TensorVariable
num_doublings: TensorVariable
is_turning: TensorVariable
is_diverging: TensorVariable


class MultiplicativeExpansionResult(NamedTuple):
proposals: ProposalState
right_states: IntegratorState
left_states: IntegratorState
momentum_sums: TensorVariable
termination_states: TerminationState
diagnostics: Diagnostics


def multiplicative_expansion(
srng: RandomStream,
trajectory_integrator: Callable,
Expand Down Expand Up @@ -416,7 +433,7 @@ def expand(
termination_state: TerminationState,
initial_energy,
step_size,
):
) -> Tuple[MultiplicativeExpansionResult, Dict]:
"""Expand the current trajectory multiplicatively.
At each step we draw a direction at random, build a subtrajectory starting
Expand Down Expand Up @@ -465,7 +482,7 @@ def expand_once(
momentum_sum_ckpts,
idx_min,
idx_max,
):
) -> Tuple[Tuple[TensorVariable, ...], Dict, until]:
left_state = (
q_left,
p_left,
Expand Down Expand Up @@ -591,7 +608,33 @@ def expand_once(
)

expansion_steps = at.arange(0, max_num_expansions)
results, updates = aesara.scan(
# results, updates = aesara.scan(
(
proposal_state_position,
proposal_state_momentum,
proposal_state_potential_energy,
proposal_state_potential_energy_grad,
proposal_energy,
proposal_weight,
proposal_sum_log_p_accept,
left_state_position,
left_state_momentum,
left_state_potential_energy,
left_state_potential_energy_grad,
right_state_position,
right_state_momentum,
right_state_potential_energy,
right_state_potential_energy_grad,
momentum_sum,
momentum_checkpoints,
momentum_sum_checkpoints,
min_indices,
max_indices,
acceptance_probability,
num_doublings,
is_diverging,
is_turning,
), updates = aesara.scan(
expand_once,
outputs_info=(
proposal.state.position,
Expand All @@ -610,16 +653,63 @@ def expand_once(
right_state.potential_energy,
right_state.potential_energy_grad,
momentum_sum,
*termination_state,
termination_state.momentum_checkpoints,
termination_state.momentum_sum_checkpoints,
termination_state.min_index,
termination_state.max_index,
None,
None,
None,
None,
),
sequences=expansion_steps,
)

return results, updates
# Ensure each item of the returned result sequence is packed into the appropriate namedtuples.
typed_result = MultiplicativeExpansionResult(
proposals=ProposalState(
state=IntegratorState(
position=proposal_state_position,
momentum=proposal_state_momentum,
potential_energy=proposal_state_potential_energy,
potential_energy_grad=proposal_state_potential_energy_grad,
),
energy=proposal_energy,
weight=proposal_weight,
sum_log_p_accept=proposal_sum_log_p_accept,
),
left_states=IntegratorState(
position=left_state_position,
momentum=left_state_momentum,
potential_energy=left_state_potential_energy,
potential_energy_grad=left_state_potential_energy_grad,
),
right_states=IntegratorState(
position=right_state_position,
momentum=right_state_momentum,
potential_energy=right_state_potential_energy,
potential_energy_grad=right_state_potential_energy_grad,
),
momentum_sums=momentum_sum,
termination_states=TerminationState(
momentum_checkpoints=momentum_checkpoints,
momentum_sum_checkpoints=momentum_sum_checkpoints,
min_index=min_indices,
max_index=max_indices,
),
diagnostics=Diagnostics(
state=IntegratorState(
position=proposal_state_position,
momentum=proposal_state_momentum,
potential_energy=proposal_state_potential_energy,
potential_energy_grad=proposal_state_potential_energy_grad,
),
acceptance_probability=acceptance_probability,
num_doublings=num_doublings,
is_turning=is_turning,
is_diverging=is_diverging,
),
)
return typed_result, updates

return expand

Expand Down
34 changes: 23 additions & 11 deletions aehmc/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aehmc.algorithms import DualAveragingState
from aehmc.integrators import IntegratorState
from aehmc.mass_matrix import covariance_adaptation
from aehmc.nuts import Diagnostics
from aehmc.step_size import dual_averaging_adaptation


Expand Down Expand Up @@ -42,7 +43,13 @@ def one_step(
step_size, # parameters
inverse_mass_matrix,
):
chain_state = (q, potential_energy, potential_energy_grad)
chain_state = IntegratorState(
position=q,
momentum=None,
potential_energy=potential_energy,
potential_energy_grad=potential_energy_grad,
)

warmup_state = (
DualAveragingState(
step=step,
Expand All @@ -56,17 +63,17 @@ def one_step(
parameters = (step_size, inverse_mass_matrix)

# Advance the chain by one step
chain_state, inner_updates = kernel(*chain_state, *parameters)
chain_info, inner_updates = kernel(chain_state, *parameters)

# Update the warmup state and parameters
warmup_state, parameters = update_adapt(
warmup_step, warmup_state, parameters, chain_state
warmup_step, warmup_state, parameters, chain_info
)
da_state = warmup_state[0]
return (
chain_state[0], # q
chain_state[1], # potential_energy
chain_state[2], # potential_energy_grad
chain_info.state.position, # q
chain_info.state.potential_energy, # potential_energy
chain_info.state.potential_energy_grad, # potential_energy_grad
da_state.step,
da_state.iterates, # log_step_size
da_state.iterates_avg, # log_step_size_avg
Expand Down Expand Up @@ -182,14 +189,19 @@ def final(
step_size = at.exp(da_state.iterates_avg) # return stepsize_avg at the end
return step_size, inverse_mass_matrix

def update(step: int, warmup_state: Tuple, parameters: Tuple, chain_state: Tuple):
position, _, _, p_accept, *_ = chain_state

def update(
step: int, warmup_state: Tuple, parameters: Tuple, chain_state: Diagnostics
):
stage = schedule_stage[step]
warmup_state, parameters = where_warmup_state(
at.eq(stage, 0),
fast_update(p_accept, warmup_state, parameters),
slow_update(position, p_accept, warmup_state, parameters),
fast_update(chain_state.acceptance_probability, warmup_state, parameters),
slow_update(
chain_state.state.position,
chain_state.acceptance_probability,
warmup_state,
parameters,
),
)

is_middle_window_end = schedule_middle_window[step]
Expand Down
Loading

0 comments on commit ece0aaf

Please sign in to comment.