import os
from typing import List
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from beetroots.inversion.results.utils.abstract_util import ResultsUtil
from beetroots.inversion.results.utils.mc import histograms
[docs]
class ResultsRegularizationWeights(ResultsUtil):
__slots__ = (
"model_name",
"list_names",
"path_img",
"path_data_csv_out_mcmc",
"N_MCMC",
"T_MC",
"T_BI",
"freq_save",
"effective_T_BI",
"effective_T_MC",
"D",
)
def __init__(
self,
model_name: str,
path_img: str,
path_data_csv_out_mcmc: str,
N_MCMC: int,
T_MC: int,
T_BI: int,
freq_save: int,
D_sampling: int,
list_names: List[str],
):
self.model_name = model_name
self.list_names = list_names
self.path_img = path_img
self.path_data_csv_out_mcmc = path_data_csv_out_mcmc
self.N_MCMC = N_MCMC
self.T_MC = T_MC
self.T_BI = T_BI
self.freq_save = freq_save
self.effective_T_BI = T_BI // freq_save
self.effective_T_MC = T_MC // freq_save
self.D_sampling = D_sampling
self.lower_bounds_lin = 1e-2 * np.ones((1,))
self.upper_bounds_lin = 1e2 * np.ones((1,))
[docs]
def read_data(self, list_mcmc_folders: List[str]):
list_tau = np.zeros((self.N_MCMC, self.effective_T_MC, self.D_sampling))
for seed, mc_path in enumerate(list_mcmc_folders):
with h5py.File(mc_path, "r") as f:
# try:
list_tau[seed] = np.array(f["list_tau"])
# except:
# return np.empty((self.D,)), False
return list_tau
[docs]
def create_folders(self):
folder_path_inter = f"{self.path_img}/regularization_weights"
folder_path = f"{folder_path_inter}/{self.model_name}"
for path_ in [folder_path_inter, folder_path]:
if not os.path.isdir(path_):
os.mkdir(path_)
return folder_path
[docs]
def estimate_regu_weight(self, list_mcmc_folders: List[str]) -> np.ndarray:
for i, mc_path in enumerate(list_mcmc_folders):
if i == 0:
with h5py.File(mc_path, "r") as f:
list_tau = np.array(f["list_tau"][self.effective_T_BI :])
else:
with h5py.File(mc_path, "r") as f:
list_tau = np.concatenate(
[
list_tau,
np.array(
f["list_tau"][self.effective_T_BI :],
),
]
)
estimated_regu_weights = list_tau.mean(0)
return estimated_regu_weights
[docs]
def main(
self,
list_mcmc_folders: List[str],
list_idx_sampling: List[int],
) -> np.ndarray:
folder_path = self.create_folders()
list_tau = self.read_data(list_mcmc_folders)
print("starting plots of regularization weights")
for seed in range(self.N_MCMC):
for idx, d in enumerate(list_idx_sampling):
list_tau_sd = list_tau[seed, :, idx] * 1
assert list_tau_sd.shape == (self.effective_T_MC,)
list_tau_sd_no_BI = list_tau_sd[self.effective_T_BI :] * 1
tau_MMSE = list_tau_sd_no_BI.mean(0)
IC_2p5 = np.percentile(list_tau_sd_no_BI, q=2.5, axis=0)
IC_97p5 = np.percentile(list_tau_sd_no_BI, q=97.5, axis=0)
assert isinstance(tau_MMSE, float), tau_MMSE
assert isinstance(IC_2p5, float), IC_2p5
assert isinstance(IC_97p5, float), IC_97p5
title = f"MC {self.list_names[d]} spatial regularization"
title += f" weight for chain {seed}"
histograms.plot_1D_chain(
list_tau_sd,
None,
d,
folder_path,
title,
lower_bounds_lin=self.lower_bounds_lin,
upper_bounds_lin=self.upper_bounds_lin,
N_MCMC=self.N_MCMC,
T_MC=self.T_MC,
T_BI=self.T_BI,
)
title = "posterior distribution of spatial regularization"
title += f" weight of {self.list_names[d]}"
histograms.plot_1D_hist(
list_tau_sd_no_BI,
None,
d,
folder_path,
title=title,
lower_bounds_lin=self.lower_bounds_lin,
upper_bounds_lin=self.upper_bounds_lin,
seed=seed,
estimator=tau_MMSE,
IC_low=IC_2p5,
IC_high=IC_97p5,
)
# altogether
list_tau_flatter = list_tau.reshape(
(self.N_MCMC * self.effective_T_MC, self.D_sampling),
)
list_tau_flatter_no_BI = list_tau[:, self.effective_T_BI :].reshape(
(self.N_MCMC * (self.T_MC - self.T_BI) // self.freq_save, self.D_sampling)
)
tau_MMSE = list_tau_flatter_no_BI.mean(0) # (D,)
IC_2p5 = np.percentile(list_tau_flatter_no_BI, q=2.5, axis=0) # (D,)
IC_97p5 = np.percentile(list_tau_flatter_no_BI, q=97.5, axis=0) # (D,)
assert tau_MMSE.shape == (
self.D_sampling,
), f"tau_MMSE {tau_MMSE} should have shape ({self.D_sampling},)"
assert IC_2p5.shape == (self.D_sampling,)
assert IC_97p5.shape == (self.D_sampling,)
plt.figure(figsize=(8, 6))
plt.title("regularization weights sampling")
for idx, d in enumerate(list_idx_sampling):
plt.semilogy(list_tau_flatter[:, idx], label=self.list_names[d])
for seed in range(self.N_MCMC):
if seed == 0:
plt.axvline(
seed * self.effective_T_MC + self.effective_T_BI,
c="k",
ls="--",
label="T_BI",
)
elif seed == 1:
plt.axvline(
seed * self.effective_T_MC,
c="k",
ls="-",
label="new MC",
)
plt.axvline(
seed * self.effective_T_MC + self.effective_T_BI,
c="k",
ls="--",
)
else:
plt.axvline(seed * self.effective_T_MC, c="k", ls="-")
plt.axvline(
seed * self.effective_T_MC + self.effective_T_BI,
c="k",
ls="--",
)
plt.grid()
plt.legend()
# plt.tight_layout()
plt.savefig(
f"{folder_path}/mc_regu_weights.PNG",
bbox_inches="tight",
)
plt.close()
for idx, d in enumerate(list_idx_sampling):
title = "posterior distribution of spatial regularization weight"
title = f" of {self.list_names[d]}"
histograms.plot_1D_hist(
list_tau_flatter_no_BI[:, idx],
None,
d,
folder_path,
title,
self.lower_bounds_lin,
self.upper_bounds_lin,
None,
tau_MMSE[idx],
IC_2p5[idx],
IC_97p5[idx],
)
list_estimation_tau = []
for idx, d in enumerate(list_idx_sampling):
dict_ = {
"model_name": self.model_name,
"name": self.list_names[d],
"MMSE": tau_MMSE[idx],
}
for q in [0.5, 2.5, 5, 95, 97.5, 99]:
dict_[f"per_{q}"] = np.percentile(
list_tau_flatter_no_BI[:, idx],
q=q,
)
list_estimation_tau.append(dict_)
df_estimation_tau = pd.DataFrame(list_estimation_tau)
path_file = f"{self.path_data_csv_out_mcmc}/estimation_tau.csv"
df_estimation_tau.to_csv(
path_file,
mode="a",
header=not (os.path.exists(path_file)),
)
print("plots of regularization weights done")
return