Source code for beetroots.sampler.my_sampler

r"""Contains a class of sampler used in the Meudon PDR code Bayesian inversion problems"""

import copy
from typing import Optional, Tuple, Union

import numpy as np
from scipy.special import softmax
from tqdm.auto import tqdm

from beetroots.modelling.likelihoods.abstract_likelihood import Likelihood
from beetroots.modelling.posterior import Posterior
from beetroots.sampler.abstract_sampler import Sampler
from beetroots.sampler.saver.abstract_saver import Saver
from beetroots.sampler.utils import utils
from beetroots.sampler.utils.mml import EBayesMMLELogRate
from beetroots.sampler.utils.my_sampler_params import MySamplerParams


[docs] class MySampler(Sampler): r"""Defines the sampler proposed in :cite:t:`paludEfficientSamplingNon2023` that randomly combines two transition kernels : 1. a independent MTM-chromatic Gibbs transition kernel 2. a position-dependent MALA transition kernel with the RMSProp pre-conditioner """ def __init__( self, my_sampler_params: MySamplerParams, D: int, L: int, N: int, rng: np.random.Generator = np.random.default_rng(42), ): r""" Parameters ---------- my_sampler_params : MySamplerParams contains the main parameters of the algorithm D : int total number of physical parameters to reconstruct L : int number of observables per component :math:`n` N : int total number of pixels to reconstruct rng : numpy.random.Generator, optional random number generator (for reproducibility), by default np.random.default_rng(42) """ # P-MALA params # ! redefine size of params self.eps0 = my_sampler_params.initial_step_size r"""float: step size used in the Position-dependent MALA transition kernel, denoted :math:`\epsilon > 0` in the article""" self.lambda_ = my_sampler_params.extreme_grad r"""float: limit value that avoids division by zero when computing the RMSProp preconditioner, denoted :math:`\eta > 0` in the article""" self.alpha = my_sampler_params.history_weight r"""float: weight of past values of :math:`v` in the exponential decay (cf RMSProp preconditioner), denoted :math:`\alpha \in ]0,1[` in the article""" # MTM params # assert np.isclose( # int(pow(my_sampler_params.k_mtm, 1 / D)) ** D, my_sampler_params.k_mtm # ), "number of candidates for mtm needs to have an integer D-root" self.k_mtm = my_sampler_params.k_mtm r"""int: number of candidates in the MTM kernel, denoted :math:`K` in the article""" # overall self.selection_probas = my_sampler_params.selection_probas r"""np.ndarray: vector of selection probabilities for the MTM and PMALA kernels, respectively, i.e., :math:`[p_{MTM}, 1 - p_{MTM}]`""" assert ( np.sum(self.selection_probas) == 1 ), f"{self.selection_probas} should sum to 1" self.stochastic = my_sampler_params.is_stochastic r"""bool: if True, the algorithm performs sampling, and optimization otherwise""" self.compute_correction_term = my_sampler_params.compute_correction_term r"""bool: wether or not to use the correction term (denoted :math:`\gamma` in the article) during the sampling (only used if `is_stochastic=True`)""" self.compute_derivatives_2nd_order = ( self.stochastic and self.compute_correction_term ) r"""bool: wether to compute the expensive second order derivatives terms. Only true when the sampler runs a Markov chain (2nd order never used in optimization) and when the correction term denoted :math:`\gamma` is to be computed.""" self.D = D r"""int: total number of physical parameters to reconstruct""" self.L = L r"""int: number of observables per component :math:`n`""" self.N = N r"""int: total number of pixels to reconstruct""" self.rng = rng r"""numpy.random.Generator: random number generator (for reproducibility)""" # initialization values, not to be kept during sampling self.v = np.zeros((N * D,)) r"""np.ndarray: RMSProp gradient variance vector, denoted :math:`v` in the article""" self.current = {} r"""dict: contains all the data about the current iterate (including the evaluations of the forward map and derivatives, etc.)"""
[docs] def generate_random_start_Theta_1pix( self, Theta: np.ndarray, posterior: Posterior, idx_pix: np.ndarray ) -> np.ndarray: r"""draws a random vectors for components :math:`n` (e.g., a pixel :math:`\theta_n`). The distribution used to draw these vectors is: * the smooth indicator prior * a combination of the smooth indicator prior and of a Gaussian mixture defined with the set of all combinations of neighbors of component :math:`n` Parameters ---------- Theta : np.ndarray current iterate posterior : Posterior contains the lower and upper bounds of the hypercube idx_pix : np.ndarray indices of the pixels Returns ------- np.array of shape (n_pix, self.k_mtm, D) random element of the hypercube defined by the lower and upper bounds with uniform distribution Raises ------ ValueError : if ``posterior.prior_indicator`` is None """ seed = self.rng.integers(0, 1_000_000_000) n_pix = idx_pix.size if posterior.prior_indicator is None: raise ValueError("The Posterior has no specified smooth indicator prior") if posterior.prior_spatial is None: # * sample from smooth indicator prior return utils.sample_smooth_indicator( posterior.prior_indicator.lower_bounds, posterior.prior_indicator.upper_bounds, posterior.prior_indicator.indicator_margin_scale, size=(n_pix * self.k_mtm, self.D), seed=seed, ).reshape((n_pix, self.k_mtm, self.D)) else: return utils.sample_conditional_spatial_prior( Theta, posterior.prior_spatial.list_edges, np.minimum( posterior.prior_spatial.initial_weights, posterior.prior_spatial.weights, ), idx_pix=idx_pix, k_mtm=self.k_mtm, seed=seed, ) # (n_pix, self.k_mtm, D)
def _update_model_check_values( self, dict_model_check: dict, likelihood: Likelihood, nll_full: np.ndarray, objective: float, ) -> dict: count_pval = dict_model_check["count_pval"] * 1 y_copy = likelihood.y * 1 if self.stochastic: dict_model_check["clppd_online"] *= count_pval / (count_pval + 1) dict_model_check["clppd_online"] += np.exp(-nll_full) / (count_pval + 1) y_rep = likelihood.sample_observation_model( self.current["forward_map_evals"], self.rng, ) likelihood_rep = copy.deepcopy(likelihood) likelihood_rep.y = y_rep * 1 assert np.allclose(likelihood.y, y_copy), "nooooo" nll_utils_rep = likelihood_rep.evaluate_all_nll_utils( self.current["forward_map_evals"], idx=None, compute_derivatives=False, compute_derivatives_2nd_order=False, ) nll_y_rep_full = likelihood_rep.neglog_pdf( self.current["forward_map_evals"], nll_utils_rep, full=True, ) # p-value per (N, L) with y_rep_{n,ell} <= y_{n,ell} dict_model_check["p_values_y"] *= count_pval / (count_pval + 1) dict_model_check["p_values_y"] += (y_rep <= likelihood.y) / (count_pval + 1) # p-value per (N,) with # p(y_rep_n \vert theta_n) <= p(y_n \vert theta_n) nll_y = np.sum(nll_full, axis=1) # (N,) nll_y_rep = np.sum(nll_y_rep_full, axis=1) # (N,) dict_model_check["p_values_llh"] *= count_pval / (count_pval + 1) dict_model_check["p_values_llh"] += (nll_y_rep >= nll_y) / (count_pval + 1) dict_model_check["count_pval"] += 1 else: if objective < dict_model_check["best_objective"]: dict_model_check["best_objective"] = objective * 1 dict_model_check["clppd_online"] = np.exp(-nll_full) # p-values are computed at the end of the optimisation process. return dict_model_check def _finalize_model_check_values( self, dict_model_check: dict, likelihood: Likelihood, forward_map_evals: dict, nll_full: np.ndarray, ) -> dict: if not self.stochastic: # optimization p-value computations on estimated \hat{\theta} for count_pval in tqdm(range(self.ESS_OPTIM)): y_rep = likelihood.sample_observation_model( forward_map_evals, self.rng, ) likelihood_rep = copy.deepcopy(likelihood) likelihood_rep.y = y_rep * 1 nll_utils_rep = likelihood_rep.evaluate_all_nll_utils( forward_map_evals, idx=None, compute_derivatives=False, compute_derivatives_2nd_order=False, ) nll_y_rep_full = likelihood_rep.neglog_pdf( forward_map_evals, nll_utils_rep, full=True, ) # p-value per (N, L) with y_rep_{n,ell} <= y_{n,ell} dict_model_check["p_values_y"] *= count_pval / (count_pval + 1) dict_model_check["p_values_y"] += (y_rep <= likelihood.y) / ( count_pval + 1 ) # p-value per (N,) with # p(y_rep_n \vert theta_n) <= p(y_n \vert theta_n) nll_y = np.sum(nll_full, axis=1) # (N,) nll_y_rep = np.sum(nll_y_rep_full, axis=1) # (N,) dict_model_check["p_values_llh"] *= count_pval / (count_pval + 1) dict_model_check["p_values_llh"] += (nll_y_rep >= nll_y) / ( count_pval + 1 ) # this p-value should be between 0 and 0.5 dict_model_check["p_values_y"] = np.where( dict_model_check["p_values_y"] > 0.5, 1 - dict_model_check["p_values_y"], dict_model_check["p_values_y"], ) return dict_model_check
[docs] def sample( self, posterior: Posterior, saver: Saver, max_iter: int, Theta_0: Optional[np.ndarray] = None, disable_progress_bar: bool = False, # regu_spatial_N0: Union[int, float] = np.infty, regu_spatial_scale: float = 1.0, regu_spatial_vmin: float = 1e-8, regu_spatial_vmax: float = 1e8, # T_BI: int = 0, # used only for clppd ) -> None: r"""main method of the class, runs the sampler Parameters ---------- posterior : Posterior probability distribution to be sampled saver : Saver object responsible for progressively saving the Markov chain data during the run max_iter : int total duration of a Markov chain Theta_0 : Optional[np.ndarray], optional starting point, by default None disable_progress_bar : bool, optional wether to disable the progress bar, by default False regu_spatial_N0 : Union[int, float], optional number of iterations defining the initial update phase (for spatial regularization weight optimization). np.infty means that the optimization phase never starts, and that the weight optimization is not applied. by default np.infty regu_spatial_scale : Optional[float], optional scale parameter involved in the definition of the projected gradient step size (for spatial regularization weight optimization). by default 1.0 regu_spatial_vmin : Optional[float], optional lower limit of the admissible interval (for spatial regularization weight optimization), by default 1e-8 regu_spatial_vmax : Optional[float], optional upper limit of the admissible interval (for spatial regularization weight optimization), by default 1e8 T_BI : int, optional duration of the `Burn-in` phase, by default 0 """ additional_sampling_log = {} if Theta_0 is None: print("starting from a random point") Theta_0 = self.generate_random_start_Theta(posterior) # (N, D) assert Theta_0 is not None assert Theta_0.shape == (self.N, self.D) self.current = posterior.compute_all( Theta_0, compute_derivatives_2nd_order=self.compute_derivatives_2nd_order, # chromatic_gibbs="both", ) assert np.isnan(self.current["objective_global"]) == 0 assert np.sum(np.isnan(self.current["grad"])) == 0 # assert ( # self.current["forward_map_evals"]["f_Theta"].min() >= 0 # ), f"{self.current['forward_map_evals']['f_Theta'].min()}" # {self.current['forward_map_evals']['log_f_Theta'].min()}" # if v0 is None: # v_max = (self.current["grad"] ** 2).max(axis=0) # self.v = (v_max[None, :] * np.ones((self.N, self.D))).flatten() # assert np.sum(np.isnan(self.v)) == 0, np.sum(np.isnan(self.v)) # else: # self.v = v0 self.v = self.current["grad"].flatten() ** 2 assert np.sum(np.isnan(self.v)) == 0.0 assert np.sum(np.isinf(self.v)) == 0.0 # print(self.v, Theta_0, self.lambda_) # self.u = self.current["grad"].flatten() * self.current["hess_diag"].flatten() # assert self.u.shape == (self.N * self.D,) # if sample_regu_weights and T_BI_reguweights is None: # T_BI_reguweights = 0 # if not (sample_regu_weights) and T_BI_reguweights is None: # T_BI_reguweights = max_iter * 1 rng_state_array, _ = self.get_rng_state() # self.j_t = 0 self.j_t = np.zeros((self.N * self.D,)) # if self.N > 10: # print(f"at start: obj = {self.current['objective']}") # n_sites = len(posterior.prior_spatial.dict_sites) # n_repetitions_first_mtm = 1 # list_n_first_samples = list(range(n_sites)) * n_repetitions_first_mtm # self.rng.shuffle(list_n_first_samples) # if self.N > 1000: # list_n_first_samples = [] # list_n_first_samples = list_n_first_samples[:250] regu_weights_optimizer = EBayesMMLELogRate( scale=regu_spatial_scale, N0=regu_spatial_N0, N1=+np.infty, dim=self.D * self.N, vmin=regu_spatial_vmin, vmax=regu_spatial_vmax, homogeneity=2.0, exponent=0.8, ) optimize_regu_weights = regu_weights_optimizer.N0 < np.infty # clppd = computed log point-wise predictive density. # if self.stochastic : avg of all pred. likelihood terms (with burn-in) # but burn-in values are negligible (0) compared to non burn-in # else : predictive likelihood with best param theta clppd_online = np.zeros((self.N, self.L)) # utilitary variables best_objective = np.infty # used only if not self.stochastic count_pval = 0 # used only if self.stochastic # p(y^{rep}_\ell <= y_\ell | y) p_values_y = np.zeros((self.N, self.L)) # p(y^{rep}_\ell \in [ q_{25\%}(y_\ell), q_{75\%}(y_\ell) ] | y) p_values_llh = np.zeros((self.N,)) dict_model_check = { "clppd_online": clppd_online, "best_objective": best_objective, "count_pval": count_pval, "p_values_y": p_values_y, "p_values_llh": p_values_llh, } for t in tqdm(range(1, max_iter + 1), disable=disable_progress_bar): if optimize_regu_weights and (self.N > 1): assert posterior.prior_spatial is not None if t >= regu_weights_optimizer.N0: tau_t = self.sample_regu_hyperparams( posterior, regu_weights_optimizer, t, self.current["Theta"] * 1, ) posterior.prior_spatial.weights = tau_t * 1 # recompute posterior neg log pdf and gradients with # new spatial regularization parameter self.current = posterior.compute_all( Theta=self.current["Theta"], forward_map_evals=self.current["forward_map_evals"], nll_utils=self.current["nll_utils"], compute_derivatives_2nd_order=self.compute_derivatives_2nd_order, # chromatic_gibbs=False, ) additional_sampling_log["tau"] = posterior.prior_spatial.weights * 1 # ------ type_t = np.argmax( self.rng.multinomial( 1, pvals=self.selection_probas, ) ) if type_t == 0: ( accepted_t, log_proba_accept_t, ) = self.generate_new_sample_mtm(t, posterior) else: assert type_t == 1 ( accepted_t, log_proba_accept_t, ) = self.generate_new_sample_pmala_rmsprop(t, posterior) # # check for nan in Theta # if np.sum(np.isnan(self.current["Theta"])) > 0: # print(f"type_t : {type_t}") # raise ValueError( # f"Theta contains {np.sum(np.isnan(self.current['Theta']))} nan" # ) # * if the memory is empty : initialize it if saver.memory == {}: additional_sampling_log["v"] = self.v.reshape((self.N, self.D)) * 1 additional_sampling_log["type_t"] = type_t additional_sampling_log["accepted_t"] = accepted_t additional_sampling_log["log_proba_accept_t"] = log_proba_accept_t dict_objective, nll_full = posterior.compute_all_for_saver( self.current["Theta"], self.current["forward_map_evals"], self.current["nll_utils"], ) if t > T_BI: assert isinstance(dict_objective["objective"], float) dict_model_check = self._update_model_check_values( dict_model_check, posterior.likelihood, nll_full, dict_objective["objective"] * 1, ) saver.initialize_memory( max_iter, t, Theta=self.current["Theta"], forward_map_evals=self.current["forward_map_evals"], nll_utils=self.current["nll_utils"], dict_objective=dict_objective, additional_sampling_log=additional_sampling_log, ) rng_state_array, rng_inc_array = self.get_rng_state() saver.update_memory( t, Theta=self.current["Theta"], forward_map_evals=self.current["forward_map_evals"], nll_utils=self.current["nll_utils"], dict_objective=dict_objective, additional_sampling_log=additional_sampling_log, rng_state_array=rng_state_array, rng_inc_array=rng_inc_array, ) elif saver.check_need_to_update_memory(t): # print(f"updating memory at t={t}") additional_sampling_log["v"] = self.v.reshape((self.N, self.D)) * 1 additional_sampling_log["type_t"] = type_t additional_sampling_log["accepted_t"] = accepted_t additional_sampling_log["log_proba_accept_t"] = log_proba_accept_t dict_objective, nll_full = posterior.compute_all_for_saver( self.current["Theta"], self.current["forward_map_evals"], self.current["nll_utils"], ) if t > T_BI: assert isinstance(dict_objective["objective"], float) dict_model_check = self._update_model_check_values( dict_model_check, posterior.likelihood, nll_full, dict_objective["objective"] * 1, ) rng_state_array, rng_inc_array = self.get_rng_state() saver.update_memory( t, Theta=self.current["Theta"], forward_map_evals=self.current["forward_map_evals"], nll_utils=self.current["nll_utils"], dict_objective=dict_objective, additional_sampling_log=additional_sampling_log, rng_state_array=rng_state_array, rng_inc_array=rng_inc_array, ) else: pass if saver.check_need_to_save(t): # print(f"saving memory at t={t}") saver.save_to_file() # --------- dict_model_check = self._finalize_model_check_values( dict_model_check, likelihood=posterior.likelihood, forward_map_evals=self.current["forward_map_evals"], nll_full=nll_full, ) saver.save_additional( list_arrays=[ dict_model_check["clppd_online"], dict_model_check["p_values_y"], dict_model_check["p_values_llh"], ], list_names=["clppd", "p-values-y", "p-values-llh"], ) return
# def generate_new_sample_pmala_rmsprop(self, t, posterior): # """generates a new sample using the position-dependent MALA transition kernel # Parameters # ---------- # t : int # current iteration index # score_model : Posterior # negative log posterior class # Returns # ------- # accepted : bool # wether or not the candidate was accepted # log_proba_accept : float # log of the acceptance proba # """ # grad_t = self.current["grad"].flatten() # # print(self.lambda_ + np.sqrt(self.v)) # diag_G_t = 1 / (self.lambda_ + np.sqrt(self.v)) # assert np.all( # diag_G_t > 0 # ), f"{diag_G_t}, {self.lambda_ + np.sqrt(self.v)}, {self.v}" # # generate random # z_t = self.rng.standard_normal(size=self.N * self.D) # z_t *= np.sqrt(self.eps0 * diag_G_t) # # bias correction term # if self.compute_correction_term: # # recursive version # # correction = -1 / 2 * diag_G_t ** 2 / np.sqrt(self.v) * self.u # # only with corresponding term # hess_diag_t = self.current["hess_diag"].flatten() # correction = ( # -(1 - self.alpha) # * self.alpha ** self.j_t # * (diag_G_t ** 2) # / np.sqrt(self.v) # * grad_t # * hess_diag_t # ) # if np.sum(~np.isfinite(correction)) > 0: # print( # f"num of nan in correction term: {np.sum(~np.isfinite(correction))}" # ) # correction = np.nan_to_num(correction) # ? nécessaire ? # else: # correction = np.zeros((self.N * self.D,)) # # combination # mu_current = ( # self.current["Theta"].flatten() # - self.eps0 / 2 * diag_G_t * grad_t # + self.eps0 * correction # ) # if self.stochastic: # candidate = mu_current + z_t # (N * D,) # log_q_candidate_given_current = -1 / 2 * np.sum(np.log(diag_G_t)) - 1 / ( # 2 * self.eps0 # ) * np.sum((candidate - mu_current) ** 2 / diag_G_t) # # * compute log_q of candidate given current # candidate_all = posterior.compute_all( # candidate.reshape(self.N, self.D), # ) # grad_cand = candidate_all["grad"].flatten() # v_cand = self.alpha * self.v + (1 - self.alpha) * grad_cand ** 2 # diag_G_cand = 1 / (self.lambda_ + np.sqrt(v_cand)) # if self.compute_correction_term: # hess_diag_cand = candidate_all["hess_diag"].flatten() # correction_cand = -( # (1 - self.alpha) # * diag_G_cand ** 2 # / np.sqrt(v_cand) # * grad_cand # * hess_diag_cand # ) # else: # correction_cand = np.zeros((self.N * self.D,)) # mu_cand = ( # candidate # - self.eps0 / 2 * diag_G_cand * grad_cand # + self.eps0 * correction_cand # ) # log_q_current_given_candidate = -1 / 2 * np.sum(np.log(diag_G_cand)) - 1 / ( # 2 * self.eps0 # ) * np.sum((self.current["Theta"].flatten() - mu_cand) ** 2 / diag_G_cand) # # * compute proba accept # logpdf_current = -self.current["objective"] * 1 # logpdf_candidate = -candidate_all["objective"] * 1 # log_proba_accept = ( # logpdf_candidate # - logpdf_current # + log_q_current_given_candidate # - log_q_candidate_given_current # ) # log_u = np.log(self.rng.uniform(0, 1)) # # print( # # f"{log_u:.2e}, {log_proba_accept:.4e}, {logpdf_candidate:.4e},, {logpdf_current:.4e}, {log_q_current_given_candidate:.4e}, {log_q_candidate_given_current:.4e}" # # ) # if log_u < log_proba_accept: # self.current = copy.copy(candidate_all) # self.v = v_cand * 1 # assert np.sum(np.isnan(self.v)) == 0.0 # assert np.sum(np.isinf(self.v)) == 0.0 # self.j_t = np.zeros((self.N * self.D,)) # return True, log_proba_accept # else: # self.j_t += 1 # self.v = v_cand * 1 # assert np.sum(np.isnan(self.v)) == 0.0 # assert np.sum(np.isinf(self.v)) == 0.0 # return False, log_proba_accept # # * in case we are doing optimization and not sampling # candidate_all = posterior.compute_all(mu_current.reshape((self.N, self.D))) # if candidate_all["objective"] < self.current["objective"]: # self.current = copy.copy(candidate_all) # accept = True # proba = 1 # else: # candidate = mu_current + z_t # (N * D,) # candidate_all = posterior.compute_all(candidate.reshape((self.N, self.D))) # if candidate_all["objective"] < self.current["objective"]: # self.current = copy.copy(candidate_all) # accept = True # proba = 1 # else: # accept = False # proba = 0 # grad_tp1 = candidate_all["grad"].flatten() # self.v = self.alpha * self.v + (1 - self.alpha) * grad_tp1 ** 2 # assert np.sum(np.isnan(self.v)) == 0.0 # assert ( # np.sum(np.isinf(self.v)) == 0.0 # ), f"{candidate_all['Theta']}, {candidate_all['grad']}" # assert np.sum(np.isnan(self.current["Theta"])) == 0.0 # return accept, proba
[docs] def generate_new_sample_pmala_rmsprop(self, t: int, posterior: Posterior): """generates a new sample using the position-dependent MALA transition kernel Parameters ---------- t : int current iteration index posterior : Posterior negative log posterior class Returns ------- accepted : bool wether or not the candidate was accepted log_proba_accept : float log of the acceptance proba """ if self.stochastic: accept_total = np.zeros((self.N,)) log_proba_accept_total = np.zeros((self.N,)) # * define proba of changing each pixel # * either uniformly or depending on their respective nll # if posterior.prior_spatial is not None: # n_sites = len(posterior.dict_sites) # idx_site = int(self.rng.integers(0, n_sites)) list_idx = np.array(list(posterior.dict_sites.keys())) if len(list_idx) > 1: chromatic_gibbs = True objective_type = "objective_pix_chromatic" else: chromatic_gibbs = False objective_type = "objective_pix_global" for idx_site in list_idx: idx_pix = posterior.dict_sites[idx_site] n_pix = idx_pix.size new_Theta = self.current["Theta"] * 1 # (N, D) grad_t = self.current["grad"][idx_pix, :] * 1 v_current = self.v.reshape((self.N, self.D))[idx_pix, :] * 1 # generate random diag_G_t = 1 / (self.lambda_ + np.sqrt(v_current)) # (n_pix, D) assert np.all( diag_G_t > 0 ), f"{diag_G_t}, {self.lambda_ + np.sqrt(self.v)}, {self.v}" z_t = self.rng.standard_normal(size=(n_pix, self.D)) z_t *= np.sqrt(self.eps0 * diag_G_t) # (n_pix, D) # bias correction term if self.compute_correction_term: # recursive version # correction = -1 / 2 * diag_G_t ** 2 # / np.sqrt(self.v) * self.u # only with corresponding term hess_diag_t = self.current["hess_diag"][idx_pix, :] * 1 j_t = self.j_t.reshape((self.N, self.D))[idx_pix, :] * 1 correction = ( -(1 - self.alpha) * self.alpha**j_t * (diag_G_t**2) / np.sqrt(v_current) * grad_t * hess_diag_t ) # (n_pix, D) if np.sum(~np.isfinite(correction)) > 0: n_inf = np.sum(~np.isfinite(correction)) print(f"num of nan in correction term: {n_inf}") correction = np.nan_to_num(correction) # ? nécessaire ? else: correction = np.zeros((n_pix, self.D)) # combination mu_current = ( new_Theta[idx_pix, :] - self.eps0 / 2 * diag_G_t * grad_t + self.eps0 * correction ) # (n_pix, D) candidate = mu_current + z_t # (n_pix, D) log_q_candidate_given_current = -1 / 2 * np.sum( np.log(diag_G_t), axis=1 ) - 1 / (2 * self.eps0) * np.sum( (candidate - mu_current) ** 2 / diag_G_t, axis=1 ) # (n_pix,) shape_q = log_q_candidate_given_current.shape assert shape_q == (n_pix,), f"{shape_q}" # * compute log_q of candidate given current candidate_full = new_Theta * 1 candidate_full[idx_pix, :] = candidate * 1 candidate_all = posterior.compute_all( candidate_full, compute_derivatives_2nd_order=self.compute_derivatives_2nd_order, # chromatic_gibbs=chromatic_gibbs, ) grad_cand = candidate_all["grad"][idx_pix, :] * 1 v_cand = ( self.alpha * v_current + (1 - self.alpha) * grad_cand**2 ) # (n_pix, D) diag_G_cand = 1 / (self.lambda_ + np.sqrt(v_cand)) # (n_pix, D) if self.compute_correction_term: hess_diag_cand = candidate_all["hess_diag"][idx_pix, :] * 1 correction_cand = -( (1 - self.alpha) * diag_G_cand**2 / np.sqrt(v_cand) * grad_cand * hess_diag_cand ) else: correction_cand = np.zeros((n_pix, self.D)) mu_cand = ( candidate - self.eps0 / 2 * diag_G_cand * grad_cand + self.eps0 * correction_cand ) # (n_pix, D) log_q_current_given_candidate = -1 / 2 * np.sum( np.log(diag_G_cand), axis=1 ) - 1 / (2 * self.eps0) * np.sum( (new_Theta[idx_pix, :] - mu_cand) ** 2 / diag_G_cand, axis=1 ) # (n_pix,) shape_q = log_q_current_given_candidate.shape assert shape_q == (n_pix,), f"{shape_q}" # * compute proba accept logpdf_current = -self.current[objective_type][idx_pix] logpdf_candidate = -candidate_all[objective_type][idx_pix] shape_1 = logpdf_current.shape shape_2 = logpdf_candidate.shape assert shape_1 == (n_pix,), f"{shape_1}" assert shape_2 == (n_pix,), f"{shape_2}" log_proba_accept = ( logpdf_candidate - logpdf_current + log_q_current_given_candidate - log_q_candidate_given_current ) assert log_proba_accept.shape == (n_pix,) log_u = np.log(self.rng.uniform(0, 1, size=n_pix)) accept_arr = log_u < log_proba_accept new_Theta[idx_pix, :] = np.where( accept_arr[:, None] * np.ones((n_pix, self.D)), candidate, # (n_pix, D) new_Theta[idx_pix, :], # (n_pix, D) ) accept_total[idx_pix] = accept_arr * 1 log_proba_accept_total[idx_pix] = log_proba_accept * 1 # update v and j v = self.v.reshape((self.N, self.D)) * 1 v[idx_pix, :] = v_cand * 1 self.v = v.flatten() j = self.j_t.reshape((self.N, self.D)) * 1 j[idx_pix, :] = np.where( accept_arr[:, None], 0.0, # reset to 0 if accept j[idx_pix, :] + 1, # else add 1 ) self.j_t = j.flatten() if accept_arr.max() > 0: # if at least one accept self.current = posterior.compute_all( new_Theta, compute_derivatives_2nd_order=self.compute_derivatives_2nd_order, # chromatic_gibbs="both", ) # after loop return accept_total.mean(), log_proba_accept_total.mean() else: # if optimization objective_type = "objective_global" chromatic_gibbs = False grad_t = self.current["grad"].flatten() # print(self.lambda_ + np.sqrt(self.v)) diag_G_t = 1 / (self.lambda_ + np.sqrt(self.v)) assert np.all( diag_G_t > 0 ), f"{diag_G_t}, {self.lambda_ + np.sqrt(self.v)}, {self.v}" # generate random z_t = self.rng.standard_normal(size=self.N * self.D) z_t *= np.sqrt(self.eps0 * diag_G_t) # combination mu_current = ( self.current["Theta"].flatten() - self.eps0 / 2 * diag_G_t * grad_t ) candidate_all = posterior.compute_all( mu_current.reshape((self.N, self.D)), compute_derivatives_2nd_order=self.compute_derivatives_2nd_order, # chromatic_gibbs="both", ) if candidate_all[objective_type] < self.current[objective_type]: self.current = copy.copy(candidate_all) accept = True proba = 1 else: candidate = mu_current + z_t # (N * D,) candidate_all = posterior.compute_all( candidate.reshape((self.N, self.D)), compute_derivatives_2nd_order=self.compute_derivatives_2nd_order, # chromatic_gibbs="both", ) if candidate_all[objective_type] < self.current[objective_type]: self.current = copy.copy(candidate_all) accept = True proba = 1 else: accept = False proba = 0 grad_tp1 = candidate_all["grad"].flatten() self.v = self.alpha * self.v + (1 - self.alpha) * grad_tp1**2 assert np.sum(np.isnan(self.v)) == 0.0 assert ( np.sum(np.isinf(self.v)) == 0.0 ), f"{candidate_all['Theta']}, {candidate_all['grad']}" assert np.sum(np.isnan(self.current["Theta"])) == 0.0 return accept, proba
[docs] def generate_new_sample_mtm( self, t: int, posterior: Posterior, # , idx_site: Union[int, None] = None ): r"""generates a new sample using the MTM transition kernel Parameters ---------- t : int current iteration index posterior : Posterior target posterior distribution to sample from Returns ------- accepted : bool wether or not the candidate was accepted log_proba_accept : float log of the acceptance proba """ new_Theta = self.current["Theta"] * 1 # (N, D) accept_total = np.zeros((self.N,)) log_rg_total = np.zeros((self.N,)) # * define proba of changing each pixel # * either uniformly or depending on their respective nll # if posterior.prior_spatial is not None: # n_sites = len(posterior.dict_sites) # idx_site = int(self.rng.integers(0, n_sites)) list_idx = np.array(list(posterior.dict_sites.keys())) # # if optim : only consider one group # if not self.stochastic: # idx_site_to_sample = self.rng.choice(list_idx.size) # list_idx = [list_idx[idx_site_to_sample]] chromatic_gibbs = False if len(list_idx) == 1 else True for idx_site in list_idx: idx_pix = posterior.dict_sites[idx_site] n_pix = idx_pix.size # * generate and evaluate candidates candidates = np.zeros((self.N, self.k_mtm + 1, self.D)) candidates += new_Theta[:, None, :] * 1 candidates[idx_pix, :-1, :] = self.generate_random_start_Theta_1pix( new_Theta, posterior, idx_pix ) neglogpdf_priors = posterior.mtm_neglog_pdf_priors( candidates, idx_pix, with_weights=True, chromatic_gibbs=chromatic_gibbs ) # (n_pix, k_mtm+1) neglogpdf_likelihood = posterior.likelihood.neglog_pdf_candidates( candidates[idx_pix].reshape((n_pix * (self.k_mtm + 1), self.D)), idx=idx_pix, Theta_t=new_Theta * 1, # self.current["Theta"] * 1 ) # (n_pix * (k_mtm+1),) assert neglogpdf_likelihood.shape == (n_pix * (self.k_mtm + 1),) # forward_map_evals = posterior.likelihood.evaluate_all_forward_map( # candidates[idx_pix].reshape((n_pix * (self.k_mtm + 1), self.D)), # compute_derivatives=False, # compute_derivatives_2nd_order=False, # ) # nll_utils = posterior.likelihood.evaluate_all_nll_utils( # forward_map_evals, # idx_pix, # compute_derivatives=False, # compute_derivatives_2nd_order=False, # ) # neglogpdf_likelihood = posterior.likelihood.neglog_pdf( # forward_map_evals, nll_utils, pixelwise=True, idx=idx_pix, # ) # (n_pix * (k_mtm+1),) # assert neglogpdf_likelihood.shape == (n_pix * (self.k_mtm + 1),) neglogpdf_candidates = neglogpdf_priors + neglogpdf_likelihood.reshape( (n_pix, self.k_mtm + 1) ) # * if optimization: define challenger with conditional posterior # * instead of likelihood, and only keep if better than current if not self.stochastic: idx_challengers = np.argmin( neglogpdf_candidates[:, :-1].sum(axis=0) ) # integer neglogpdf_candidates_challengers = neglogpdf_candidates[ :, idx_challengers ] challengers = candidates[idx_pix, idx_challengers] assert neglogpdf_candidates_challengers.shape == ( n_pix, ), neglogpdf_candidates_challengers.shape # challengers = candidates_pix[ # np.arange(len(candidates_pix)), idx_challengers, : # ] assert challengers.shape == (n_pix, self.D), challengers.shape # * compute values of corresponding pixels in current x candidates_already_Theta = candidates[idx_pix, -1, :] * 1 neglogpdf_already_Theta = neglogpdf_candidates[:, -1] * 1 assert candidates_already_Theta.shape == (n_pix, self.D) assert neglogpdf_already_Theta.shape == (n_pix,) # * select best pixels accept_arr = ( (neglogpdf_candidates_challengers < neglogpdf_already_Theta) & np.isfinite(neglogpdf_candidates_challengers) & np.isfinite(neglogpdf_already_Theta) ) new_Theta[idx_pix, :] = np.where( accept_arr[:, None] * np.ones((n_pix, self.D)), challengers, # (n_pix, D) candidates_already_Theta, # (n_pix, D) ) accept_total[idx_pix] = accept_arr * 1 # * save which pixels were accepted # *------ # * if sampling else: if posterior.prior_spatial is not None: nlratio_prior_proposal = utils.compute_nlpdf_spatial_proposal( candidates, posterior.prior_spatial.list_edges, np.minimum( posterior.prior_spatial.initial_weights, posterior.prior_spatial.weights, ), idx_pix, ) shape_ = nlratio_prior_proposal.shape assert shape_ == (n_pix, self.k_mtm + 1) neglogpdf_candidates += nlratio_prior_proposal neglogpdf_candidates_min = np.amin( neglogpdf_candidates, axis=1, keepdims=True ) neglogpdf_candidates -= neglogpdf_candidates_min pdf_candidates = np.exp(-neglogpdf_candidates) # (n_pix, k_mtm) log_numerators = np.log(np.sum(pdf_candidates[:, :-1], axis=1)) # log_numerators = np.where( # np.isinf(log_numerators), -1e15, log_numerators # ) assert log_numerators.shape == (n_pix,), log_numerators.shape # assert np.sum(1 - np.isfinite(log_numerators)) == 0, log_numerators # * choose challenger candidate weights = softmax(-neglogpdf_candidates[:, :-1], axis=1) assert np.sum(1 - np.isfinite(weights)) == 0, weights idx_challengers = np.zeros((n_pix,), dtype=int) for i in range(n_pix): idx_challengers[i] = self.rng.choice( self.k_mtm, p=weights[i], ) challengers = candidates[idx_pix, idx_challengers, :] # (n_pix, D) neglogpdf_challengers = neglogpdf_candidates[ np.arange(n_pix), idx_challengers ] shape_ = neglogpdf_challengers.shape assert shape_ == (n_pix,), shape_ # * denominator log_denominators = np.log( np.sum(pdf_candidates, axis=1) - np.exp(-neglogpdf_challengers) ) # log_denominators = np.where( # np.isinf(log_denominators), -1e15, log_denominators # ) shape_ = log_denominators.shape assert shape_ == (n_pix,), shape_ # assert np.sum(1 - np.isfinite(log_numerators)) == 0, log_numerators # assert np.sum(1 - np.isfinite(log_denominators)) == 0, log_denominators # * accept-reject log_rg = log_numerators - log_denominators log_rg = np.where( np.isfinite(log_rg), log_rg, 1e-15 ) # if either log_numerators or log_denominators is not finite, do not accept log_u = np.log(self.rng.uniform(0, 1, size=n_pix)) accept_arr = log_u < log_rg new_Theta[idx_pix, :] = np.where( accept_arr[:, None] * np.ones((n_pix, self.D)), challengers, # (n_pix, D) candidates[idx_pix][:, -1, :], # (n_pix, D) ) accept_total[idx_pix] = accept_arr * 1 log_rg_total[idx_pix] = log_rg * 1 # * re-initialize j for new point new_j_t = self.j_t.reshape((self.N, self.D)) new_j_t[idx_pix, :] = np.where( accept_arr[:, None], 0.0, new_j_t[idx_pix, :] ) self.j_t = new_j_t.flatten() # (ND,) # *------ # * once all sites have been dealt with, update global parameters if accept_total.max() > 0: # if at least one accept self.current = posterior.compute_all( new_Theta, compute_derivatives_2nd_order=self.compute_derivatives_2nd_order, # chromatic_gibbs="both", # temporary quick fix to handle chromatic gibbs or full map ) new_v = self.v.reshape((self.N, self.D)) new_v = np.where( accept_total[:, None], self.alpha * new_v + (1 - self.alpha) * self.current["grad"] ** 2, new_v, ) self.v = new_v.flatten() assert np.sum(np.isnan(self.v)) == 0.0 assert np.sum(np.isinf(self.v)) == 0.0 if not self.stochastic: return np.mean(accept_total), np.mean(accept_total) else: return np.mean(accept_total), np.mean(log_rg_total)