Source code for elicit.targets

# noqa SPDX-FileCopyrightText: 2024 Florence Bockting <florence.bockting@tu-dortmund.de>
#
# noqa SPDX-License-Identifier: Apache-2.0

import tensorflow as tf
import tensorflow_probability as tfp
import bayesflow as bf
import inspect
import pandas as pd

from elicit.extras import utils

tfd = tfp.distributions
bfn = bf.networks


# TODO: Update Custom Target Function
[docs] def use_custom_functions(custom_function, model_simulations): """ Helper function that prepares custom functions if specified by checking all inputs and extracting the argument from different sources. Parameters ---------- custom_function : callable custom function as specified by the user. model_simulations : dict simulations from the generative model. global_dict : dict dictionary including all user-input settings. Returns ------- custom_quantity : tf.Tensor returns the evaluated custom function. """ # get function custom_func = custom_function["function"] # create a dict with arguments from model simulations and custom args # for custom func args_dict = dict() if custom_function["additional_args"] is not None: additional_args_dict = { f"{key}": custom_function["additional_args"][key] for key in list(custom_function["additional_args"].keys()) } else: additional_args_dict = {} # select only relevant keys from args_dict custom_args_keys = inspect.getfullargspec(custom_func)[0] # check whether expert-specific input has been specified if "from_simulated_truth" in custom_args_keys: for i in range(len(inspect.getfullargspec(custom_func)[3][0])): quantity = inspect.getfullargspec(custom_func)[3][i][0] true_model_simulations = pd.read_pickle( output_path + "/expert/model_simulations.pkl" ) for key in custom_args_keys: if f"{key}" == quantity: args_dict[key] = true_model_simulations[quantity] custom_args_keys.remove(quantity) custom_args_keys.remove("from_simulated_truth") # TODO: check that all args needed for custom function were detected for key in list(set(custom_args_keys) - set(additional_args_dict)): args_dict[key] = model_simulations[key] for key in additional_args_dict: args_dict.update(additional_args_dict) # evaluate custom function custom_quantity = custom_func(**args_dict) return custom_quantity
[docs] def computation_elicited_statistics( target_quantities: dict, targets): """ Computes the elicited statistics from the target quantities by applying a prespecified elicitation technique. Parameters ---------- target_quantities : dict simulated target quantities. global_dict : dict dictionary including all user-input settings. Returns ------- elicits_res : dict simulated elicited statistics. """ # initialize dict for storing results elicits_res = dict() # loop over elicitation techniques for i in range(len(targets)): # use custom method if specified otherwise use built-in methods if targets[i]["query"]["name"] == "custom": elicited_statistic = use_custom_functions( targets[i]["elicitation_method"]["value"], target_quantities ) elicits_res[f"custom_{targets[i]['name']}"] = elicited_statistic if targets[i]["query"]["name"] == "identity": elicits_res[f"identity_{targets[i]['name']}" ] = target_quantities[targets[i]['name']] if targets[i]["query"]["name"] == "pearson_correlation": # compute correlation between model parameters (used for # learning correlation structure of joint prior) elicited_statistic = utils.pearson_correlation( target_quantities[targets[i]['name']]) # save correlation in result dictionary elicits_res[f"pearson_{targets[i]['name']}" ] = elicited_statistic if targets[i]["query"]["name"] == "quantiles": quantiles = targets[i]["query"]["value"] # reshape target quantity if tf.rank(target_quantities[targets[i]['name']]) == 3: quan_reshaped = tf.reshape( target_quantities[targets[i]['name']], shape=( target_quantities[targets[i]['name']].shape[0], target_quantities[targets[i]['name']].shape[1] * target_quantities[targets[i]['name']].shape[2], ), ) if tf.rank(target_quantities[targets[i]['name']]) == 2: quan_reshaped = target_quantities[targets[i]['name']] # compute quantiles computed_quantiles = tfp.stats.percentile( quan_reshaped, q=quantiles, axis=-1 ) # bring quantiles to the last dimension elicited_statistic = tf.einsum("ij...->ji...", computed_quantiles) elicits_res[f"quantiles_{targets[i]['name']}"] = elicited_statistic # return results return elicits_res
[docs] def computation_target_quantities(model_simulations, targets): """ Computes target quantities from model simulations. Parameters ---------- model_simulations : dict simulations from generative model. global_dict : dict dictionary including all user-input settings.. Returns ------- targets_res : dict computed target quantities. """ # initialize dict for storing results targets_res = dict() # loop over target quantities for i in range(len(targets)): tar = targets[i] # use custom function for target quantity if it has been defined if tar["name"] == "correlation": target_quantity = model_simulations["prior_samples"] elif ( tar["target_method"] is not None ): target_quantity = use_custom_functions( tar["target_method"], model_simulations ) else: target_quantity = model_simulations[tar["name"]] # save target quantities targets_res[tar["name"]] = target_quantity return targets_res