import os
from typing import List, Tuple
import h5py
import matplotlib.pyplot as plt
import numpy as np
from beetroots.inversion.results.utils.abstract_util import ResultsUtil
# TODO: for now, assumes that the Sampler has a MTM kernel and a PMALA kernel
# -> generalize ?
[docs]
class ResultsKernels(ResultsUtil):
__slots__ = (
"model_name",
"chain_type",
"path_img",
"N_run",
"effective_T",
)
def __init__(
self,
model_name: str,
chain_type: str,
path_img: str,
N_run: int,
T: int,
freq_save: int,
):
assert chain_type in ["mcmc", "optim_map"]
self.model_name = model_name
self.chain_type = chain_type
self.path_img = path_img
self.N_run = N_run
self.effective_T = T // freq_save
[docs]
def read_data(
self,
list_chains_folders: List[str],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
list_type = np.zeros((self.N_run, self.effective_T))
list_accepted = np.zeros((self.N_run, self.effective_T))
list_log_proba = np.zeros((self.N_run, self.effective_T))
for seed, mc_path in enumerate(list_chains_folders):
with h5py.File(mc_path, "r") as f:
list_type[seed] = np.array(f["list_type_t"])
list_accepted[seed] = np.array(f["list_accepted_t"])
list_log_proba[seed] = np.array(f["list_log_proba_accept_t"])
return list_type, list_accepted, list_log_proba
[docs]
def create_folders(self) -> Tuple[str, str]:
folder_path_inter = f"{self.path_img}/accepted_freq"
folder_path_inter2 = f"{folder_path_inter}/{self.chain_type}"
folder_path_accept_freq = f"{folder_path_inter2}/{self.model_name}"
for path_ in [
folder_path_inter,
folder_path_inter2,
folder_path_accept_freq,
]:
if not os.path.isdir(path_):
os.mkdir(path_)
folder_path_inter = f"{self.path_img}/log_proba_accept"
folder_path_inter2 = f"{folder_path_inter}/{self.chain_type}"
folder_path_log_p_accept = f"{folder_path_inter2}/{self.model_name}"
for path_ in [
folder_path_inter,
folder_path_inter2,
folder_path_log_p_accept,
]:
if not os.path.isdir(path_):
os.mkdir(path_)
return (folder_path_accept_freq, folder_path_log_p_accept)
[docs]
def plot_accept_freq(
self,
folder_path: str,
list_type: np.ndarray,
list_accepted: np.ndarray,
) -> None:
# mobile mean size
k_mm_mtm = 20 # MTM
k_mm_mala = 20 # P-MALA
print("starting plot of accepted frequencies")
# * MTM kernel
for seed in range(self.N_run):
idx_mtm = list_type[seed] == 0
accepted_mtm = list_accepted[seed, idx_mtm]
if accepted_mtm.size > k_mm_mtm:
accepted_mtm_smooth = np.convolve(
accepted_mtm,
np.ones(k_mm_mtm) / k_mm_mtm,
mode="valid",
)
plt.figure(figsize=(8, 6))
nan_mean = 100 * np.nanmean(accepted_mtm)
plt.title(f"MTM : {nan_mean:.2f} % accepted")
plt.plot(accepted_mtm_smooth, label="mobile mean")
plt.grid()
plt.legend()
plt.xticks(rotation=45)
# plt.tight_layout()
filename = f"{folder_path}/freq_accept_seed{seed}_MTM.PNG"
plt.savefig(filename, bbox_inches="tight")
plt.close()
# * PMALA
for seed in range(self.N_run):
idx_pmala = list_type[seed] == 1
accepted_pmala = list_accepted[seed, idx_pmala]
if accepted_pmala.size > k_mm_mala:
accepted_pmala_smooth = np.convolve(
accepted_pmala,
np.ones(k_mm_mala) / k_mm_mala,
mode="valid",
)
plt.figure(figsize=(8, 6))
nan_mean = 100 * np.nanmean(accepted_pmala)
plt.title(f"PMALA : {nan_mean:.2f} % accepted")
plt.plot(accepted_pmala_smooth, label="mobile mean")
plt.grid()
plt.legend()
plt.xticks(rotation=45)
# plt.tight_layout()
filename = f"{folder_path}/freq_accept_seed{seed}_PMALA.PNG"
plt.savefig(filename, bbox_inches="tight")
plt.close()
print("plots of accepted frequencies done")
return
[docs]
def plot_log_proba_accept(
self,
folder_path: str,
list_type: np.ndarray,
list_log_proba: np.ndarray,
) -> None:
"""plots log proba accept per kernel"""
# mobile mean size
k_mm_mtm = 20 # MTM
k_mm_mala = 20 # P-MALA
print("starting plot of log proba accept")
# * MTM
for seed in range(self.N_run):
idx_mtm = list_type[seed] == 0
list_log_proba_mtm = list_log_proba[seed, idx_mtm]
if list_log_proba_mtm.size > k_mm_mtm:
list_log_proba_mtm_smooth = np.convolve(
list_log_proba_mtm,
np.ones(k_mm_mtm) / k_mm_mtm,
mode="valid",
)
plt.figure(figsize=(8, 6))
nan_mean = np.nanmean(list_log_proba_mtm)
nan_median = np.nanmedian(list_log_proba_mtm)
title = f"MTM: log proba accept avg: {nan_mean:.3e},"
title += f"median: {nan_median:.3e}"
plt.title(title)
plt.plot(list_log_proba_mtm_smooth, label="mobile mean")
# plt.axvline(self.T_BI, c="k", ls="--", label="T_BI")
plt.grid()
plt.legend()
plt.yscale("symlog")
plt.xticks(rotation=45)
# plt.tight_layout()
filename = f"{folder_path}/log_proba_accept_seed{seed}_MTM.PNG"
plt.savefig(filename, bbox_inches="tight")
plt.close()
# * PMALA
for seed in range(self.N_run):
idx_pmala = list_type[seed] == 1
list_log_proba_pmala = list_log_proba[seed, idx_pmala]
if list_log_proba_pmala.size > k_mm_mala:
list_log_proba_pmala_smooth = np.convolve(
list_log_proba_pmala,
np.ones(k_mm_mala) / k_mm_mala,
mode="valid",
)
plt.figure(figsize=(8, 6))
nan_mean = np.nanmean(list_log_proba_mtm)
nan_median = np.nanmedian(list_log_proba_mtm)
title = f"PMALA: log proba accept avg: {nan_mean:.3e},"
title += f"median: {nan_median:.3e}"
plt.title(title)
plt.plot(list_log_proba_pmala_smooth, label="mobile mean")
# plt.axvline(self.T_BI, c="k", ls="--", label="T_BI")
plt.grid()
plt.legend()
plt.yscale("symlog")
plt.xticks(rotation=45)
# plt.tight_layout()
filename = f"{folder_path}/log_proba_accept_"
filename += f"seed{seed}_PMALA.PNG"
plt.savefig(filename, bbox_inches="tight")
plt.close()
print("plots of log proba accept done")
return
[docs]
def main(self, list_chains_folders: List[str]) -> None:
list_type, list_accepted, list_log_proba = self.read_data(
list_chains_folders,
)
(folder_accept_freq, folder_log_proba_accept) = self.create_folders()
self.plot_accept_freq(folder_accept_freq, list_type, list_accepted)
self.plot_log_proba_accept(
folder_log_proba_accept,
list_type,
list_log_proba,
)