Toy Example - Deep prior#

import tensorflow_probability as tfp
import tensorflow as tf
import numpy as np
import elicit as el

from bayesflow.inference_networks import InvertibleNetwork

tfd = tfp.distributions
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import tensorflow_probability as tfp
      2 import tensorflow as tf
      3 import numpy as np

ModuleNotFoundError: No module named 'tensorflow_probability'

The Model#

Generative model#

\[\begin{align*} (\beta_0, \beta_1, \sigma) &\sim p_\lambda(\boldsymbol{\theta}) \\ \mu &= \beta_0 + \beta_1X \\ y_{pred} &\sim \text{Normal}(\mu, \sigma) \end{align*}\]

Implementation#

Generative model#

class ToyModel:
    def __call__(self, prior_samples, design_matrix):
        # linear predictor 
        epred=tf.matmul(prior_samples[:, :, :-1], design_matrix, 
                        transpose_b=True)
        
        # data-generating model
        likelihood = tfd.Normal(
            loc=epred, scale=tf.expand_dims(prior_samples[:, :, -1], -1)
        )
        # prior predictive distribution 
        ypred = likelihood.sample()
        
        # selected observations
        y_x0, y_x1, y_x2, y_x3, y_x4 = (ypred[:,:,i] for i in range(5))

        # R2 
        var_ypred = tf.math.reduce_variance(ypred, -1)
        var_epred = tf.math.reduce_variance(epred, -1)
        r2 = tf.math.divide(var_ypred, tf.add(var_ypred, var_epred))

        return dict(
            y_x0=y_x0, y_x1=y_x1, y_x2=y_x2, y_x3=y_x3, y_x4=y_x4,
            r2=r2
        )

Define design matrix#

# create a predictor ranging from 1 to 200 
# standardize predictor
# select the 5th, 25th, 50th, 75th, and 95th quantile of the std. predictor for querying the expert
def x_design(n, quantiles):
    x = tf.cast(np.arange(n), tf.float32)
    x_std = (x-tf.reduce_mean(x))/tf.math.reduce_std(x)
    x_sel = tfp.stats.percentile(x_std, quantiles)
    return tf.stack([tf.ones(x_sel.shape), x_sel], -1)

x_design(n=200, quantiles=[5,25,50,75,95])
<tf.Tensor: shape=(5, 2), dtype=float32, numpy=
array([[ 1.        , -1.5502049 ],
       [ 1.        , -0.85737586],
       [ 1.        ,  0.00866036],
       [ 1.        ,  0.85737586],
       [ 1.        ,  1.5502049 ]], dtype=float32)>

Model input for elicit method#

# specify the model
model=el.model(
        obj=ToyModel,
        design_matrix=x_design(n=200, quantiles=[5,25,50,75,95])
        )

Model parameters#

  • intercept parameter \(\beta_0\)

  • slope parameter \(\beta_1\)

  • error term \(\sigma\)

To be learned hyperparameters \(\lambda\): reflecting the weights of the deep neural networks within the normalizing flow architecture.

Parameter input for elicit method#

parameters=[
        el.parameter(name="beta0"),
        el.parameter(name="beta1"),
        el.parameter(name="sigma", lower=0.)
    ]

Target quantities and elicitation techniques#

Target quantities

  • query expert regarding prior predictions \(y \mid X_{i}\) with \(i\) being the 5th, 25th, 50th, 75th, and 95th quantile of the predictor.

  • \(R^2 = \frac{\text{Var}(\mu)}{\text{Var}(y)}\) (we use \(\log R^2\) for numerical stability)

  • correlation between model parameters

Elicitation technique

  • query each prior prediction using quantile-based elicitation using \(Q_p(y \mid X)\) for \(p=5, 25, 50, 75, 95\)

  • query \(R^2\) using quantile-based elicitation using \(Q_p(y \mid X)\) for \(p=5, 25, 50, 75, 95\)

  • regarding the correlation structure, we assume independence between the model parameters (thus, \(\boldsymbol{\rho}_\boldsymbol{\theta}=\mathbf{0}\))

