Source code for elicit.optimization
# noqa SPDX-FileCopyrightText: 2024 Florence Bockting <florence.bockting@tu-dortmund.de>
#
# noqa SPDX-License-Identifier: Apache-2.0
import tensorflow as tf
import tensorflow_probability as tfp
import time
import numpy as np
import pandas as pd
import logging
import os
import elicit as el
from tqdm import tqdm
tfd = tfp.distributions
[docs]
def sgd_training(
expert_elicited_statistics,
prior_model_init,
trainer,
optimizer,
model,
targets
):
"""
Wrapper that runs the optimization algorithms for E epochs.
Parameters
----------
expert_elicited_statistics : dict
expert data or simulated data representing a prespecified ground truth.
prior_model_init : class instance
instance of a class that initializes and samples from the prior
distributions.
one_forward_simulation : callable
one forward simulation cycle including: sampling from priors,
simulating model
predictions, computing target quantities and elicited statistics.
compute_loss : callable
sub-dag to compute the loss value including: compute loss components
of model simulations and expert data, compute loss per component,
compute total loss.
global_dict : dict
dictionary including all user-input settings.
"""
# set seed
tf.random.set_seed(trainer["seed"])
# prepare generative model
prior_model = prior_model_init
total_loss = []
component_losses = []
loss_comp_expert = []
loss_comp_model = []
gradients_ep = []
time_per_epoch = []
# initialize the adam optimizer
optimizer_copy = optimizer.copy()
init_sgd_optimizer = optimizer["optimizer"]
optimizer_copy.pop("optimizer")
sgd_optimizer = init_sgd_optimizer(**optimizer_copy)
# start training loop
print("Training")
for epoch in tqdm(tf.range(trainer["epochs"])):
if epoch > 0:
logging.disable(logging.INFO)
# runtime of one epoch
epoch_time_start = time.time()
with tf.GradientTape() as tape:
# generate simulations from model
(train_elicits, prior_sim, model_sim, target_quants
) = el.utils.one_forward_simulation(
prior_model, trainer, model, targets
)
# compute total loss as weighted sum
(weighted_total_loss, loss_components_expert,
loss_components_training, loss_per_component
) = el.losses.compute_loss(
train_elicits,
expert_elicited_statistics,
epoch,
targets
)
# compute gradient of loss wrt trainable_variables
gradients = tape.gradient(
weighted_total_loss, prior_model.trainable_variables
)
# update trainable_variables using gradient info with adam
# optimizer
sgd_optimizer.apply_gradients(
zip(gradients, prior_model.trainable_variables)
)
# time end of epoch
epoch_time_end = time.time()
epoch_time = epoch_time_end - epoch_time_start
# break for loop if loss is NAN and inform about cause
if tf.math.is_nan(weighted_total_loss):
print("Loss is NAN. The training process has been stopped.")
break
# %% Saving of results
if trainer["method"] == "parametric_prior":
# save gradients per epoch
gradients_ep.append(gradients)
# save single learned hyperparameter values for each prior and
# epoch
# extract learned hyperparameter values
hyperparams = prior_model.trainable_variables
if epoch == 0:
# prepare list for saving hyperparameter values
hyp_list = []
for i in range(len(hyperparams)):
hyp_list.append(hyperparams[i].name[:-2])
# create a dict with empty list for each hyperparameter
res_dict = {f"{k}": [] for k in hyp_list}
# save names and values of hyperparameters
vars_values = [
hyperparams[i].numpy().copy() for i in range(len(hyperparams))
]
vars_names = [
hyperparams[i].name[:-2] for i in range(len(hyperparams))
]
# create a final dict of hyperparameter values
for val, name in zip(vars_values, vars_names):
res_dict[name].append(val)
if trainer["method"] == "deep_prior":
# save mean and std for each sampled marginal prior
# for each epoch
if epoch == 0:
res_dict = {"means": [], "stds": []}
means = tf.reduce_mean(model_sim["prior_samples"], (0, 1))
sds = tf.reduce_mean(tf.math.reduce_std(model_sim["prior_samples"], 1), 0)
for val, name in zip([means, sds], ["means", "stds"]):
res_dict[name].append(val)
# savings per epoch (independent from chosen method)
time_per_epoch.append(epoch_time)
total_loss.append(weighted_total_loss)
component_losses.append(loss_per_component)
res_ep = {
"loss": total_loss,
"loss_component": component_losses,
"time": time_per_epoch,
"hyperparameter": res_dict
}
output_res = {
"target_quantities": target_quants,
"elicited_statistics": train_elicits,
"prior_samples": prior_sim,
"model_samples": model_sim,
"model": prior_model,
"loss_tensor_expert": loss_components_expert,
"loss_tensor_model": loss_components_training,
}
if trainer["method"] == "parametric_prior":
res_ep["hyperparameter_gradient"] = gradients_ep
return res_ep, output_res