Source code for beetroots.sampler.saver.abstract_saver

import abc
import os
from typing import List, Optional

import h5py
import numpy as np

from beetroots.space_transform.abstract_transform import Scaler


[docs] class Saver: r"""enable to regularly save the progression of the Markov chain to a ``.hdf5`` file""" __slots__ = ( "N", "D", "D_sampling", "L", "results_path", "batch_size", "freq_save", "scaler", "t_last_save", "t_last_init", "next_batch_size", "final_next_batch_size", "memory", "save_forward_map_evals", "list_idx_sampling", ) def __init__( self, N: int, D: int, D_sampling: int, L: int, scaler: Scaler, results_path: str = "", batch_size: Optional[int] = None, freq_save: int = 1, save_forward_map_evals: bool = False, list_idx_sampling: Optional[List[int]] = None, ): """ Parameters ---------- N : int total number of pixels to reconstruct D : int total number of physical parameters D_sampling : int number of physical parameters that are optimized / sampled L : int number of observed lines scaler : Scaler contains the transformation of the Theta values from their scaled space (in which the sampling happens) to their natural space results_path : str path towards the ``.hdf5`` output file, by default "" batch_size : int, optional number of iterations between two saves on file, by default None freq_save : int, optional save one sample in every (freq_save). Used to save disk space., by default 1 save_forward_map_evals: bool, optional wether to save the forward model evaluations and gradients, by default False list_idx_sampling : Optional[List[int]], optional contains the indices of the physical parameters to be sampled """ self.N = N r"""int: total number of pixels to reconstruct""" self.D = D r"""int: total number of physical parameters""" self.D_sampling = D_sampling r"""int: number of physical parameters that are optimized / sampled""" self.L = L r"""int: number of observed lines""" if list_idx_sampling is None: list_idx_sampling_arr = np.arange(self.D) else: list_idx_sampling_arr = np.array(list_idx_sampling) self.list_idx_sampling = list_idx_sampling_arr r"""1D np.ndarray: contains the indices of the physical parameters to be sampled""" self.results_path = results_path r"""str: path towards the ``.hdf5`` output file""" if len(results_path) > 0 and not (os.path.isdir(results_path)): os.mkdir(results_path) self.batch_size = batch_size r"""int: frequency of saves, i.e., "every ``batch_size`` new iterates to be saved, the memory is saved to an ``.hdf5`` file and re-initialized""" self.freq_save = freq_save r"""int: frequency of saved iterates during the run (1 means that every iteration is saved)""" self.scaler = scaler r"""Scaler: contains the transformation of the Theta values from their natural space to their scaled space (in which the sampling happens)""" # these two attributes are initialized by initialize_memory # updated during sampling self.t_last_save = 0 r"""int: time index of the last save of the memory to ``.hdf5`` file""" self.t_last_init = 0 r"""int: time index of the last memory initialization""" self.next_batch_size = 0 r"""int: number of iterates to be stored in the next batch, i.e., until next save to file""" self.final_next_batch_size = 0 r"""int: """ self.memory = dict() """dict[str, Union[float, np.ndarray]]: stores the values before saving them to file""" self.save_forward_map_evals = save_forward_map_evals r"""bool: wether to save the forward model evaluations and gradients"""
[docs] def set_results_path(self, results_path: str) -> None: r"""sets the path of the ``.hdf5`` file to a new value Parameters ---------- results_path : str path towards the ``.hdf5`` output file """ self.results_path = results_path if len(results_path) > 0 and not (os.path.isdir(results_path)): os.mkdir(results_path) return
[docs] def check_need_to_save(self, t: int) -> bool: """checks wether or not the memory should be saved to a ``.hdf5`` file Parameters ---------- t : int current iteration index Returns ------- bool wether or not to save to disk now """ current_t_in_batch = t - self.t_last_init + 1 return current_t_in_batch == self.next_batch_size * self.freq_save
[docs] def check_need_to_update_memory(self, t: int) -> bool: """checks wether or not the memory should be updated Parameters ---------- t : int current iteration index Returns ------- bool wether or not to save to update the memory """ return t % self.freq_save == 0
[docs] @abc.abstractmethod def initialize_memory( self, T_MC: int, t: int, Theta: np.ndarray, forward_map_evals: dict = dict(), nll_utils: dict = dict(), dict_objective: dict = dict(), additional_sampling_log: dict = dict(), ) -> None: r"""initializes the memory with the correct shapes""" pass
[docs] @abc.abstractmethod def update_memory( self, t: int, Theta: np.ndarray, forward_map_evals: dict = dict(), nll_utils: dict = dict(), dict_objective: dict = dict(), additional_sampling_log: dict = dict(), rng_state_array: Optional[np.ndarray] = None, rng_inc_array: Optional[np.ndarray] = None, ) -> None: """updates the memory with new information. All of the potential entries are optional except for the current iterate.""" pass
[docs] def save_to_file(self): """Saves the current memory content to a ``.hdf5`` file""" if self.t_last_init == 1: # if first writing with h5py.File( os.path.join(self.results_path, "mc_chains.hdf5"), "w", ) as f: for k, v in self.memory.items(): f.create_dataset(k, data=v, maxshape=(None,) + v.shape[1:]) else: # append data to already created file with h5py.File( os.path.join(self.results_path, "mc_chains.hdf5"), "a", ) as f: for k, v in self.memory.items(): f[k].resize( f[k].shape[0] + self.final_next_batch_size, axis=0, ) f[k][-self.final_next_batch_size :] = v self.memory = dict()
[docs] def save_additional( self, list_arrays: List[np.ndarray], list_names: List[str], ) -> None: r"""saves additional content to a ``.hdf5`` file Parameters ---------- list_arrays : List[np.ndarray] list of the arrays to be saved list_names : List[str] list of names for the arrays to be saved in the ``.hdf5`` file """ assert len(list_names) == len(list_arrays) with h5py.File( os.path.join(self.results_path, "mc_chains.hdf5"), "a", ) as f: for name, array_ in zip(list_names, list_arrays): f.create_dataset(name, data=array_) return