Importance of elicited statistics in loss

  • all elicited statistics should have equal importance (weight=1.0)

  • for computing the discrepancy between expert-elicited statistics and model simulations with use the Maximum Mean Discrepancy with Energy kernel

Targets input for elicit method#

targets=[
        el.target(
            name=f"y_x{i}",
            query=el.queries.quantiles((.05, .25, .50, .75, .95)),
            loss=el.losses.MMD2(kernel="energy"),
            weight=1.0
        ) for i in range(5)
        ]+[
        el.target(
            name="r2",
            query=el.queries.quantiles((.05, .25, .50, .75, .95)),
            loss=el.losses.MMD2(kernel="energy"),
            weight=1.0
        ),
        el.target(
            name="correlation",
            query=el.queries.correlation(),
            loss=el.losses.L2,
            weight=0.1
        )
    ]

Expert elicitation#

  • instead of querying a “real” expert, we define a ground truth (i.e., oracle) and simulate the oracle-elicited statistics

Expert input for elicit method (here: oracle)#

# specify ground truth
ground_truth = {
    "beta0": tfd.Normal(loc=5, scale=1),
    "beta1": tfd.Normal(loc=2, scale=2),
    "sigma": tfd.HalfNormal(scale=5.0),
}

# define oracle
expert=el.expert.simulator(
        ground_truth = ground_truth,
        num_samples = 10_000
    )

Normalizing Flow#

ToDo

network=el.networks.NF(
        inference_network=InvertibleNetwork,
        network_specs=dict(
            num_params=3,
            num_coupling_layers=3,
            coupling_design="affine",
            coupling_settings={
                "dropout": False,
                "dense_args": {
                    "units": 128,
                    "activation": "relu",
                    "kernel_regularizer": None,
                },
                "num_dense": 2,
            },
            permutation="fixed"
        ),
        base_distribution=el.networks.base_normal
    )

Training: Learn prior distributions based on expert data#

All inputs for elicit method

eliobj = el.Elicit(
    model=model,
    parameters=parameters,
    targets=targets,
    expert=expert,
    optimizer=el.optimizer(
        optimizer=tf.keras.optimizers.Adam,
        learning_rate=0.001,
        clipnorm=1.0
        ),
    trainer=el.trainer(
        method="deep_prior",
        seed=4,
        epochs=500
    ),
    network=network
)

Run multiple chains

# run method
eliobj.fit(parallel=el.utils.parallel(chains=4))
---------------------------------------------------------------------------
_RemoteTraceback                          Traceback (most recent call last)
_RemoteTraceback: 
"""
Traceback (most recent call last):
  File "C:\Users\bockting\.conda\envs\prior_elicitation\Lib\site-packages\joblib\externals\loky\process_executor.py", line 426, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\bockting\.conda\envs\prior_elicitation\Lib\multiprocessing\queues.py", line 122, in get
    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\bockting\Documents\GitHub\prior_elicitation\elicit\__init__.py", line 6, in <module>
    from . import initialization  # noqa
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\bockting\Documents\GitHub\prior_elicitation\elicit\initialization.py", line 5, in <module>
    import tensorflow as tf
ModuleNotFoundError: No module named 'tensorflow'
"""

The above exception was the direct cause of the following exception:

BrokenProcessPool                         Traceback (most recent call last)
Cell In[31], line 2
      1 # run method
----> 2 eliobj.fit(parallel=el.utils.parallel(chains=4))

File ~\Documents\GitHub\prior_elicitation\elicit\elicit.py:1179, in Elicit.fit(self, overwrite, save_history, save_results, parallel)
   1176     seeds = parallel["seeds"]
   1178 # run training simultaneously for multiple seeds
-> 1179 (*res,) = joblib.Parallel(
   1180     n_jobs=parallel["cores"])(
   1181         joblib.delayed(self.workflow)(seed) for seed in seeds)
   1183 for i, seed in enumerate(seeds):
   1184     self.results.append(res[i][0])

