Source code for beetroots.sampler.utils.my_sampler_params

"""Defines a Dataclass that stores the parameters of the augmented PSGLD sampler
"""
from typing import List, Union

import numpy as np
import pandas as pd


[docs] class MySamplerParams(object): r"""Dataclass that stores the parameters of the sampler proposed in :cite:t:`paludEfficientSamplingNon2023`""" __slots__ = ( "initial_step_size", "extreme_grad", "history_weight", "selection_probas", "k_mtm", "is_stochastic", "compute_correction_term", ) def __init__( self, initial_step_size: float, extreme_grad: float, history_weight: float, selection_probas: Union[np.ndarray, List[float]], k_mtm: int, is_stochastic: bool = True, compute_correction_term: bool = True, ) -> None: r""" Parameters ---------- initial_step_size : float step size used in the Position-dependent MALA transition kernel, denoted :math:`\epsilon` in the article extreme_grad : float limit value that avoids division by zero when computing the RMSProp preconditioner, denoted :math:`\eta` in the article history_weight : float weight of past values of :math:`v` in the exponential decay (cf RMSProp preconditioner), denoted :math:`\alpha` in the article selection_probas : np.ndarray of shape (2,) vector of selection probabilities for the MTM and PMALA kernels, respectively, i.e., :math:`[p_{MTM}, 1 - p_{MTM}]` k_mtm : int number of candidates in the MTM kernel, denoted :math:`K` in the article is_stochastic : bool if True, the algorithm performs sampling, and optimization otherwise, by default True compute_correction_term : bool wether or not to use the correction term (denoted :math:`\gamma` in the article) during the sampling (only used if `is_stochastic=True`), by default True """ assert initial_step_size > 0 assert extreme_grad > 0 assert 0 < history_weight < 1 assert isinstance(k_mtm, int) and k_mtm >= 1 if isinstance(selection_probas, list): selection_probas = np.array(selection_probas) assert np.all(0 <= selection_probas) assert selection_probas.sum() == 1 assert isinstance(is_stochastic, bool) assert isinstance(compute_correction_term, bool) self.initial_step_size = initial_step_size self.extreme_grad = extreme_grad self.history_weight = history_weight self.selection_probas = selection_probas self.k_mtm = k_mtm self.is_stochastic = is_stochastic self.compute_correction_term = compute_correction_term