Source code for beetroots.modelling.likelihoods.approx_censored_add_mult

"""Implementation of an approximation of the likelihood function of a mixture of Gaussian and multiplicative noises with censorship with a lower limit.
"""
from typing import List, Optional, Union

import numba
import numpy as np
import pandas as pd
from scipy.special import log_ndtr

from beetroots.modelling.likelihoods import utils
from beetroots.modelling.likelihoods.abstract_likelihood import Likelihood


[docs] @numba.njit() def cost_au( y: np.ndarray, f_Theta: np.ndarray, m_a: np.ndarray, s_a: np.ndarray ) -> np.ndarray: nll_au = 0.5 * ((y - f_Theta - m_a) / s_a) ** 2 + np.log(s_a) return nll_au
[docs] @numba.njit() def cost_mu( log_y: np.ndarray, log_f_Theta: np.ndarray, m_m: np.ndarray, s_m: np.ndarray ) -> np.ndarray: nll_mu = 0.5 * ((log_y - log_f_Theta - m_m) / s_m) ** 2 + np.log(s_m) # nll_mu = np.where(np.isnan(log_y) | np.isinf(log_y), 10.0 ** 8, nll_mu) # shape = nll_mu.shape # nll_mu = nll_mu.ravel() # nll_mu[np.isnan(nll_mu)] = 0 # nll_mu = nll_mu.reshape(shape) return nll_mu
[docs] @numba.njit() def gradient_cost_au( y: np.ndarray, grad_f_Theta: np.ndarray, f_Theta: np.ndarray, grad_m_a: np.ndarray, m_a: np.ndarray, grad_s_a2: np.ndarray, s_a2: np.ndarray, ) -> np.ndarray: u_1 = (f_Theta - m_a - y) / s_a2**2 assert u_1.shape == f_Theta.shape # (N, L) u_2 = ( 2 * np.expand_dims(s_a2, axis=1) * (grad_f_Theta + grad_m_a) - np.expand_dims(f_Theta - m_a - y, axis=1) * grad_s_a2 ) assert u_2.shape == grad_f_Theta.shape # (N, D, L) g_au = 0.5 * grad_s_a2 / np.expand_dims(s_a2, axis=1) g_au += 0.5 * np.expand_dims(u_1, axis=1) * u_2 assert g_au.shape == grad_f_Theta.shape return g_au
[docs] @numba.njit() def gradient_cost_mu( log_y: np.ndarray, grad_log_f_Theta: np.ndarray, log_f_Theta: np.ndarray, grad_m_m: np.ndarray, m_m: np.ndarray, grad_s_m2: np.ndarray, s_m2: np.ndarray, ) -> np.ndarray: u_1 = (log_f_Theta - m_m - log_y) / s_m2**2 assert u_1.shape == log_f_Theta.shape # (N, L) u_2 = ( 2 * np.expand_dims(s_m2, axis=1) * (grad_log_f_Theta + grad_m_m) - np.expand_dims(log_f_Theta - m_m - log_y, axis=1) * grad_s_m2 ) assert u_2.shape == grad_log_f_Theta.shape # (N, D, L) g_mu = 0.5 * grad_s_m2 / np.expand_dims(s_m2, axis=1) g_mu += np.where( np.expand_dims(np.isnan(log_y), axis=1) + np.zeros_like(grad_log_f_Theta), # fake output np.zeros_like(grad_log_f_Theta), # true grad 0.5 * np.expand_dims(u_1, axis=1) * u_2, ) assert g_mu.shape == grad_log_f_Theta.shape return g_mu
[docs] class MixingModelsLikelihood(Likelihood): r"""Class implementing a Gaussian likelihood model with lower censorship. This likelihood function is introduced in Section II.C of :cite:t:`paludEfficientSamplingNon2023`. Note ---- This likelihood is a parametric approximation of the true likelihood model. The associated parameter, denoted :math:`a_\ell` in the article (as there is one such parameter per observable :math:`\ell `), should be adjusted before any inversion. To adjust this parameter, see the ``beetroots.approx_optim`` subpackage. """ __slots__ = ( "N", "L", "D", "y", "forward_map", "log_y", "sigma_a", "sigma_m", "omega", "log_fm1", "log_fp1", "P_lambda", "grad_P_lambda", "hess_diag_P_lambda", ) def __init__( self, forward_map, D: int, L: int, N: int, y: np.ndarray, sigma_a: Union[float, np.ndarray], sigma_m: Union[float, np.ndarray], omega: Union[float, np.ndarray], path_transition_params: str, list_lines_fit: List[str], ) -> None: """Constructor of the GaussianLikelihood object. Parameters ---------- forward_map : ForwardMap instance forward map D : int number of disinct physical parameters in input space. L : int number of distinct observed physical parameters. N : int number of pixels in each physical dimension y : np.ndarray of shape (N, L) mean of the gaussian distribution sigma_a : float or np.ndarray of shape (N, L) standard deviation of the Gaussian distribution sigma_m : float or np.ndarray of shape (N, L) scale parameter of the lognormal distribution omega : float or np.ndarray of shape (N, L) censorship threshold Raises ------ ValueError y must have the shape (N, L) """ super().__init__(forward_map, D, L, N, y) # ! trigger an error is the mean y contains less than N elements if not (y.shape == (N, L)): raise ValueError( "y must have the shape (N, L) = ({}, {}) elements".format( self.N, self.L ) ) if isinstance(sigma_a, (float, int)): self.sigma_a = sigma_a * np.ones((N, L)) else: assert sigma_a.shape == (N, L) self.sigma_a = sigma_a if isinstance(sigma_m, (float, int)): self.sigma_m = sigma_m * np.ones((N, L)) else: assert sigma_m.shape == (N, L) self.sigma_m = sigma_m if isinstance(omega, (float, int)): self.omega = omega * np.ones((N, L)) else: assert omega.shape == (N, L) self.omega = omega self.log_y = np.log(y) self.set_likelihood_approx_parameters(path_transition_params, list_lines_fit) # set polynom for transition self.P_lambda = np.poly1d(np.array([-6.0, 15.0, -10.0, 0.0, 0.0, 1.0])) self.grad_P_lambda = self.P_lambda.deriv(m=1) self.hess_diag_P_lambda = self.P_lambda.deriv(m=2)
[docs] def set_likelihood_approx_parameters( self, path_transition_params: str, list_lines_fit: List[str] ) -> None: r"""Sets the likelihood approximation parameters by reading them from the csv file indicated in the `mixing_model_params_filename` entry of the input params yaml file. This `mixing_model_params_filename` file is obtained with the Bayesian Optimization procedure implemented in the `approx_optim` subpackage (see documentation for more details on how to obtain this file). The parameters are called `log_fm1` and `log_fp1`. They correspond to the lower and upper limits of the interval in which the two approximations (additive and multiplicative) are combined (below the lower limit, only the additive model is used, and above the upper limit, only the multiplicative model is used). These two parameters are computed for each pixel `n` and line `ell` from the two parameters that are stored in the `.csv` file, which are the center and the radius of the interval. Parameters ---------- path_transition_params : str path to the file that contains the approximation parameters list_lines_fit : List[str] list of lines used during the fit """ df_transition = pd.read_csv(path_transition_params) num_distinct_n = np.unique(df_transition["n"]).size # check that the lines in list_lines_fit are in the # df_transition dataframe list_unique_lines = np.unique(df_transition["line"].values) for line in list_lines_fit: assert ( line in list_unique_lines ), f"the line {line} is not in the file containing the likelihood approximation parameters" # in case there is only one set of approx parameters per line (ie, to be used for each pixel) if num_distinct_n == 1: df_transition = df_transition.set_index("line") transition_center = ( df_transition.loc[list_lines_fit, "a0_best"].values[None, :] * np.ones((self.N, self.L)) * np.log(10) ) transition_radius = ( df_transition.loc[list_lines_fit, "a1_best"].values[None, :] * np.ones((self.N, self.L)) * np.log(10) ) print( "Using the same set of approximation parameters for all pixels. Careful: this can lead to estimation errors." ) # in case there are as many sets of approx parameters per line as pixels else: assert ( num_distinct_n == self.N ), f"the transition file {path_transition_params} should have {self.N} or 1 distinct values of n, and has {num_distinct_n}" df_transition = df_transition.set_index(["n", "line"]) index = pd.MultiIndex.from_product( [list(range(self.N)), list_lines_fit], names=["n", "line"], ) transition_center = df_transition.loc[index, "a0_best"].values.reshape( (self.N, self.L) ) * np.log(10) transition_radius = df_transition.loc[index, "a1_best"].values.reshape( (self.N, self.L) ) * np.log(10) print( "Using sets of approximation parameters defined for each pixel and line." ) self.log_fm1 = transition_center - transition_radius # (N, L) self.log_fp1 = transition_center + transition_radius # (N, L) return
[docs] def sample_observation_model( self, forward_map_evals: dict, rng: np.random.Generator = np.random.default_rng(), ) -> np.ndarray: eps_a = rng.normal(loc=0.0, scale=self.sigma_a) eps_m = rng.lognormal( mean=-(self.sigma_m**2) / 2, sigma=self.sigma_m, ) f_Theta = forward_map_evals["f_Theta"] * 1 y_rep = np.maximum(self.omega, eps_m * f_Theta + eps_a) return y_rep
[docs] def model_mixing_param( self, forward_map_evals: dict, idx: Optional[np.ndarray] = None, ) -> np.ndarray: r"""computes the weight of the additive model :math:`\lambda_{n, \ell}` (line-wise and pixel-wise). In this model, :math:`\lambda_{n, \ell}` is a functino of the observation :math:`y_{n, \ell}`, and therefore constant during the sampling .. math:: \lambda_{n, \ell} = \frac{\sigma_a^{2b}}{\sigma_a^{2b} + (a \sigma_m y_{n, \ell})^{2b}} with :math:`a` a transition location parameter and :math:`b` a transition speed parameter """ if idx is None: N_pix = self.N * 1 log_fm1 = self.log_fm1 * 1 log_fp1 = self.log_fp1 * 1 # sigma_a = self.sigma_a * 1 # sigma_m = self.sigma_m * 1 else: n_pix = idx.size k_mtm = forward_map_evals["f_Theta"].shape[0] // n_pix N_pix = forward_map_evals["f_Theta"].shape[0] log_fm1 = np.zeros((n_pix, k_mtm, self.L)) log_fp1 = np.zeros((n_pix, k_mtm, self.L)) for i_pix in range(n_pix): log_fm1[i_pix, :, :] = self.log_fm1[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) log_fp1[i_pix, :, :] = self.log_fp1[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) log_fm1 = log_fm1.reshape((N_pix, self.L)) log_fp1 = log_fp1.reshape((N_pix, self.L)) # N_pix = forward_map_evals["f_Theta"].shape[0] # log_fm1 = np.zeros((N_pix, self.L)) # log_fp1 = np.zeros((N_pix, self.L)) # for i in range(idx.size): # log_fm1[i, :] = self.log_fm1[idx[i], :] * 1 # log_fp1[i, :] = self.log_fp1[idx[i], :] * 1 # log_fp1 = self.log_fp1[idx, :][None, :] * np.ones((N_pix, self.L)) # sigma_a = self.sigma_a[idx, :][None, :] * np.ones((N_pix, self.L)) # sigma_m = self.sigma_m[idx, :][None, :] * np.ones((N_pix, self.L)) # lambda_ = 1 / ( # 1 # + np.exp( # (2 * self.transition_speed) # * np.log( # (self.transition_loc * sigma_m * forward_map_evals["f_Theta"]) / sigma_a # ) # ) # ) lambda_ = np.where( forward_map_evals["log_f_Theta"] <= log_fm1, 1, np.where( forward_map_evals["log_f_Theta"] >= log_fp1, 0, self.P_lambda( (forward_map_evals["log_f_Theta"] - log_fm1) / (log_fp1 - log_fm1) ), ), ) return lambda_ # (N, L)
[docs] def grad_model_mixing_param(self, forward_map_evals: dict) -> np.ndarray: r"""[summary] Parameters ---------- lambda_ : np.ndarray of shape (N, L) [description] f_Theta : np.ndarray of shape (N, L) [description] grad_f_Theta : np.ndarray of shape (N, D, L) [description] Returns ------- np.ndarray of shape (N, D, L) [description] """ # grad_ = ( # (-2 * self.transition_speed) # * ( # ( # (self.transition_loc * self.sigma_m / self.sigma_a) # ** (2 * self.transition_speed) # ) # * ( # forward_map_evals["f_Theta"] ** (2 * self.transition_speed - 1) # * lambda_ ** 2 # ) # )[:, None, :] # * forward_map_evals["grad_f_Theta"] # ) u = (forward_map_evals["log_f_Theta"] - self.log_fm1) / ( self.log_fp1 - self.log_fm1 ) # (N, L) grad_u = ( forward_map_evals["grad_log_f_Theta"] / (self.log_fp1 - self.log_fm1)[:, None, :] ) # (N, D, L) # assert u.shape == (self.N, self.L) # assert grad_u.shape == (self.N, self.D, self.L), grad_u.shape grad_ = np.where( ( (forward_map_evals["log_f_Theta"] <= self.log_fm1) | (forward_map_evals["log_f_Theta"] >= self.log_fp1) )[:, None, :], np.zeros((self.N, self.D, self.L)), grad_u * self.grad_P_lambda(u)[:, None, :], ) assert grad_.shape == (self.N, self.D, self.L), grad_.shape return grad_ # (N, D, L)
[docs] def hess_diag_model_mixing_param(self, forward_map_evals: dict) -> np.ndarray: r"""[summary] Parameters ---------- lambda_ : np.ndarray of shape (N, L) [description] grad_lambda_ : np.ndarray of shape (N, D, L) [description] f_Theta : np.ndarray of shape (N, L) [description] grad_f_Theta : np.ndarray of shape (N, D, L) [description] hess_diag_f_Theta : np.ndarray of shape (N, D, L) [description] Returns ------- np.ndarray of shape (N, D, L) [description] """ # hess_diag = ( # (-2 * self.transition_speed) # * (self.transition_loc * self.sigma_m / self.sigma_a) # ** (2 * self.transition_speed) # )[:, None, :] * ( # ( # forward_map_evals["f_Theta"] ** (2 * self.transition_speed - 1) # * lambda_ ** 2 # )[:, None, :] # * forward_map_evals["hess_diag_f_Theta"] # + (2 * self.transition_speed - 1) # * ( # forward_map_evals["grad_f_Theta"] ** 2 # * ( # lambda_ ** 2 # * forward_map_evals["f_Theta"] ** (2 * self.transition_speed - 2) # )[:, None, :] # ) # + 2 # * (forward_map_evals["f_Theta"] ** (2 * self.transition_speed - 1) * lambda_)[ # :, None, : # ] # * forward_map_evals["grad_f_Theta"] # * grad_lambda_ # ) u = (forward_map_evals["log_f_Theta"] - self.log_fm1) / ( self.log_fp1 - self.log_fm1 ) # (N, L) grad_u = ( forward_map_evals["grad_log_f_Theta"] / (self.log_fp1 - self.log_fm1)[:, None, :] ) # (N, D, L) hess_diag_u = ( forward_map_evals["hess_diag_log_f_Theta"] / (self.log_fp1 - self.log_fm1)[:, None, :] ) # (N, D, L) hess_diag = np.where( ( (forward_map_evals["log_f_Theta"] <= self.log_fm1) | (forward_map_evals["log_f_Theta"] >= self.log_fp1) )[:, None, :], np.zeros((self.N, self.D, self.L)), ( grad_u**2 * self.hess_diag_P_lambda(u)[:, None, :] + hess_diag_u * self.grad_P_lambda(u)[:, None, :] ), ) assert hess_diag.shape == (self.N, self.D, self.L), hess_diag.shape return hess_diag # (N, D, L)
def _compute_bias_and_std( self, forward_map_evals: dict, idx: Optional[np.ndarray] = None, ) -> np.ndarray: r"""computes the biases and std of additive and multiplicative approximations Parameters ---------- f_Theta : numpy.array of shape (N_pix, L) image of the current iterate through the forward model sigma_a : np Returns ------- current_bias_std : np.ndarray of shape (4, N_pix, L) array with (m_a, s_a, m_m, s_m) """ if idx is None: N_pix = self.N * 1 sigma_a = self.sigma_a * 1 sigma_m = self.sigma_m * 1 else: N_pix = forward_map_evals["f_Theta"].shape[0] sigma_a = np.zeros((N_pix, self.L)) sigma_m = np.zeros((N_pix, self.L)) for i in range(idx.size): sigma_a[i, :] = self.sigma_a[idx[i], :] * 1 sigma_m[i, :] = self.sigma_m[idx[i], :] * 1 log_combination = ( np.log(sigma_a) - forward_map_evals["log_f_Theta"] - (sigma_m**2) / 2 ) assert sigma_a.min() > 0, sigma_a.min() assert sigma_m.min() > 0, sigma_m.min() assert np.sum(np.isnan(forward_map_evals["log_f_Theta"])) == 0, np.sum( np.isnan(forward_map_evals["log_f_Theta"]) ) assert np.sum(np.isnan(log_combination)) == 0, np.sum(np.isnan(log_combination)) # * computation of bias and variances m_a = (np.exp(sigma_m**2 / 2) - 1) * forward_map_evals["f_Theta"] s_a = sigma_a * np.sqrt( (np.exp(sigma_m**2) - 1) * np.exp(-2 * log_combination) + 1 ) m_m = -0.5 * np.log(1 + np.exp(2 * log_combination)) s_m = np.sqrt(sigma_m**2 - 2 * m_m) assert s_m.min() > 0, f"{s_m.min()}, {(sigma_m ** 2 - 2 * m_m).min()}" # gather all in one array N_pix = forward_map_evals["f_Theta"].shape[0] current_bias_std = np.zeros((4, N_pix, self.L)) current_bias_std[0] = m_a * 1 current_bias_std[1] = s_a * 1 current_bias_std[2] = m_m * 1 current_bias_std[3] = s_m * 1 return current_bias_std
[docs] def neglog_pdf( self, forward_map_evals: dict, nll_utils: dict, pixelwise: bool = False, full: bool = False, idx: Optional[np.ndarray] = None, ) -> Union[float, np.ndarray]: nlpdf = nll_utils["lambda_"] * np.where( nll_utils["censored_mask"], nll_utils["nll_ac"], nll_utils["nll_au"] ) + (1 - nll_utils["lambda_"]) * np.where( nll_utils["censored_mask"], nll_utils["nll_mc"], nll_utils["nll_mu"] ) # (N, L) nlpdf = np.nan_to_num(nlpdf) nlpdf -= np.log(nll_utils["sigma_a"]) + np.log(nll_utils["sigma_m"]) # nlpdf /= self.N * self.L # nll_utils["censored_mask"].size = self.N * self.L if standard eval # nll_utils["censored_mask"].size = N_candidates * self.L if candidates if full: return nlpdf # (N, L) if pixelwise: sum_ = np.sum(nlpdf, axis=1) # (N,) assert sum_.size == forward_map_evals["f_Theta"].shape[0] return sum_ return nlpdf.sum()
[docs] def neglog_pdf_ac( self, forward_map_evals: dict, nll_utils: dict, omega: np.ndarray, ) -> np.ndarray: z = (omega - forward_map_evals["f_Theta"] - nll_utils["m_a"]) / nll_utils["s_a"] # z = np.nan_to_num(z) nll_ac = -log_ndtr(z) nll_ac = np.nan_to_num(nll_ac) return nll_ac
[docs] def neglog_pdf_au( self, forward_map_evals: dict, nll_utils: dict, y: np.ndarray, ) -> np.ndarray: nll_au = cost_au( y=y, f_Theta=forward_map_evals["f_Theta"], m_a=nll_utils["m_a"], s_a=nll_utils["s_a"], ) # add a constant to `nll_au` to ensure that it is positive # nll_au -= np.log(nll_utils["sigma_a"]) # nll_au = np.nan_to_num(nll_au) return nll_au
[docs] def neglog_pdf_mc( self, forward_map_evals: dict, nll_utils: dict, log_omega: np.ndarray, ) -> np.ndarray: z = log_omega - forward_map_evals["log_f_Theta"] - nll_utils["m_m"] z /= nll_utils["s_m"] # z = np.nan_to_num(z) nll_mc = -log_ndtr(z) nll_mc = np.nan_to_num(nll_mc) return nll_mc
[docs] def neglog_pdf_mu( self, forward_map_evals: dict, nll_utils: dict, log_y: np.ndarray, ) -> np.ndarray: nll_mu = cost_mu( log_y=log_y, log_f_Theta=forward_map_evals["log_f_Theta"], m_m=nll_utils["m_m"], s_m=nll_utils["s_m"], ) # nll_mu -= np.log(nll_utils["sigma_m"]) return nll_mu
[docs] def gradient_neglog_pdf( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: """[summary] [extended_summary] Parameters ---------- x : np.ndarray of shape (N, D) [description] f_Theta : np.ndarray of shape (N, L), optional image of x via forward map, by default None grad_f_Theta : np.ndarray of shape (N, D, L), optional [description], by default None Returns ------- np.ndarray of shape (N, D, L) [description] """ grad_ = np.where( (self.y == self.omega)[:, None, :], # censored # grad_nll_ac, nll_utils["lambda_"][:, None, :] * nll_utils["grad_nll_ac"] + nll_utils["grad_lambda_"] * nll_utils["nll_ac"][:, None, :] + (1 - nll_utils["lambda_"])[:, None, :] * nll_utils["grad_nll_mc"] - nll_utils["grad_lambda_"] * nll_utils["nll_mc"][:, None, :], # uncensored nll_utils["lambda_"][:, None, :] * nll_utils["grad_nll_au"] + nll_utils["grad_lambda_"] * nll_utils["nll_au"][:, None, :] + (1 - nll_utils["lambda_"])[:, None, :] * nll_utils["grad_nll_mu"] - nll_utils["grad_lambda_"] * nll_utils["nll_mu"][:, None, :], ) # (N, D, L) # grad_ = np.where( # np.isfinite(grad_), # grad_, # np.abs(grad_).max(axis=0)[None, :, :], # ) # (N, D, L) grad_ = np.nan_to_num(grad_) return np.sum(grad_, axis=2) # / (self.N * self.L)
[docs] def gradient_neglog_pdf_ac( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: f_Theta_m_a_omega = forward_map_evals["f_Theta"] + nll_utils["m_a"] - self.omega u_1 = utils.norm_pdf_cdf_ratio(-f_Theta_m_a_omega / nll_utils["s_a"]) # (N, L) u_2 = (1 / nll_utils["s_a2"])[:, None, :] * ( (forward_map_evals["grad_f_Theta"] + nll_utils["grad_m_a"]) * nll_utils["s_a"][:, None, :] - f_Theta_m_a_omega[:, None, :] * nll_utils["grad_s_a"] ) # (N, D, L) grad_ = u_1[:, None, :] * u_2 # (N, D, L) # grad_ = np.nan_to_num(grad_) # assert np.sum(np.isnan(grad_)) == 0, f"grad_ac : {np.sum(np.isnan(grad_))}" return grad_
[docs] def gradient_neglog_pdf_au( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: g_au = gradient_cost_au( y=self.y, grad_f_Theta=forward_map_evals["grad_f_Theta"], f_Theta=forward_map_evals["f_Theta"], grad_m_a=nll_utils["grad_m_a"], m_a=nll_utils["m_a"], grad_s_a2=nll_utils["grad_s_a2"], s_a2=nll_utils["s_a2"], ) # g_au = np.nan_to_num(g_au) return g_au
[docs] def gradient_neglog_pdf_mc( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: log_f_Theta_m_m_log_omega = ( forward_map_evals["log_f_Theta"] + nll_utils["m_m"] - np.log(self.omega) ) u_1 = utils.norm_pdf_cdf_ratio( -log_f_Theta_m_m_log_omega / nll_utils["s_m"] ) # (N, L) u_2 = (1 / nll_utils["s_m2"])[:, None, :] * ( (forward_map_evals["grad_log_f_Theta"] + nll_utils["grad_m_m"]) * nll_utils["s_m"][:, None, :] - log_f_Theta_m_m_log_omega[:, None, :] * nll_utils["grad_s_m"] ) # (N, D, L) grad_ = u_1[:, None, :] * u_2 # (N, D, L) # grad_ = np.nan_to_num(grad_) # assert np.sum(np.isnan(grad_)) == 0, f"grad_ac : {np.sum(np.isnan(grad_))}" return grad_
[docs] def gradient_neglog_pdf_mu( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: g_mu = gradient_cost_mu( log_y=self.log_y, grad_log_f_Theta=forward_map_evals["grad_log_f_Theta"], log_f_Theta=forward_map_evals["log_f_Theta"], grad_m_m=nll_utils["grad_m_m"], m_m=nll_utils["m_m"], grad_s_m2=nll_utils["grad_s_m2"], s_m2=nll_utils["s_m2"], ) # g_mu = np.nan_to_num(g_mu) return g_mu
[docs] def hess_diag_neglog_pdf( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: r"""[summary] [extended_summary] Parameters ---------- x : np.ndarray of shape (N, D) [description] f_Theta : np.ndarray of shape (N, L), optional [description], by default None grad_f_Theta : np.ndarray of shape (N, D, L), optional [description], by default None hess_diag_f_Theta : np.ndarray of shape (N, D, L), optional [description], by default None Returns ------- np.ndarray of shape (N, D, L) [description] """ hess_diag = np.where( (self.y == self.omega)[:, None, :], # censored # hess_diag_ac, ( nll_utils["lambda_"][:, None, :] * nll_utils["hess_diag_ac"] + nll_utils["hess_diag_lambda_"] * nll_utils["nll_ac"][:, None, :] + 2 * nll_utils["grad_lambda_"] * nll_utils["grad_nll_ac"] # + (1 - nll_utils["lambda_"])[:, None, :] * nll_utils["hess_diag_mc"] - nll_utils["hess_diag_lambda_"] * nll_utils["nll_mc"][:, None, :] - 2 * nll_utils["grad_lambda_"] * nll_utils["grad_nll_mc"] ), # uncensored ( nll_utils["lambda_"][:, None, :] * nll_utils["hess_diag_au"] + nll_utils["hess_diag_lambda_"] * nll_utils["nll_au"][:, None, :] + 2 * nll_utils["grad_lambda_"] * nll_utils["grad_nll_au"] # + (1 - nll_utils["lambda_"])[:, None, :] * nll_utils["hess_diag_mu"] - nll_utils["hess_diag_lambda_"] * nll_utils["nll_mu"][:, None, :] - 2 * nll_utils["grad_lambda_"] * nll_utils["grad_nll_mu"] ), ) # (N, D, L) # hess_diag = np.where( # np.isfinite(hess_diag), # hess_diag, # np.abs(hess_diag).max(axis=0)[None, :, :], # ) # (N, D, L) hess_diag = np.nan_to_num(hess_diag) return np.sum(hess_diag, axis=2) # / (self.N * self.L)
[docs] def hess_diag_neglog_pdf_ac( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: f_Theta_m_a_omega = forward_map_evals["f_Theta"] + nll_utils["m_a"] - self.omega grad_f_Theta_grad_m_a = ( forward_map_evals["grad_f_Theta"] + nll_utils["grad_m_a"] ) u_1 = utils.norm_pdf_cdf_ratio(-f_Theta_m_a_omega / nll_utils["s_a"]) # (N, L) assert u_1.shape == (self.N, self.L) u_2 = (1 / nll_utils["s_a2"])[:, None, :] * ( grad_f_Theta_grad_m_a * nll_utils["s_a"][:, None, :] - f_Theta_m_a_omega[:, None, :] * nll_utils["grad_s_a"] ) # (N, D, L) assert u_2.shape == (self.N, self.D, self.L) grad_u_1 = ( u_2 * (u_1 * (-f_Theta_m_a_omega / nll_utils["s_a"] + u_1))[:, None, :] ) # (N, D, L) assert grad_u_1.shape == (self.N, self.D, self.L) grad_u_2 = (1 / nll_utils["s_a2"] ** 2)[:, None, :] * ( nll_utils["s_a2"][:, None, :] * ( nll_utils["s_a"][:, None, :] * (forward_map_evals["hess_diag_f_Theta"] + nll_utils["hess_diag_m_a"]) - f_Theta_m_a_omega[:, None, :] * nll_utils["hess_diag_s_a"] ) - nll_utils["grad_s_a2"] * ( nll_utils["s_a"][:, None, :] * grad_f_Theta_grad_m_a - f_Theta_m_a_omega[:, None, :] * nll_utils["grad_s_a"] ) ) # (N, D, L) assert grad_u_2.shape == (self.N, self.D, self.L) hess_diag = grad_u_1 * u_2 + u_1[:, None, :] * grad_u_2 # hess_diag = np.nan_to_num(hess_diag) return hess_diag
[docs] def hess_diag_neglog_pdf_au( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: f_Theta_m_a_y = ( forward_map_evals["f_Theta"] + nll_utils["m_a"] - self.y ) # (N, L) grad_f_Theta_grad_m_a = ( forward_map_evals["grad_f_Theta"] + nll_utils["grad_m_a"] ) # (N, D, L) u_1 = f_Theta_m_a_y / nll_utils["s_a2"] ** 2 # (N, L) grad_u_1 = (1 / nll_utils["s_a2"] ** 4)[:, None, :] * ( grad_f_Theta_grad_m_a * (nll_utils["s_a2"] ** 2)[:, None, :] - 2 * (f_Theta_m_a_y * nll_utils["s_a2"])[:, None, :] * nll_utils["grad_s_a2"] ) # (N, D, L) u_2 = ( 2 * nll_utils["s_a2"][:, None, :] * grad_f_Theta_grad_m_a - f_Theta_m_a_y[:, None, :] * nll_utils["grad_s_a2"] ) # (N, D, L) grad_u_2 = ( 2 * nll_utils["s_a2"][:, None, :] * (forward_map_evals["hess_diag_f_Theta"] + nll_utils["hess_diag_m_a"]) # + nll_utils["grad_s_a2"] * grad_f_Theta_grad_m_a # - f_Theta_m_a_y[:, None, :] * nll_utils["hess_diag_s_a2"] ) # (N, D, L) hess_diag = ( 0.5 * ( nll_utils["hess_diag_s_a2"] * nll_utils["s_a2"][:, None, :] - nll_utils["grad_s_a2"] ** 2 ) / (nll_utils["s_a2"] ** 2)[:, None, :] ) # (N, D, L) hess_diag += 0.5 * ( grad_u_1 * u_2 + u_1[:, None, :] * grad_u_2 ) # hess_diag = np.nan_to_num(hess_diag) return hess_diag
[docs] def hess_diag_neglog_pdf_mc( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: log_f_Theta_m_m_log_omega = ( forward_map_evals["log_f_Theta"] + nll_utils["m_m"] - self.omega ) grad_log_f_Theta_grad_m_m = ( forward_map_evals["grad_log_f_Theta"] + nll_utils["grad_m_m"] ) u_1 = utils.norm_pdf_cdf_ratio( -log_f_Theta_m_m_log_omega / nll_utils["s_m"] ) # (N, L) assert u_1.shape == (self.N, self.L) u_2 = (1 / nll_utils["s_m2"])[:, None, :] * ( grad_log_f_Theta_grad_m_m * nll_utils["s_m"][:, None, :] - log_f_Theta_m_m_log_omega[:, None, :] * nll_utils["grad_s_m"] ) # (N, D, L) assert u_2.shape == (self.N, self.D, self.L) grad_u_1 = ( u_2 * (u_1 * (-log_f_Theta_m_m_log_omega / nll_utils["s_m"] + u_1))[:, None, :] ) # (N, D, L) assert grad_u_1.shape == (self.N, self.D, self.L) grad_u_2 = (1 / nll_utils["s_m2"] ** 2)[:, None, :] * ( nll_utils["s_m2"][:, None, :] * ( nll_utils["s_m"][:, None, :] * (forward_map_evals["hess_diag_f_Theta"] + nll_utils["hess_diag_m_m"]) - log_f_Theta_m_m_log_omega[:, None, :] * nll_utils["hess_diag_s_m"] ) - nll_utils["grad_s_m2"] * ( nll_utils["s_m"][:, None, :] * grad_log_f_Theta_grad_m_m - log_f_Theta_m_m_log_omega[:, None, :] * nll_utils["grad_s_m"] ) ) # (N, D, L) assert grad_u_2.shape == (self.N, self.D, self.L) hess_diag = grad_u_1 * u_2 + u_1[:, None, :] * grad_u_2 # hess_diag = np.nan_to_num(hess_diag) return hess_diag
[docs] def hess_diag_neglog_pdf_mu( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: log_f_Theta_m_m_log_y = ( forward_map_evals["log_f_Theta"] + nll_utils["m_m"] - self.log_y ) # (N, L) grad_log_f_Theta_grad_m_m = ( forward_map_evals["grad_log_f_Theta"] + nll_utils["grad_m_m"] ) # (N, D, L) u_1 = log_f_Theta_m_m_log_y / nll_utils["s_m2"] ** 2 # (N, L) grad_u_1 = (1 / nll_utils["s_m2"] ** 4)[:, None, :] * ( grad_log_f_Theta_grad_m_m * (nll_utils["s_m2"] ** 2)[:, None, :] - 2 * (log_f_Theta_m_m_log_y * nll_utils["s_m2"])[:, None, :] * nll_utils["grad_s_m2"] ) # (N, D, L) u_2 = ( 2 * nll_utils["s_m2"][:, None, :] * grad_log_f_Theta_grad_m_m - log_f_Theta_m_m_log_y[:, None, :] * nll_utils["grad_s_m2"] ) # (N, D, L) grad_u_2 = 2 * ( nll_utils["grad_s_m2"] * grad_log_f_Theta_grad_m_m + nll_utils["s_m2"][:, None, :] * (forward_map_evals["hess_diag_log_f_Theta"] + nll_utils["hess_diag_m_m"]) ) - ( nll_utils["grad_s_m2"] * grad_log_f_Theta_grad_m_m + log_f_Theta_m_m_log_y[:, None, :] * nll_utils["hess_diag_s_m2"] ) # (N, D, L) hess_diag = ( 0.5 * ( nll_utils["hess_diag_s_m2"] * nll_utils["s_m2"][:, None, :] - nll_utils["grad_s_m2"] ** 2 ) / (nll_utils["s_m2"] ** 2)[:, None, :] ) # (N, D, L) hess_diag += 0.5 * ( grad_u_1 * u_2 + u_1[:, None, :] * grad_u_2 ) # hess_diag = np.nan_to_num(hess_diag) return hess_diag
[docs] def evaluate_all_nll_utils( self, forward_map_evals: dict, idx: Optional[np.ndarray] = None, compute_derivatives: bool = True, compute_derivatives_2nd_order: bool = True, ) -> dict: assert ( np.sum(np.isnan(forward_map_evals["log_f_Theta"])) == 0 ), f"before entering bias and std : {np.sum(np.isnan(forward_map_evals['log_f_Theta']))}" nll_utils = {} # * bias and variance if idx is None: N_pix = self.N * 1 sigma_a = self.sigma_a * 1 sigma_m = self.sigma_m * 1 y = self.y * 1 log_y = self.log_y * 1 omega = self.omega * 1 else: n_pix = idx.size k_mtm = forward_map_evals["f_Theta"].shape[0] // n_pix N_pix = forward_map_evals["f_Theta"].shape[0] assert n_pix * k_mtm == N_pix sigma_a = np.zeros((n_pix, k_mtm, self.L)) sigma_m = np.zeros((n_pix, k_mtm, self.L)) y = np.zeros((n_pix, k_mtm, self.L)) omega = np.zeros((n_pix, k_mtm, self.L)) for i_pix in range(n_pix): sigma_a[i_pix, :, :] = self.sigma_a[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) sigma_m[i_pix, :, :] = self.sigma_m[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) y[i_pix, :, :] = self.y[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) omega[i_pix, :, :] = self.omega[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) sigma_a = ( sigma_a.transpose((2, 0, 1)).reshape((self.L, N_pix)).T ) # .reshape((N_pix, self.L)) sigma_m = ( sigma_m.transpose((2, 0, 1)).reshape((self.L, N_pix)).T ) # .reshape((N_pix, self.L)) y = ( y.transpose((2, 0, 1)).reshape((self.L, N_pix)).T ) # .reshape((N_pix, self.L)) log_y = np.log(y) omega = ( omega.transpose((2, 0, 1)).reshape((self.L, N_pix)).T ) # .reshape((N_pix, self.L)) omega # * ----- nll_utils["censored_mask"] = (y <= omega) * 1 nll_utils["sigma_a"] = sigma_a * 1 nll_utils["sigma_m"] = sigma_m * 1 sigma_m2 = sigma_m**2 log_combination = ( np.log(sigma_a) - forward_map_evals["log_f_Theta"] - sigma_m2 / 2 ) exp_m2_log_combin = np.exp(-2 * log_combination) exp_sigma_m_squared = np.exp(sigma_m2) # exp_sigma_m_squared_div2_m1 = np.exp(sigma_m2 / 2) - 1 # (N, L) # * computation of bias and variances nll_utils["m_a"] = np.zeros_like(y) nll_utils["s_a"] = sigma_a * np.sqrt( (exp_sigma_m_squared - 1) * np.exp(2 * (forward_map_evals["log_f_Theta"] - np.log(sigma_a))) + 1 ) nll_utils["s_a2"] = nll_utils["s_a"] ** 2 nll_utils["m_m"] = -0.5 * (sigma_m2 + np.log(1 + 1 / exp_m2_log_combin)) nll_utils["s_m2"] = -2 * nll_utils["m_m"] nll_utils["s_m"] = np.sqrt(nll_utils["s_m2"]) assert ( nll_utils["s_m"].min() > 0 ), f"{nll_utils['s_m'].min()}, {(sigma_m2 - 2 * nll_utils['m_m']).min()}" if compute_derivatives: nll_utils["grad_m_a"] = np.zeros((N_pix, self.D, self.L)) if compute_derivatives_2nd_order: nll_utils["hess_diag_m_a"] = np.zeros((N_pix, self.D, self.L)) nll_utils["grad_s_a2"] = ( 2 * (forward_map_evals["f_Theta"] * (exp_sigma_m_squared - 1))[:, None, :] * forward_map_evals["grad_f_Theta"] ) # (N, D, L) if compute_derivatives_2nd_order: nll_utils["hess_diag_s_a2"] = ( 2 * (exp_sigma_m_squared - 1)[:, None, :] * ( forward_map_evals["f_Theta"][:, None, :] * forward_map_evals["hess_diag_f_Theta"] + forward_map_evals["grad_f_Theta"] ** 2 ) ) # (N, D, L) nll_utils["grad_s_a"] = ( 1 / (2 * nll_utils["s_a"])[:, None, :] ) * nll_utils["grad_s_a2"] if compute_derivatives_2nd_order: nll_utils["hess_diag_s_a"] = (1 / (2 * nll_utils["s_a2"]))[ :, None, : ] * ( nll_utils["hess_diag_s_a2"] * nll_utils["s_a"][:, None, :] - nll_utils["grad_s_a2"] * nll_utils["grad_s_a"] ) # (N, D, L) nll_utils["grad_m_m"] = ( 1 / (1 + exp_m2_log_combin) / forward_map_evals["f_Theta"] )[:, None, :] * forward_map_evals[ "grad_f_Theta" ] # (N, D, L) if compute_derivatives_2nd_order: nll_utils["hess_diag_m_m"] = ( 1 / (forward_map_evals["f_Theta"] * (1 + exp_m2_log_combin)) ** 2 )[:, None, :] * ( forward_map_evals["hess_diag_f_Theta"] * (forward_map_evals["f_Theta"] * (1 + exp_m2_log_combin))[ :, None, : ] - forward_map_evals["grad_f_Theta"] ** 2 * (1 + 3 * exp_m2_log_combin)[:, None, :] ) # (N, D, L) nll_utils["grad_s_m2"] = -2 * nll_utils["grad_m_m"] # (N, D, L) if compute_derivatives_2nd_order: nll_utils["hess_diag_s_m2"] = ( -2 * nll_utils["hess_diag_m_m"] ) # (N, D, L) nll_utils["grad_s_m"] = (1 / (2 * nll_utils["s_m"]))[ :, None, : ] * nll_utils[ "grad_s_m2" ] # (N, D, L) if compute_derivatives_2nd_order: nll_utils["hess_diag_s_m"] = (1 / (2 * nll_utils["s_m2"]))[ :, None, : ] * ( nll_utils["hess_diag_s_m2"] * nll_utils["s_m"][:, None, :] - nll_utils["grad_s_m"] * nll_utils["grad_s_m2"] ) # (N, D, L) # assert np.sum(np.isnan(forward_map_evals["f_Theta"])) == 0.0 # assert ( # forward_map_evals["f_Theta"].min() > 0.0 # ), f"{forward_map_evals['f_Theta'].min()}, {forward_map_evals['log_f_Theta'].min()}" # assert np.sum(np.isnan(forward_map_evals["grad_f_Theta"])) == 0.0 # assert np.sum(np.isnan(forward_map_evals["hess_diag_f_Theta"])) == 0.0 # assert np.sum(np.isnan(forward_map_evals["log_f_Theta"])) == 0.0 # assert np.sum(np.isnan(forward_map_evals["grad_log_f_Theta"])) == 0.0 # assert np.sum(np.isnan(forward_map_evals["hess_diag_log_f_Theta"])) == 0.0 # assert np.sum(np.isnan(m_a)) == 0.0 # assert np.sum(np.isnan(grad_m_a)) == 0.0 # assert np.sum(np.isnan(hess_diag_m_a)) == 0.0 # assert m_a.shape == (self.N, self.L) # assert grad_m_a.shape == (self.N, self.D, self.L) # assert hess_diag_m_a.shape == (self.N, self.D, self.L) # assert np.sum(np.isnan(s_a2)) == 0.0 # assert np.all(s_a2 > 0) # assert np.sum(np.isnan(grad_s_a2)) == 0.0 # assert np.sum(np.isnan(hess_diag_s_a2)) == 0.0 # assert s_a2.shape == (self.N, self.L) # assert grad_s_a2.shape == (self.N, self.D, self.L) # assert hess_diag_s_a2.shape == (self.N, self.D, self.L) # assert np.sum(np.isnan(s_a)) == 0.0 # assert np.all(s_a > 0) # assert np.sum(np.isnan(grad_s_a)) == 0.0 # assert np.sum(np.isnan(hess_diag_s_a)) == 0.0 # assert s_a.shape == (self.N, self.L) # assert grad_s_a.shape == (self.N, self.D, self.L) # assert hess_diag_s_a.shape == (self.N, self.D, self.L) # assert np.sum(np.isnan(m_m)) == 0.0 # assert np.sum(np.isnan(grad_m_m)) == 0.0 # assert np.sum(np.isnan(hess_diag_m_m)) == 0.0 # assert m_m.shape == (self.N, self.L) # assert grad_m_m.shape == (self.N, self.D, self.L) # assert hess_diag_m_m.shape == (self.N, self.D, self.L) # assert np.sum(np.isnan(s_m2)) == 0.0 # assert np.sum(np.isnan(grad_s_m2)) == 0.0 # assert np.sum(np.isnan(hess_diag_s_m2)) == 0.0 # assert s_m2.shape == (self.N, self.L) # assert grad_s_m2.shape == (self.N, self.D, self.L) # assert hess_diag_s_m2.shape == (self.N, self.D, self.L) # assert np.sum(np.isnan(s_m)) == 0.0 # assert np.sum(np.isnan(grad_s_m)) == 0.0 # assert np.sum(np.isnan(hess_diag_s_m)) == 0.0 # assert s_m.shape == (self.N, self.L) # assert grad_s_m.shape == (self.N, self.D, self.L) # assert hess_diag_s_m.shape == (self.N, self.D, self.L) # * mixture weight nll_utils["lambda_"] = self.model_mixing_param(forward_map_evals, idx) if compute_derivatives: nll_utils["grad_lambda_"] = self.grad_model_mixing_param(forward_map_evals) if compute_derivatives_2nd_order: nll_utils["hess_diag_lambda_"] = self.hess_diag_model_mixing_param( forward_map_evals ) nll_utils["nll_au"] = self.neglog_pdf_au(forward_map_evals, nll_utils, y) nll_utils["nll_ac"] = self.neglog_pdf_ac(forward_map_evals, nll_utils, y) nll_utils["nll_mu"] = self.neglog_pdf_mu(forward_map_evals, nll_utils, log_y) nll_utils["nll_mc"] = self.neglog_pdf_mc( forward_map_evals, nll_utils, np.log(omega) ) if compute_derivatives: nll_utils["grad_nll_au"] = self.gradient_neglog_pdf_au( forward_map_evals, nll_utils ) nll_utils["grad_nll_mu"] = self.gradient_neglog_pdf_mu( forward_map_evals, nll_utils ) nll_utils["grad_nll_ac"] = self.gradient_neglog_pdf_ac( forward_map_evals, nll_utils ) nll_utils["grad_nll_mc"] = self.gradient_neglog_pdf_mc( forward_map_evals, nll_utils ) if compute_derivatives_2nd_order: nll_utils["hess_diag_ac"] = self.hess_diag_neglog_pdf_ac( forward_map_evals, nll_utils ) nll_utils["hess_diag_au"] = self.hess_diag_neglog_pdf_au( forward_map_evals, nll_utils ) nll_utils["hess_diag_mc"] = self.hess_diag_neglog_pdf_mc( forward_map_evals, nll_utils ) nll_utils["hess_diag_mu"] = self.hess_diag_neglog_pdf_mu( forward_map_evals, nll_utils ) return nll_utils