Source code for beetroots.simulations.astro.real_data.real_data_nn

from typing import Dict, List, Optional, Union

import numpy as np

from beetroots.modelling.priors.spatial_prior_params import SpatialPriorParams
from beetroots.sampler.utils.my_sampler_params import MySamplerParams
from beetroots.simulations.abstract_simulation import Simulation
from beetroots.simulations.astro import data_validation
from beetroots.simulations.astro.forward_map.abstract_nn import SimulationNN
from beetroots.simulations.astro.observation.abstract_real_data import (
    SimulationRealData,
)
from beetroots.simulations.astro.posterior_type.abstract_direct import (
    SimulationMySampler,
)


[docs] class SimulationRealDataNN(SimulationNN, SimulationRealData, SimulationMySampler): __slots__ = ( "path_output_sim", "path_img", "path_raw", "path_data_csv", "path_data_csv_in", "path_data_csv_out", "path_data_csv_out_mcmc", "path_data_csv_out_optim_map", "path_data_csv_out_optim_mle", "N", "D", "D_no_kappa", "L", "list_names", "list_names_plot", "cloud_name", "max_workers", "list_lines_fit", "list_lines_valid", "Theta_true_scaled", "map_shaper", "plots_estimator", )
[docs] def setup( self, forward_model_name: str, force_use_cpu: bool, fixed_params: Dict[str, Optional[float]], is_log_scale_params: Dict[str, bool], # data_int_path: str, data_err_path: str, sigma_m_float: float, # indicator_margin_scale: float, lower_bounds_lin: np.ndarray, upper_bounds_lin: np.ndarray, # with_spatial_prior: bool = True, spatial_prior_params: Union[None, SpatialPriorParams] = None, list_gaussian_approx_params: List[bool] = [], list_mixing_model_params: List[Dict[str, str]] = [], ): self.list_lines_valid = [] ( df_int_fit, y_fit, sigma_a_fit, omega_fit, y_valid, sigma_a_valid, omega_valid, ) = self.setup_observation( data_int_path=data_int_path, data_err_path=data_err_path, ) scaler, forward_map = self.setup_forward_map( forward_model_name=forward_model_name, force_use_cpu=force_use_cpu, dict_fixed_params=fixed_params, dict_is_log_scale_params=is_log_scale_params, ) sigma_m_fit = sigma_m_float * np.ones((self.N, self.L)) # run setup print(f"lower_bounds_lin = {lower_bounds_lin}") dict_posteriors, scaler, prior_indicator_1pix = self.setup_posteriors( scaler=scaler, forward_map=forward_map, y=y_fit, sigma_a=sigma_a_fit, sigma_m=sigma_m_fit, omega=omega_fit, syn_map=df_int_fit, with_spatial_prior=with_spatial_prior, spatial_prior_params=spatial_prior_params, indicator_margin_scale=indicator_margin_scale, lower_bounds_lin=lower_bounds_lin, upper_bounds_lin=upper_bounds_lin, list_gaussian_approx_params=list_gaussian_approx_params, list_mixing_model_params=list_mixing_model_params, ) return ( dict_posteriors, scaler, prior_indicator_1pix, y_valid, sigma_a_valid, omega_valid, )
[docs] def main( self, params: dict, path_data_cloud: str, point_challenger: dict = {}, ) -> None: if params["with_spatial_prior"]: spatial_prior_params = SpatialPriorParams( **params["spatial_prior"], ) else: spatial_prior_params = None sigma_m_float = np.log(params["sigma_m_float_linscale"]) ( dict_posteriors, scaler, prior_indicator_1pix, y_valid, sigma_a_valid, omega_valid, ) = self.setup( **params["forward_model"], # data_int_path=f"{path_data_cloud}/{params['filename_int']}", data_err_path=f"{path_data_cloud}/{params['filename_err']}", sigma_m_float=sigma_m_float, # **params["prior_indicator"], # with_spatial_prior=params["with_spatial_prior"], spatial_prior_params=spatial_prior_params, list_gaussian_approx_params=params["list_gaussian_approx_params"], list_mixing_model_params=[ {"path_transition_params": f"{path_data_cloud}/{filename}"} for filename in params["mixing_model_params_filename"] ], ) self.save_and_plot_setup( dict_posteriors, params["prior_indicator"]["lower_bounds_lin"], params["prior_indicator"]["upper_bounds_lin"], scaler, ) # * Optim MAP if params["to_run_optim_map"]: self.inversion_optim_map( dict_posteriors=dict_posteriors, scaler=scaler, my_sampler_params=MySamplerParams(**params["sampling_params"]["map"]), can_run_in_parallel=params["forward_model"]["force_use_cpu"], **params["run_params"]["map"], ) # * MCMC if params["to_run_mcmc"]: self.inversion_mcmc( dict_posteriors=dict_posteriors, scaler=scaler, my_sampler_params=MySamplerParams(**params["sampling_params"]["mcmc"]), can_run_in_parallel=params["forward_model"]["force_use_cpu"], # y_valid=y_valid, sigma_a_valid=sigma_a_valid, omega_valid=omega_valid, sigma_m_valid=sigma_m_float * np.ones_like(sigma_a_valid), # point_challenger=point_challenger, **params["run_params"]["mcmc"], ) return