Source code for beetroots.inversion.results.utils.mc

import multiprocessing as mp
import os
import time
import warnings
from concurrent.futures import ProcessPoolExecutor
from typing import Dict, List, Optional, Tuple, Union

import h5py
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from beetroots.inversion.results.utils.abstract_util import ResultsUtil
from beetroots.inversion.results.utils.mc_utils import ess, histograms
from beetroots.space_transform.abstract_transform import Scaler


[docs] class ResultsMC(ResultsUtil): __slots__ = ( "model_name", "chain_type", "path_img", "path_data_csv_out_mcmc", "max_workers", "N_MCMC", "T_MC", "T_BI", "freq_save", "effective_T_BI", "lower_bounds_lin", "upper_bounds_lin", "N", "D", "list_names", ) def __init__( self, model_name: str, chain_type: str, path_img: str, path_data_csv_out_mcmc: str, max_workers: int, N_MCMC: int, T_MC: int, T_BI: int, freq_save: int, N: int, list_idx_sampling: List, list_fixed_values_scaled: List, lower_bounds_lin: Union[np.ndarray, List[float]], upper_bounds_lin: Union[np.ndarray, List[float]], list_names: List[str], ): assert chain_type in ["mcmc", "optim_map"] self.model_name = model_name self.chain_type = chain_type self.path_img = path_img self.path_data_csv_out_mcmc = path_data_csv_out_mcmc self.max_workers = max_workers self.N_MCMC = N_MCMC self.T_MC = T_MC self.T_BI = T_BI self.freq_save = freq_save self.effective_T_BI = T_BI // freq_save if isinstance(lower_bounds_lin, list): lower_bounds_lin = np.array(lower_bounds_lin) if isinstance(upper_bounds_lin, list): upper_bounds_lin = np.array(upper_bounds_lin) self.lower_bounds_lin = lower_bounds_lin self.upper_bounds_lin = upper_bounds_lin self.list_idx_sampling = list_idx_sampling self.list_fixed_values_scaled = list_fixed_values_scaled self.N = N self.D = upper_bounds_lin.size self.D_sampling = len(self.list_idx_sampling) self.list_names = list_names
[docs] def read_data(self): pass
[docs] def create_folders(self) -> Tuple[str, str, str, str, str]: folder_path_mc = f"{self.path_img}/mc" folder_path_1D = f"{folder_path_mc}/{self.model_name}_1D" folder_path_1D_chain = f"{folder_path_1D}/chains" folder_path_1D_hist = f"{folder_path_1D}/hist" folder_path_2D = f"{folder_path_mc}/{self.model_name}_2D" folder_path_2D_chain = f"{folder_path_2D}/chains" folder_path_2D_hist = f"{folder_path_2D}/hist" folder_path_2D_proba = f"{folder_path_2D}/proba_contours" for path_ in [ folder_path_mc, folder_path_1D, folder_path_1D_chain, folder_path_1D_hist, folder_path_2D, folder_path_2D_chain, folder_path_2D_hist, folder_path_2D_proba, ]: if not os.path.isdir(path_): os.mkdir(path_) return ( folder_path_1D_chain, folder_path_1D_hist, folder_path_2D_chain, folder_path_2D_hist, folder_path_2D_proba, )
[docs] def full_mc_analysis( self, scaler: Scaler, Theta_true_scaled_full: Optional[np.ndarray], list_mcmc_folders: List[str], plot_ESS: bool, plot_1D_chains: bool, plot_2D_chains: bool, plot_comparisons_yspace: bool, # folder_path_1D_chain: str, folder_path_1D_hist: str, folder_path_2D_chain: str, folder_path_2D_hist: str, folder_path_2D_proba: str, # point_challenger: Dict = {}, list_CI: List[int] = [], ) -> None: global _one_pixel_mmse_ic_extraction list_fixed_values_lin = np.zeros((1, self.D)) for d, value in enumerate(self.list_fixed_values_scaled): if value is not None: list_fixed_values_lin[0, d] = value * 1 list_fixed_values_lin = scaler.from_scaled_to_lin( list_fixed_values_lin, ).flatten() # (D,) # list_fixed_values_lin = [ # v if d not in self.list_idx_sampling else None # for d, v in enumerate(list_fixed_values_lin) # ] len_mc = (self.T_MC - self.T_BI) // self.freq_save def _one_pixel_mmse_ic_extraction(dict_input: dict): """for one pixel n, performs: - MMSE and credibility interval extraction - ESS computation - plot 1D histograms - plot 2D histograms """ n = dict_input["n"] Theta_n_true = dict_input["Theta_n_true"] point_challenger_n = dict_input["point_challenger_n"] # read data list_Theta_n_lin_full = np.zeros((self.N_MCMC, len_mc, self.D)) for d in range(self.D): if d not in self.list_idx_sampling: list_Theta_n_lin_full[:, :, d] = list_fixed_values_lin[d] * 1 for seed, mc_path in enumerate(list_mcmc_folders): with h5py.File(mc_path, "r") as f: list_Theta_n_lin_sub = np.array( f["list_Theta"][self.effective_T_BI :, n, :] ) for idx, d in enumerate(self.list_idx_sampling): list_Theta_n_lin_full[seed, :, d] = list_Theta_n_lin_sub[:, idx] # * MMSE and IC estimators list_Theta_n_lin_full_flatter = list_Theta_n_lin_full.reshape( (self.N_MCMC * (len_mc), self.D) ) # compute percentiles dict_ci_per = { ci: {"lower": (100 - ci) / 2, "upper": 100 - (100 - ci) / 2} for ci in list_CI } list_per = list([dict_ci_per[ci]["lower"] for ci in list_CI]) list_per += list([dict_ci_per[ci]["upper"] for ci in list_CI]) list_per.sort() dict_per = { per: np.percentile(list_Theta_n_lin_full_flatter, per, axis=0) # (D,) for per in list_per } # compute MMSE list_Theta_n_scaled_full_flatter = scaler.from_lin_to_scaled( list_Theta_n_lin_full_flatter, ) # for d in range(self.D): # if self.list_fixed_values_scaled[d] is not None: # list_Theta_n_scaled_flatter[:, d] += self.list_fixed_values_scaled[d] Theta_n_MMSE_scaled = np.mean( list_Theta_n_scaled_full_flatter, axis=0 ) # (D,) Theta_n_MMSE_lin = scaler.from_scaled_to_lin( Theta_n_MMSE_scaled.reshape((1, self.D)) ).flatten() # (D,) assert Theta_n_MMSE_lin.shape == ( self.D, ), f"shape {Theta_n_MMSE_lin.shape}, should be {(self.D,)}" # create and save dataset of MMSE and IC df_estim = pd.DataFrame() df_estim["n"] = n * np.ones((self.D,), dtype=np.int32) df_estim["d"] = np.arange(self.D) df_estim["Theta_MMSE"] = Theta_n_MMSE_lin * 1 for per in list_per: df_estim[f"per_{per:.1f}".replace(".", "p")] = dict_per[per] # in order to avoid re-writing multiple times the header because of # parallel writing, force a delay to favor n = 0 to be written # first with header path_overall_results = f"{self.path_data_csv_out_mcmc}/" path_overall_results += f"estimation_Theta_{self.model_name}.csv" if n == 0: df_estim.to_csv(path_overall_results, mode="w") else: while not (os.path.exists(path_overall_results)): time.sleep(0.5) df_estim.to_csv( path_overall_results, mode="a", header=not (os.path.exists(path_overall_results)), ) del df_estim # * index of first element st true val btw [MC first val, elt] or # * [elt, MC first val] if Theta_n_true is not None: first_elt_arr = -np.ones((self.N_MCMC, self.D_sampling)) for seed in range(self.N_MCMC): for idx_d, d in enumerate(self.list_idx_sampling): if list_Theta_n_lin_full[seed, 0, d] < Theta_n_true[d]: (idx,) = np.where( list_Theta_n_lin_full[seed, :, d] >= Theta_n_true[d], ) else: (idx,) = np.where( list_Theta_n_lin_full[seed, :, d] <= Theta_n_true[d], ) if idx.size > 0: first_elt_arr[seed, idx_d] = idx[0] list_dict = [ { "seed": seed, "n": n, "d": d, "first_elt_valid_mc": int(first_elt_arr[seed, idx_d]), } for seed in range(self.N_MCMC) for idx_d, d in enumerate(self.list_idx_sampling) ] df_first_elt_valid_mc = pd.DataFrame.from_records(list_dict) path_file = f"{self.path_data_csv_out_mcmc}/" path_file += f"first_elt_valid_mc_{self.model_name}.csv" if n == 0: df_first_elt_valid_mc.to_csv( path_file, mode="w", ) else: while not (os.path.exists(path_file)): time.sleep(0.5) df_first_elt_valid_mc.to_csv( path_file, mode="a", header=not (os.path.exists(path_file)), ) del df_first_elt_valid_mc # * ESS if plot_ESS: list_Theta_n_scaled_full = list_Theta_n_scaled_full_flatter.reshape( (self.N_MCMC, len_mc, self.D) ) list_dict_output = [] for d in self.list_idx_sampling: ess_ = ess.compute_ess(list_Theta_n_scaled_full[:, :, d]) list_dict_output.append( { "n": n, "d": d, "seed": "overall", "model_name": self.model_name, "ess": ess_, } ) df_ess_nd = pd.DataFrame.from_records(list_dict_output) path_file = f"{self.path_data_csv_out_mcmc}/" path_file += f"estimation_ESS_{self.model_name}.csv" if n == 0: df_ess_nd.to_csv(path_file, mode="w") else: while not (os.path.exists(path_file)): time.sleep(0.5) df_ess_nd.to_csv( path_file, mode="a", header=not (os.path.exists(path_file)), ) del df_ess_nd # * 1D histograms if plot_1D_chains: for d in self.list_idx_sampling: true_val = Theta_n_true[d] if Theta_n_true is not None else None title = f"Markov Chain of {self.list_names[d]}" if self.N > 1: title += f" of pixel {n}" histograms.plot_1D_chain( list_Theta_lin_nd=list_Theta_n_lin_full_flatter[:, d], n=n if self.N > 1 else None, d=d, folder_path=folder_path_1D_chain, title=title, lower_bounds_lin=self.lower_bounds_lin, upper_bounds_lin=self.upper_bounds_lin, N_MCMC=self.N_MCMC, T_MC=self.T_MC, T_BI=self.T_BI, true_val=true_val, ) title = f"Sample histogram of {self.list_names[d]}" if self.N > 1: title += f" of pixel {n}" estimator = Theta_n_MMSE_lin[d] * 1 assert isinstance(estimator, float) histograms.plot_1D_hist( list_Theta_lin_seed=list_Theta_n_lin_full_flatter[:, d], n=n if self.N > 1 else None, d=d, folder_path=folder_path_1D_hist, title=title, lower_bounds_lin=self.lower_bounds_lin, upper_bounds_lin=self.upper_bounds_lin, seed=None, estimator=estimator, true_val=true_val, ) # * 2D histograms if plot_2D_chains and self.D > 1: for idx_d1, d1 in enumerate(self.list_idx_sampling): for d2 in self.list_idx_sampling[idx_d1 + 1 :]: if Theta_n_true is not None: true_val = Theta_n_true[[d1, d2]] * 1 else: true_val = None histograms.plot_2D_hist( list_Theta_n_lin_full_flatter[:, [d1, d2]], n if self.N > 1 else None, d1, d2, self.model_name, folder_path_2D_hist, self.list_names, self.lower_bounds_lin, self.upper_bounds_lin, Theta_MMSE=Theta_n_MMSE_lin[[d1, d2]], true_val=true_val, point_challenger=point_challenger_n, ) try: histograms.plot_2D_proba_contours( list_Theta_n_lin_full_flatter[:, [d1, d2]], n if self.N > 1 else None, d1, d2, self.model_name, folder_path_2D_proba, self.list_names, self.lower_bounds_lin, self.upper_bounds_lin, Theta_MMSE=Theta_n_MMSE_lin[[d1, d2]], true_val=true_val, point_challenger=point_challenger_n, ) except: msg = "Issue with proba contour plot for (n, d1, d2) = " msg += f"({n}, {d1}, {d2})" print(msg) return # * global part of the function if Theta_true_scaled_full is not None: Theta_true_lin = scaler.from_scaled_to_lin(Theta_true_scaled_full) else: Theta_true_lin = None list_params = [ { "n": n, "Theta_n_true": Theta_true_lin[n] if Theta_true_lin is not None else None, "point_challenger_n": { "name": point_challenger["name"], "value": point_challenger["value"][n, :], } if len(point_challenger) > 0 else point_challenger, } for n in range(self.N) ] # ? The parallel execution may fail on mac, even with the mp_context # ? argument. As I can't correct the error, in case of fail, I perform # ? the extraction in series, which is much slower. try: with ProcessPoolExecutor( max_workers=self.max_workers, mp_context=mp.get_context("fork") ) as p: _ = list( tqdm( p.map(_one_pixel_mmse_ic_extraction, list_params), total=self.N, ) ) except: warnings.warn( "The parallel pixel-wise result extraction failed. Extracting in series instead." ) for params in tqdm(list_params): _one_pixel_mmse_ic_extraction(params) return
[docs] def main( self, scaler: Scaler, Theta_true_scaled_full: Optional[np.ndarray], list_mcmc_folders: List[str], plot_ESS: bool, plot_1D_chains: bool, plot_2D_chains: bool, plot_comparisons_yspace: bool, point_challenger: Dict = {}, list_CI: List[int] = [], ): ( folder_path_1D_chain, folder_path_1D_hist, folder_path_2D_chain, folder_path_2D_hist, folder_path_2D_proba, ) = self.create_folders() self.full_mc_analysis( scaler, Theta_true_scaled_full, list_mcmc_folders, plot_ESS, plot_1D_chains, plot_2D_chains, plot_comparisons_yspace, # folder_path_1D_chain, folder_path_1D_hist, folder_path_2D_chain, folder_path_2D_hist, folder_path_2D_proba, # point_challenger=point_challenger, list_CI=list_CI, ) return