Source code for beetroots.modelling.likelihoods.gaussian_censored

"""Implementation of Gaussian likelihood with censorship (with a lower limit)
"""

from typing import Optional, Union

import numpy as np
from scipy.stats import norm as statsnorm

from beetroots.modelling.likelihoods import utils
from beetroots.modelling.likelihoods.abstract_likelihood import Likelihood


[docs] class CensoredGaussianLikelihood(Likelihood): r"""Class implementing a Gaussian likelihood model with lower censorship""" __slots__ = ( "forward_map", "D", "L", "N", "y", "sigma", "omega", "bias", ) def __init__( self, forward_map, D: int, L: int, N: int, y: np.ndarray, sigma: Union[float, np.ndarray], omega: Union[float, np.ndarray], bias: Union[float, np.ndarray] = 0.0, ) -> None: """Constructor of the GaussianLikelihood object. Parameters ---------- forward_map : ForwardMap instance forward map D : int number of disinct physical parameters in input space. L : int number of distinct observed physical parameters. N : int number of pixels in each physical dimension y : np.ndarray of shape (N, L) mean of the gaussian distribution bias : float or np.ndarray of shape (N, L) variance of the Gaussian distribution sigma : float or np.ndarray of shape (N, L) variance of the Gaussian distribution omega : float or np.ndarray of shape (N, L) censorship threshold Raises ------ ValueError y must have the shape (N, L) """ super().__init__(forward_map, D, L, N, y) # ! trigger an error is the mean y contains less than N elements if not (y.shape == (N, L)): raise ValueError( "y must have the shape (N, L) = ({}, {}) elements".format( self.N, self.L ) ) if isinstance(sigma, (float, int)): self.sigma = sigma * np.ones((N, L)) else: assert sigma.shape == (N, L) self.sigma = sigma if isinstance(omega, (float, int)): self.omega = omega * np.ones((N, L)) else: assert omega.shape == (N, L) self.omega = omega if isinstance(bias, (float, int)): self.bias = bias * np.ones((N, L)) else: assert bias.shape == (N, L) self.bias = bias
[docs] def neglog_pdf( self, forward_map_evals: dict, nll_utils: dict, pixelwise: bool = False, full: bool = False, idx: Optional[np.ndarray] = None, ) -> Union[float, np.ndarray]: r"""[summary] .. math:: p(y_{n,\ell} \vert \theta_n) \propto \exp \left\{- [y_{n,\ell} = \omega] \Phi( \frac{\omega - f_{\ell}(\theta_n)}{\sigma^2} \right) - [y_{n,\ell} > \omega] \frac{\omega - f_{\ell}(\theta_n)}{\sigma^2} \right\} """ if idx is None: N_pix = self.N * 1 y = self.y * 1 sigma = self.sigma * 1 omega = self.omega * 1 bias = self.bias * 1 else: n_pix = idx.size k_mtm = forward_map_evals["f_Theta"].shape[0] // n_pix N_pix = forward_map_evals["f_Theta"].shape[0] assert n_pix * k_mtm == N_pix y = np.zeros((n_pix, k_mtm, self.L)) sigma = np.zeros((n_pix, k_mtm, self.L)) omega = np.zeros((n_pix, k_mtm, self.L)) bias = np.zeros((n_pix, k_mtm, self.L)) for i_pix in range(n_pix): y[i_pix, :, :] = self.y[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) sigma[i_pix, :, :] = self.sigma[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) omega[i_pix, :, :] = self.omega[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) bias[i_pix, :, :] = self.bias[idx[i_pix], :][None, :] * np.ones( (k_mtm, self.L) ) y = y.reshape((N_pix, self.L)) sigma = sigma.reshape((N_pix, self.L)) omega = omega.reshape((N_pix, self.L)) bias = bias.reshape((N_pix, self.L)) nlpdf = np.where( y <= omega, self.neglog_pdf_ac(forward_map_evals, nll_utils, y, sigma, omega, bias), self.neglog_pdf_au(forward_map_evals, nll_utils, y, sigma, omega, bias), ) # (N_pix, L) if full: return nlpdf # (N_pix, L) if pixelwise: return np.sum(nlpdf, axis=1) # (N_pix,) return np.sum(nlpdf)
[docs] def neglog_pdf_ac( self, forward_map_evals: dict, nll_utils: dict, y: np.ndarray, sigma: np.ndarray, omega: np.ndarray, bias: np.ndarray, ) -> np.ndarray: return -statsnorm.logcdf((omega - forward_map_evals["f_Theta"] - bias) / sigma)
[docs] def neglog_pdf_au( self, forward_map_evals: dict, nll_utils: dict, y: np.ndarray, sigma: np.ndarray, omega: np.ndarray, bias: np.ndarray, ) -> np.ndarray: return (forward_map_evals["f_Theta"] + bias - y) ** 2 / (2 * sigma**2)
[docs] def gradient_neglog_pdf( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: """[summary] [extended_summary] Parameters ---------- x : np.ndarray of shape (N, D) [description] f_Theta : np.ndarray of shape (N, L), optional image of x via forward map, by default None grad_f_Theta : np.ndarray of shape (N, D, L), optional [description], by default None Returns ------- np.ndarray of shape (N, D) [description] """ # if f_Theta is None: # f_Theta = self.forward_map.evaluate(x) # (N, L) # if grad_f_Theta is None: # grad_f_Theta = self.forward_map.gradient(x) # (N, D, L) grad_ = np.where( (self.y == self.omega)[:, None, :], self.gradient_neglog_pdf_ac(forward_map_evals, nll_utils), self.gradient_neglog_pdf_au(forward_map_evals, nll_utils), ) # (N, D, L) # ! issue: do not sum over L if L = D (i.e. identity forward_map) if not self.D == self.L: grad_ = np.sum(grad_, axis=2) # (N, D) return grad_
[docs] def gradient_neglog_pdf_ac( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: grad_ = ( forward_map_evals["grad_f_Theta"] * ( utils.norm_pdf_cdf_ratio( (self.omega - forward_map_evals["f_Theta"] - self.bias) / self.sigma ) / self.sigma )[:, None, :] ) return grad_ # (N, D, L)
[docs] def gradient_neglog_pdf_au( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: grad_ = ( forward_map_evals["grad_f_Theta"] * ((forward_map_evals["f_Theta"] + self.bias - self.y) / self.sigma**2)[ :, None, : ] ) return grad_ # (N, D, L)
[docs] def hess_diag_neglog_pdf( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: r"""[summary] [extended_summary] Parameters ---------- x : np.ndarray of shape (N, D) [description] f_Theta : np.ndarray of shape (N, L), optional [description], by default None grad_f_Theta : np.ndarray of shape (N, D, L), optional [description], by default None hess_diag_f_Theta : np.ndarray of shape (N, D, L), optional [description], by default None Returns ------- np.ndarray of shape (N, D, L) [description] """ hess_diag = np.where( (self.y == self.omega)[:, None, :], self.hess_diag_neglog_pdf_ac(forward_map_evals, nll_utils), self.hess_diag_neglog_pdf_au(forward_map_evals, nll_utils), ) # (N, D, L) # ! issue: do not sum over L if L = D (i.e. identity forward_map) if not self.D == self.L: hess_diag = np.sum(hess_diag, axis=2) # (N, D) return hess_diag
[docs] def hess_diag_neglog_pdf_ac( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: hess_diag = ( utils.norm_pdf_cdf_ratio( (self.omega - forward_map_evals["f_Theta"]) / self.sigma ) / self.sigma )[:, None, :] * ( forward_map_evals["hess_diag_f_Theta"] + forward_map_evals["grad_f_Theta"] ** 2 * ( ( (self.omega - forward_map_evals["f_Theta"] - self.bias) / self.sigma + utils.norm_pdf_cdf_ratio( (self.omega - forward_map_evals["f_Theta"] - self.bias) / self.sigma ) ) / self.sigma )[:, None, :] ) return hess_diag # (N, D, L)
[docs] def hess_diag_neglog_pdf_au( self, forward_map_evals: dict, nll_utils: dict ) -> np.ndarray: return (1 / self.sigma**2)[:, None, :] * ( forward_map_evals["grad_f_Theta"] ** 2 + forward_map_evals["hess_diag_f_Theta"] * (forward_map_evals["f_Theta"] + self.bias - self.y)[:, None, :] ) # (N, D, L)
[docs] def evaluate_all_nll_utils( self, forward_map_evals: dict, idx: Optional[np.ndarray] = None, compute_derivatives: bool = True, compute_derivatives_2nd_order: bool = True, ) -> dict: nll_utils = {} return nll_utils
[docs] def sample_observation_model( self, forward_map_evals: dict, rng: np.random.Generator = np.random.default_rng(), ) -> np.ndarray: return np.maximum( self.omega, forward_map_evals["f_Theta"] + rng.normal(loc=0.0, scale=self.sigma), )
[docs] def evaluate_all_forward_map( self, Theta: np.ndarray, compute_derivatives: bool, compute_derivatives_2nd_order: bool = True, ) -> dict: assert len(Theta.shape) == 2 and Theta.shape[1] == self.D forward_map_evals = self.forward_map.compute_all( Theta, compute_lin=True, compute_log=False, compute_derivatives=compute_derivatives, compute_derivatives_2nd_order=compute_derivatives_2nd_order, ) return forward_map_evals