# Copyright © 2023, UChicago Argonne, LLC
# All Rights Reserved
# Math & Science
import numpy as np
# Other imports
import time
import re
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Callable, Sequence
from pyoed.configs import (
validate_key,
aggregate_configurations,
set_configurations,
PyOEDConfigsValidationError,
)
from pyoed.utility import (
isnumber,
plot_sampling_results,
)
from pyoed.optimization.scipy_optimization import ScipyOptimizer
from pyoed.stats.core.sampling import (
Proposal,
Sampler,
SamplerConfigs,
)
from .proposals import GaussianProposal
[docs]
@dataclass(kw_only=True, slots=True)
class RejectionSamplerConfigs(SamplerConfigs):
"""
Configurations for the rejection sampler :py:class:`RejectionSampler`.
In addition to the attributes provided by :py:class:`SamplerConfigs`,
this configurations class provides the following attributes:
:param size: dimension of the target distribution to sample
:param log_density: log of the (unscaled) density function to be sampled
:param proposal: a proposal object to be used for generating new samples
:param constraint_test: a function that returns a boolean value `True` if sample
point satisfy any constrints, and `False` otherwise; ignored if `None`, is
passed.
:param bounds: bounds of the domain of the target random variable (if any)
:param random_seed: random seed used when the object is initiated to keep track
of random samples. This is useful for reproductivity. If `None`, random seed
follows `numpy.random.seed` rules.
"""
size: int | None = None
log_density: Callable[[np.ndarray], float] | None = None
proposal: Proposal | None = None
constraint_test: Callable[[np.ndarray], bool] | None = None
bounds: Sequence | None = None
[docs]
@set_configurations(configurations_class=RejectionSamplerConfigs)
class RejectionSampler(Sampler):
"""
Implementation of the rejection sampling algorithm with a predefined proposal.
:param configs: a configurations object. See :py:class:`RejectionSamplerConfigs`.
:note:
- Validation of the configurations dictionary is taken care of in the super class
- If a proposal is passed in the configurations, 'constraint_test' should be
set to it. If a constraint test is passed both here and in the proposal,
the one passed here will overwrite the one associated with the proposal
"""
[docs]
def __init__(self, configs: RejectionSamplerConfigs | dict | None = None):
configs = self.configurations_class.data_to_dataclass(configs)
super().__init__(configs)
# If no proposal in the configurations, create default Gaussian
if self.configurations.proposal is None:
self.configurations.proposal = GaussianProposal(
{
"size": self.configurations.size,
"mean": 0,
"variance": 1,
}
)
# Maintain a proper random number generator (here and in the proposal)
self.update_random_number_generators(
random_seed=self.configurations.random_seed,
update_proposal=True,
)
# Update proposal's constraint test (if passed here and not there)
if self.configurations.constraint_test is not None:
self.proposal.update_configurations(
constraint_test=self.configurations.constraint_test.constraint_test,
)
# Define log-density
self._LOG_DENSITY = self._update_log_density()
# Create an optimizer that will be used to approximate the MAP point of the distribution
self._OPTIMIZER = ScipyOptimizer(
{
"fun": lambda x: - self.log_density(x),
"x0": np.zeros(self.configurations.size),
"bounds": self.configurations.bounds,
}
)
[docs]
def validate_configurations(
self,
configs: dict | RejectionSamplerConfigs,
raise_for_invalid: bool = True,
) -> bool:
"""
Check the passed configuratios and make sure they are conformable with each
other, and with current configurations once combined. This guarantees that any
key-value pair passed in configs can be properly used
:param configs: full or partial (subset) configurations to be validated
:param raise_for_invalid: if `True` raise :py:class:`TypeError` for invalid
configrations type/key. Default `True`
:returns: flag indicating whether passed configurations dictionary is valid or not
:raises AttributeError: if any (or a group) of the configurations does not exist
in the model configurations :py:class:`RejectionSamplerConfigs`.
:raises PyOEDConfigsValidationError: if the configurations are invalid and
`raise_for_invalid` is set to True.
"""
# Fuse configs into current/default configurations
aggregated_configs = aggregate_configurations(
obj=self,
configs=configs,
configs_class=self.configurations_class
)
## Validate specific entries/configs
is_float = lambda x: utility.isnumber(x) and (x == float(x))
is_float_or_none = lambda x: x is None or is_float(x)
is_callable_or_none = lambda x: x is None or callable(x)
is_sequence = lambda x: utility.isiterable(x)
is_sequence_or_none = lambda x: x is None or is_sequence(x)
is_dict_or_none = lambda x: x is None or isinstance(x, dict)
for key in ["size", ]:
if not validate_key(
aggregated_configs,
configs,
key,
test=lambda x: isnumber(x) and (int(x) == x) and (x > 0),
message=f"{key} must be a positive integer",
raise_for_invalid=raise_for_invalid,
):
return False
for key in ["log_density", ]:
if not validate_key(
aggregated_configs,
configs,
key,
test=callable,
message=f"{key} must be a callable function",
raise_for_invalid=raise_for_invalid,
):
return False
if not validate_key(
aggregated_configs,
configs,
"constraint_test",
test=lambda x: callable(x) or (x is None),
message="constraint_test must be a callable function or None",
raise_for_invalid=raise_for_invalid,
):
return False
if not validate_key(
aggregated_configs,
configs,
"proposal",
test=lambda x: (isinstance(x, Proposal) and x.configurations.size==aggregated_configs.size) or (x is None),
message="proposal must be an instance of class `Proposal` or None",
raise_for_invalid=raise_for_invalid,
):
return False
return super().validate_configurations(configs, raise_for_invalid)
[docs]
def update_configurations(self, **kwargs):
"""
Take any set of keyword arguments, and lookup each in
the configurations, and update as nessesary/possible/valid
:raises PyOEDConfigsValidationError: if invalid configurations passed
"""
# Validate and udpate
super().update_configurations(**kwargs)
## Special updates based on specific arguments
if 'log_density' in kwargs:
self._update_log_density(log_density=kwargs['log_density'])
# Update proposal's constraint test (if passed here and not there)
if 'constraint_test' in kwargs:
self.proposal.update_configurations(constraint_test=kwargs['constraint_test'])
# Create an optimizer that will be used to approximate the MAP point of the distribution
if 'log_density' in kwargs or 'bounds' in kwargs:
self.optimizer.update_configurations(
fun=lambda x: - self.log_density(x),
bounds=self.configurations.bounds,
)
# Check if `random_seed' is passed
if 'random_seed' in kwargs:
self.update_random_number_generators(
random_seed=random_seed,
update_proposal=True,
)
def _update_log_density(
self,
log_density=None,
):
"""
Update the function that evaluates the logarithm of the (unscaled) target density function
and the associated gradient (if given) as described by the configurations dictionary.
This can be halpful to avoid recreating the sampler for various PDFs.
This method defines/updates two variables:
`_LOG_DENSITY` which evalute the value of the log-density function of
the (unscaled) target distribution
"""
size = self.size
# Log-Density function
if log_density is None:
log_density = self.configurations.log_density
if not callable(log_density):
raise TypeError(
f"'log_density' found in the configurations is not a valid callable/function!"
)
try:
test_vec = np.ones(size)
log_density(test_vec)
except:
raise TypeError(
f"Failed to evaluate the log-density using a randomly generated vector"
)
self._LOG_DENSITY = log_density
return log_density
[docs]
def sample(
self,
sample_size=1,
full_diagnostics=False,
):
"""
Generate and return a sample of size `sample_size`.
This method returns a list with each entry representing a sample point from
the underlying distribution
:param int sample_size:
:param initial_state:
:param bool full_diagnostics: if `True` all generated states will be tracked
and kept for full disgnostics, otherwise, only collected samples are kept in memory
:returns: a list of samples collected from the target distributions
"""
results = self.start_batch_rejection_sampling(
sample_size=sample_size,
full_diagnostics=full_diagnostics,
)
return results["collected_ensemble"]
[docs]
def start_batch_rejection_sampling(
self,
sample_size,
full_diagnostics=False,
):
"""
Start the the rejection sampling procedure. Unlike `start_rejection_sampling`, this function works
by proposing a batch of samples (from the proposal) of size equal to the number of remaining samples
to be collected; those who are accepted and satisfy the underlying `constraint_test` if one is set
in the configurations.
:param int sample_size: number of smaple points to generate/collect from the predefined target distribution
:param bool full_diagnostics: if `True` all generated states will be tracked and kept for full disgnostics, otherwise,
only collected samples are kept in memory
:remarks: This will replace `start_rejection_sampling` for efficiency; timing will decide!
"""
constraint_test = self.configurations.constraint_test
bounds = self.configurations.bounds
liner, sliner = "=" * 53, "-" * 40
if self.verbose:
print("\n%s\nStarted Sampling\n%s\n" % (liner, liner))
# Minimum and maximum probabilities
pmin = 0 # 0.0 could be generalized by optimization as well!
# Optimize for the MAP point (to get maximum (log) probability)
pmax = min(max(- self.optimizer.solve().optimization_results.fun, 0), 1)
# Extract configurations from the configurations dictionary
constraint_test = self.configurations.constraint_test
# All generated sample points will be kept for testing and efficiency analysis
proposals_repository = []
uniform_random_numbers = []
acceptance_flags = []
# Initialize samples container
collected_ensemble = np.empty((sample_size, self.size))
collected_ensemble[...] = np.nan
start_time = time.time() # start timing
collected = 0
#
while collected < sample_size:
# Propose a batch of states and calculate its probability
proposed_states = self.proposal.sample(sample_size=sample_size - collected)
## Accept/Reject proposed state
# Calculate acceptance proabability
if constraint_test is not None:
meet = np.where([constraint_test(state) for state in proposed_states])[
0
]
proposed_states = proposed_states[meet]
# Retry to get states that satisfy the contraint; the proposal should handle it anyways
if len(proposed_states) == 0:
continue
# Evaluate log-density for each of the proposed states
proposed_states_pdf = np.asarray(
[np.exp(self.log_density(state)) for state in proposed_states]
)
# a uniform random number between pmin and pmax
uniform_probability = self.random_number_generator.uniform(
low=pmin,
high=pmax,
size=proposed_states_pdf.size,
)
# MH-rule Accept/Reject
flags = proposed_states_pdf > uniform_probability
accepted_states = proposed_states[np.where(flags)[0]]
num_accepted = len(accepted_states)
collected_ensemble[
collected : collected + num_accepted, :
] = accepted_states
# Update counter
collected += num_accepted
if self.verbose:
print(f"\rSampled {collected}/{sample_size}", end=" ")
# Update Results Repositories:
if full_diagnostics:
proposals_repository += [state for state in proposed_states]
uniform_random_numbers += [u for u in uniform_probability]
acceptance_flags += [f for f in flags]
# Stop timing
sampling_time = time.time() - start_time
# ------------------------------------------------------------------------------------------------
# Now output diagnostics and show some plots :)
if full_diagnostics:
chain_diagnostics = self.diagnostic_statistics(
proposals_repository=proposals_repository,
collected_ensemble=collected_ensemble,
uniform_probabilities=uniform_random_numbers,
acceptance_flags=acceptance_flags,
plot_diagnostics=True,
plot_title=f"Rejection Sampling ('{self.proposal.configurations.name}' Proposal)",
filename_prefix="Rejection_Sampling",
)
else:
chain_diagnostics = None
#
# ======================================================================================================== #
# Output sampling diagnostics and plot the results for 1 and 2 dimensions #
# ======================================================================================================== #
#
if self.verbose:
print("Rejection Sampling:")
print(f"Time Elapsed for Rejection sampling: {sampling_time} seconds")
if chain_diagnostics is not None:
print(f"Acceptance Rate: {chain_diagnostics['acceptance_rate']:.2f}")
sampling_results = dict(
proposals_repository=proposals_repository,
uniform_random_numbers=uniform_random_numbers,
acceptance_flags=acceptance_flags,
collected_ensemble=collected_ensemble,
chain_diagnostics=chain_diagnostics,
sampling_time=sampling_time,
)
return sampling_results
[docs]
def start_rejection_sampling(
self,
sample_size,
full_diagnostics=False,
):
"""
Start the the rejection sampling procedure.
:param int sample_size: number of smaple points to generate/collect from the
predefined target distribution
:param bool full_diagnostics: if `True` all generated states will be tracked and
kept for full disgnostics, otherwise, only collected samples are kept in memory
"""
liner, sliner = "=" * 53, "-" * 40
if self.verbose:
print("\n%s\nStarted Sampling\n%s\n" % (liner, liner))
# Minimum and maximum probabilities
pmin = 0 # 0.0 could be generalized by optimization as well!
# Optimize for the MAP point (to get maximum (log) probability)
pmax = - self.optimizer.solve().optimization_results.fun
# Extract configurations from the configurations dictionary
constraint_test = self.configurations.constraint_test
# All generated sample points will be kept for testing and efficiency analysis
proposals_repository = []
uniform_random_numbers = []
acceptance_flags = []
collected_ensemble = []
start_time = time.time() # start timing
collected = 0
#
while collected < sample_size:
# Propose state and calculate its probability
proposed_state = self.proposal.sample()[0]
## Accept/Reject proposed state
# Calculate acceptance proabability
constraint_violated = False
if constraint_test is not None:
if not constraint_test(proposed_state):
constraint_violated = True
if constraint_violated:
proposal_pdf = 0
else:
proposal_pdf = np.exp(self.log_density(proposed_state))
# a uniform random number between pmin and pmax
uniform_probability = self.random_number_generator.uniform(
low=pmin, high=pmax
)
# MH-rule
if proposal_pdf > uniform_probability:
accept_proposal = True
collected_ensemble.append(proposed_state)
collected += 1
else:
accept_proposal = False
if self.verbose:
print(f"\rSampled {collected}/{sample_size}", end=" ")
# Update Results Repositories:
if full_diagnostics:
proposals_repository.append(proposed_state)
uniform_random_numbers.append(uniform_probability)
#
if accept_proposal:
acceptance_flags.append(1)
else:
acceptance_flags.append(0)
# Stop timing
sampling_time = time.time() - start_time
# ------------------------------------------------------------------------------------------------
# Now output diagnostics and show some plots :)
if full_diagnostics:
chain_diagnostics = self.diagnostic_statistics(
proposals_repository=proposals_repository,
collected_ensemble=collected_ensemble,
uniform_probabilities=uniform_random_numbers,
acceptance_flags=acceptance_flags,
plot_diagnostics=True,
plot_title=f"Rejection Sampling ('{self.proposal.configurations.name}' Proposal)",
filename_prefix="Rejection_Sampling",
)
else:
chain_diagnostics = None
#
# ======================================================================================================== #
# Output sampling diagnostics and plot the results for 1 and 2 dimensions #
# ======================================================================================================== #
#
if self.verbose:
print("Rejection Sampling:")
print(f"Time Elapsed for Rejection sampling: {sampling_time} seconds")
if chain_diagnostics is not None:
print(f"Acceptance Rate: {chain_diagnostics['acceptance_rate']:.2f}")
sampling_results = dict(
proposals_repository=proposals_repository,
uniform_random_numbers=uniform_random_numbers,
acceptance_flags=acceptance_flags,
collected_ensemble=collected_ensemble,
chain_diagnostics=chain_diagnostics,
sampling_time=sampling_time,
)
return sampling_results
[docs]
def log_density(self, state):
"""
Evaluate the value of the logarithm of the target unscaled posterior density function
"""
val = self._LOG_DENSITY(state)
try:
val[0]
val = np.asarray(val).flatten()[0]
except:
pass
# if isinstance(val, np.ndarray) and val.size == 1: val = val.flatten()[0]
return val
[docs]
def diagnostic_statistics(
self,
proposals_repository,
uniform_probabilities,
acceptance_flags,
collected_ensemble,
plot_diagnostics=True,
output_dir=None,
plot_title="Rejection Sampling",
filename_prefix="Rejection_Sampling",
):
"""
Return diagnostic statistics of the sampler such as the rejection rate, acceptance ratio, etc.
"""
acceptance_flags = np.asarray(acceptance_flags).flatten()
if output_dir is None:
output_dir = self.configurations.output_dir
acceptance_rate = (
float(acceptance_flags.sum()) / np.size(acceptance_flags) * 100.0
)
rejection_rate = 100.0 - acceptance_rate
# TODO: Add More; e.g., effective sample size, etc.
# Return all diagonistics in a dictionary
chain_diagnositics = dict(
acceptance_rate=acceptance_rate,
rejection_rate=rejection_rate,
)
# Plots & autocorrelation, etc.
if plot_diagnostics:
plot_sampling_results(
sample=collected_ensemble,
log_density=self.log_density,
title=plot_title,
output_dir=output_dir,
filename_prefix=filename_prefix,
)
return chain_diagnositics
[docs]
def update_random_number_generators(
self,
random_seed,
update_proposal=True,
):
"""
Reset/Update the underlying random_number generator by resetting it's seed.
The random number generator is provided by the `RandomNumberGenerationMixin`
If `update_proposal` is `True` do the same for underlying `proposal`
This actually replaces the current random number generator(s) with a new one(s)
created from the given `random_seed`.
"""
self.update_random_number_generator(random_seed=random_seed)
if update_proposal:
self.proposal.update_configurations(random_seed=random_seed)
@property
def size(self):
"""Return the dimentionsize of the underlying probability space"""
return self.configurations.size
@property
def proposal(self):
"""Get a handle of the proposal"""
return self.configurations.proposal
@proposal.setter
def proposal(self, value):
"""Update the proposal"""
self.update_configurations(proposal=value)
@property
def optimizer(self):
return self._OPTIMIZER
## Simple interfaces (to generate instances from classes developed here).
[docs]
def create_rejection_sampler(
size,
log_density,
proposal=None,
bounds=None,
constraint_test=None,
output_dir=None,
random_seed=None,
):
"""
Given the size of the target space, and a function to evalute log density,
create and return an :py:class:`RejectionSampler` instance/object to generate samples
using standard Rejection Sampling approach.
Configurations/settings can be updated after inistantiation
:param int size: dimension of the target distribution to sample
:param log_density: a callable (function) to evaluate the logarithm of the target density (unscaled);
this function takes one vector (1d array/iterable) of length equal to `size`, and returns a scalar
:param Proposal proposal: proposal instance (to generate candidate samples)
:param constraint_test: a function that returns a boolean value `True` if sample point satisfy
any constrints, and `False` otherwise; ignored if `None`, is passed.
:param random_seed: random seed used when the object is initiated to keep track of random samples
This is useful for reproductivity. If `None`, random seed follows `numpy.random.seed` rules
:return: instance of :py:class:`RejctionSampler` (with some or all configurations passed)
"""
configs = dict(
size=size,
log_density=log_density,
proposal=proposal,
bounds=bounds,
constraint_test=constraint_test,
output_dir=output_dir,
random_seed=random_seed,
)
return RejectionSampler(configs)