from typing import Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors
from beetroots.inversion.plots.map_shaper import MapShaper
from beetroots.inversion.plots.plots_2d import AbstractPlots2D
[docs]
class PlotsEstimator(AbstractPlots2D):
r"""utilitary class that draws and saves maps related to the inference results."""
__slots__ = (
"map_shaper",
"list_names",
"lower_bounds_lin",
"upper_bounds_lin",
"list_idx_sampling",
"pixels_of_interest_names",
"pixels_of_interest_coords",
)
def __init__(
self,
map_shaper: MapShaper,
list_names: List[str],
lower_bounds_lin: np.ndarray,
upper_bounds_lin: np.ndarray,
list_idx_sampling: List[int],
pixels_of_interest: Dict[int, str] = {},
):
super().__init__(map_shaper, pixels_of_interest)
self.list_names = list_names
r"""list: physical parameter names"""
self.lower_bounds_lin = lower_bounds_lin
r"""1D np.ndarray: contains the lower bounds in linear scale on the physical parameters"""
self.upper_bounds_lin = upper_bounds_lin
r"""1D np.ndarray: contains the upper bounds in linear scale on the physical parameters"""
self.list_idx_sampling = list_idx_sampling
r"""1D np.ndarray: contains the indices of the physical parameters to be sampled"""
[docs]
def plot_estimator(
self,
Theta_estimated: np.ndarray,
estimator_name: str,
folder_path: str,
model_name: Optional[str] = "",
):
"""plots and saves the 2D map of an estimated physical parameter
Parameters
----------
Theta_estimated : np.ndarray of shape (N, D)
vector of the estimated physical parameter
estimator_name : str
name of the estimator, e.g., "MAP" or "MMSE"
folder_path : str
path to the folder where the figure is to be saved
model_name : Optional[str], optional
name of the model (not used here, kept for compatibility), by default ""
"""
Theta_estimated_plot = self.map_shaper.from_vector_to_map(Theta_estimated)
for d, name in enumerate(self.list_names):
vmin = self.lower_bounds_lin[d] / 1.1
vmax = self.upper_bounds_lin[d] * 1.1
x_estimator_d_plot = Theta_estimated_plot[:, :, d] * 1
title = f"{estimator_name} for {name}"
if d not in self.list_idx_sampling:
title += " (fixed, not estimated)"
if vmin > 0:
plt.figure(figsize=(8, 6))
plt.title(title)
plt.imshow(
x_estimator_d_plot,
norm=colors.LogNorm(vmin, vmax),
origin="lower",
cmap="viridis",
)
plt.colorbar()
self._draw_rect_on_pixels_of_interest()
# plt.tight_layout()
filename = f"{folder_path}/{estimator_name}_{d}"
filename = filename.replace("%", "").replace(".", "p")
filename = filename.replace(" ", "_")
filename += ".PNG"
plt.savefig(
filename,
bbox_inches="tight",
)
plt.close()
# * same in linear scale
plt.figure(figsize=(8, 6))
plt.title(title)
plt.imshow(
x_estimator_d_plot,
origin="lower",
cmap="viridis",
)
plt.colorbar()
self._draw_rect_on_pixels_of_interest()
filename = f"{folder_path}/"
filename += f"{estimator_name}_linscale_{d}"
filename = filename.replace("%", "").replace(".", "p")
filename = filename.replace(" ", "_")
filename += ".PNG"
plt.savefig(filename, bbox_inches="tight")
plt.close()
[docs]
def plot_estimator_u(
self,
u_estimated: np.ndarray,
estimator_name: str,
folder_path: str,
model_name: Optional[str] = "",
list_lines: List[str] = [],
):
"""Only used in hierarchical models. The sampling of such model is not implemented."""
u_estimated_plot = self.map_shaper.from_vector_to_map(u_estimated)
for ell, line in enumerate(list_lines):
u_estimator_ell_plot = u_estimated_plot[:, :, ell] * 1
plt.figure(figsize=(8, 6))
plt.title(f"{estimator_name} for {line}")
plt.imshow(
u_estimator_ell_plot,
norm=colors.LogNorm(),
origin="lower",
cmap="viridis",
)
plt.colorbar()
self._draw_rect_on_pixels_of_interest()
# plt.tight_layout()
filename = f"{folder_path}/{estimator_name}_{ell}_{line}"
filename = filename.replace("%", "").replace(".", "p")
filename += ".PNG"
plt.savefig(filename, bbox_inches="tight")
plt.close()
# * same in linear scale
plt.figure(figsize=(8, 6))
plt.title(f"{estimator_name} for {line}")
plt.imshow(u_estimator_ell_plot, origin="lower", cmap="viridis")
plt.colorbar()
filename = f"{folder_path}/"
filename += f"{estimator_name}_linscale_{ell}_{line}"
filename = filename.replace("%", "").replace(".", "p")
filename += ".PNG"
plt.savefig(filename, bbox_inches="tight")
plt.close()
[docs]
def plot_CI_size(
self,
Theta_ci_size: np.ndarray,
CI_name: str,
folder_path: str,
) -> None:
r"""plots the map of credibility interval sizes for a physical parameter
Parameters
----------
Theta_ci_size : np.ndarray of shape (N, D)
vector of credibility interval sizes for the D physical parameters
CI_name : str
name of the credibility interval, e.g., "95\%" or "99\%"
folder_path : str
path to the folder where the figure is to be saved
"""
Theta_ci_size_plot = self.map_shaper.from_vector_to_map(Theta_ci_size)
for d, name in enumerate(self.list_names):
if d in self.list_idx_sampling:
x_ci_size_d_plot = Theta_ci_size_plot[:, :, d] * 1
if self.upper_bounds_lin[-1] - self.lower_bounds_lin[-1] > 20:
vmin = 1
vmax = None
# vmax = self.upper_bounds_lin[d] / self.lower_bounds_lin[d]
else:
vmin = 0
vmax = None
# vmax = self.upper_bounds_lin[d] - self.lower_bounds_lin[d]
plt.figure(figsize=(8, 6))
plt.title(f"{CI_name} for {name}")
plt.imshow(
x_ci_size_d_plot,
norm=colors.LogNorm(vmin, vmax),
origin="lower",
cmap="viridis",
)
plt.colorbar()
self._draw_rect_on_pixels_of_interest()
filename = f"{folder_path}/{CI_name}_d{d}.PNG"
filename = filename.replace("%", "").replace(" ", "_")
plt.savefig(filename, bbox_inches="tight")
plt.close()
# * same in linear scale
plt.figure(figsize=(8, 6))
plt.title(f"{CI_name} for {name}")
plt.imshow(
x_ci_size_d_plot,
origin="lower",
cmap="viridis",
)
plt.colorbar()
self._draw_rect_on_pixels_of_interest()
filename = f"{folder_path}/{CI_name}_linscale_d{d}.PNG"
filename = filename.replace("%", "").replace(" ", "_")
plt.savefig(filename, bbox_inches="tight")
plt.close()
return
[docs]
def plot_CI_size_u(
self,
u_ci_size,
CI_name,
folder_path: str,
model_name: Optional[str] = "",
list_lines: List[str] = [],
):
"""
Only used in hierarchical models. The sampling of such model is not implemented.
"""
u_ci_size_plot = self.map_shaper.from_vector_to_map(u_ci_size)
for ell, line in enumerate(list_lines):
u_ci_size_d_plot = u_ci_size_plot[:, :, ell] * 1
plt.figure(figsize=(8, 6))
plt.title(f"{CI_name} for {line}")
plt.imshow(
u_ci_size_d_plot,
norm=colors.LogNorm(),
origin="lower",
cmap="viridis",
)
plt.colorbar()
self._draw_rect_on_pixels_of_interest()
filename = f"{folder_path}/{CI_name}_{ell}.PNG"
filename = filename.replace("%", "")
plt.savefig(filename, bbox_inches="tight")
plt.close()
# * same in linear scale
plt.figure(figsize=(8, 6))
plt.title(f"{CI_name} for {line}")
plt.imshow(
u_ci_size_d_plot,
origin="lower",
cmap="viridis",
)
plt.colorbar()
self._draw_rect_on_pixels_of_interest()
filename = f"{folder_path}/{CI_name}_linscale_{ell}.PNG"
filename = filename.replace("%", "")
plt.savefig(filename, bbox_inches="tight")
plt.close()