File ~\.conda\envs\prior_elicitation\Lib\site-packages\joblib\parallel.py:2007, in Parallel.__call__(self, iterable)
   2001 # The first item from the output is blank, but it makes the interpreter
   2002 # progress until it enters the Try/Except block of the generator and
   2003 # reaches the first `yield` statement. This starts the asynchronous
   2004 # dispatch of the tasks to the workers.
   2005 next(output)
-> 2007 return output if self.return_generator else list(output)

File ~\.conda\envs\prior_elicitation\Lib\site-packages\joblib\parallel.py:1650, in Parallel._get_outputs(self, iterator, pre_dispatch)
   1647     yield
   1649     with self._backend.retrieval_context():
-> 1650         yield from self._retrieve()
   1652 except GeneratorExit:
   1653     # The generator has been garbage collected before being fully
   1654     # consumed. This aborts the remaining tasks if possible and warn
   1655     # the user if necessary.
   1656     self._exception = True

File ~\.conda\envs\prior_elicitation\Lib\site-packages\joblib\parallel.py:1754, in Parallel._retrieve(self)
   1747 while self._wait_retrieval():
   1748 
   1749     # If the callback thread of a worker has signaled that its task
   1750     # triggered an exception, or if the retrieval loop has raised an
   1751     # exception (e.g. `GeneratorExit`), exit the loop and surface the
   1752     # worker traceback.
   1753     if self._aborting:
-> 1754         self._raise_error_fast()
   1755         break
   1757     # If the next job is not ready for retrieval yet, we just wait for
   1758     # async callbacks to progress.

File ~\.conda\envs\prior_elicitation\Lib\site-packages\joblib\parallel.py:1789, in Parallel._raise_error_fast(self)
   1785 # If this error job exists, immediately raise the error by
   1786 # calling get_result. This job might not exists if abort has been
   1787 # called directly or if the generator is gc'ed.
   1788 if error_job is not None:
-> 1789     error_job.get_result(self.timeout)

File ~\.conda\envs\prior_elicitation\Lib\site-packages\joblib\parallel.py:745, in BatchCompletionCallBack.get_result(self, timeout)
    739 backend = self.parallel._backend
    741 if backend.supports_retrieve_callback:
    742     # We assume that the result has already been retrieved by the
    743     # callback thread, and is stored internally. It's just waiting to
    744     # be returned.
--> 745     return self._return_or_raise()
    747 # For other backends, the main thread needs to run the retrieval step.
    748 try:

File ~\.conda\envs\prior_elicitation\Lib\site-packages\joblib\parallel.py:763, in BatchCompletionCallBack._return_or_raise(self)
    761 try:
    762     if self.status == TASK_ERROR:
--> 763         raise self._result
    764     return self._result
    765 finally:

BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.

Results#

Convergence#

el.plots.loss(eliobj, figsize=(7,3))
../_images/40539ee44d10b27df59ddfe0a47d4edef802b258dad3a9b7dbdf0d4d36c18b39.png
el.plots.marginals(eliobj, cols=3, figsize=(6,4))
../_images/dbf6056fd9d28a8525f0a3874c1cce4e38a14313096d2d4e94b447a62c82ebbd.png

Expert expectations#

el.plots.elicits(eliobj, cols=4,figsize=(7,4))
../_images/57b5a17cc51788797ac368095e205e249702da28a5a08fc4a8677b3415ee6825.png

Learned joint prior#

el.plots.prior_joint(eliobj, idx=2)
../_images/d75ad3f1682c559b16714f89de9f7dcc4978e73509df934b3653b6bc69241ff1.png
el.plots.prior_marginals(eliobj, figsize=(7,2))
INFO: Reset cols=3 (number of elicited statistics)
../_images/4a3ac354f0b9594edf4cb59300d42274ede13dbad8dbe68d4da7cddfa8fde0d7.png
el.plots.prior_averaging(eliobj, figsize=(7,4))
INFO: Reset cols=3 (number of elicited statistics)
../_images/c05905528787f7d9e7c2b1dd2d924e5baebe1f63418f68a69375a78a5c6b1386.png

Add-on: Model averaging#

ToDo