Source code for elicit.networks

# noqa SPDX-FileCopyrightText: 2024 Florence Bockting <florence.bockting@tu-dortmund.de>
#
# noqa SPDX-License-Identifier: Apache-2.0

import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions


[docs] def NF(inference_network: callable, network_specs: dict, base_distribution: callable): """ specification of the normalizing flow used from BayesFlow library Parameters ---------- inference_network : callable type of inference network as specified by bayesflow.inference_networks. network_specs : dict specification of normalizing flow architecture. Arguments are inherited from chosen bayesflow.inference_networks. base_distribution : callable Base distribution from which should be sampled during learning. Normally the base distribution is a multivariate normal. input_dim : int number of model parameters. Returns ------- nf_dict : dict dictionary specifying the normalizing flow settings. """ nf_dict = dict( inference_network=inference_network, network_specs=network_specs, base_distribution=base_distribution ) return nf_dict
[docs] class BaseNormal:
[docs] def __call__(self, num_params): base_dist = tfd.MultivariateNormalDiag( loc=tf.zeros(num_params), scale_diag=tf.ones(num_params) ) return base_dist
base_normal = BaseNormal()