Source code for beetroots.modelling.priors.l22_prior

"""(deprecated) Implementation of L2 norm on image gradient prior"""

import numpy as np

from beetroots.modelling.priors.abstract_spatial_prior import SpatialPrior


[docs] class L22SpatialPrior(SpatialPrior): r"""L22 smooth spatial prior, valid for both 1D and 2D tensors. Its pdf is defined as .. math:: \forall d \in [1, D], \quad p(x_d) \propto \exp \left[- \tau_d \Vert \nabla Theta_d \Vert_F^2 \right] where :math:`\Vert \cdot \Vert_F` denotes the Fröbenius norm """
[docs] def neglog_pdf( self, Theta: np.ndarray, with_weights: bool = True, **kwargs ) -> np.ndarray: assert Theta.shape == (self.N, self.D) neglog_p = np.zeros((self.D,)) if self.list_edges.size > 0: neglog_p += np.sum( (Theta[self.list_edges[:, 1], :] - Theta[self.list_edges[:, 0], :]) ** 2, axis=0, ) if with_weights: neglog_p *= self.weights return neglog_p # (D,)
[docs] def neglog_pdf_one_pix( self, Theta: np.ndarray, n: int, list_pixel_candidates: np.ndarray, ) -> np.ndarray: """ computes the neg log-prior when only one pixel is modified Parameters ---------- Theta : np.ndarray of shape (N, D) current iterate n : int the index of the pixel to consider (0 <= n <= N - 1) list_pixel_candidates : np.ndarray of shape (N_candidates, D) the list of all candidates for pixel n Returns ------- nl_priors : np.ndarray of shape (N_candidates,) the leg log-prior of the candidates """ assert Theta.shape == (self.N, self.D) assert 0 <= n <= self.N - 1 # TODO : à reprendre N_candidates = list_pixel_candidates.shape[0] neglog_p = np.zeros_like(list_pixel_candidates) # (N_candidates, D) list_edges_pix = self.list_edges[(self.list_edges[:, 0] == n)] if list_edges_pix.size > 0: neglog_p += ( Theta[list_edges_pix[:, 1], :] - Theta[list_edges_pix[:, 0], :] ) ** 2 # (N_candidates, D) neglog_p = neglog_p * self.weights[None, :] neglog_p = np.sum(neglog_p, axis=1) return neglog_p # (D,)
[docs] def gradient_neglog_pdf(self, Theta: np.ndarray) -> np.ndarray: assert Theta.shape == (self.N, self.D) g = np.zeros_like(Theta) for edge in self.list_edges: val = 2 * (Theta[edge[1]] - Theta[edge[0]]) # (D,) g[edge[0]] -= val g[edge[1]] += val g = self.weights[None, :] * g return g # (N, D)
[docs] def hess_diag_neglog_pdf(self, Theta: np.ndarray) -> np.ndarray: hess_diag = np.zeros_like(Theta, dtype=np.float64) if self.list_edges.size > 0: idx, counts = np.unique(self.list_edges.flatten(), return_counts=True) # print(counts.dtype, hess_diag.dtype) hess_diag[idx, :] += 2 * counts[:, None] * np.ones((idx.size, self.D)) hess_diag = self.weights[None, :] * hess_diag return hess_diag # (N, D)