Parameter validation

This is not necessary to reproducibility or a reproducible analytical pipeline, but is good practice when working, and may help avoid mistakes.


1 Introduction

Parameter validation refers to checking that inputs provided to functions or classes are correct and as expected.

This page focuses on two examples of validation we can perform:

  1. Preventing a dangerous but common mistake in discrete-event simulation (DES): accidentally creating new parameters through typos instead of modifying existing ones.
  2. Validating parameter values to ensure they fall within the expected range.
# pylint: disable=missing-module-docstring
# Import required packages
import inspect
# Import required packages
library(R6) # nolint: undesirable_function_linter


2 Accidental creation of new parameters

2.1 The problem

When defining your parameters in a function, if you mistype a parameter name, an error won’t be raised. Instead, a new, unused parameter is created. This can silently invalidate your results, as you may not realise that your parameter is unchanged.

Note: This parameter typo issue doesn’t occur when using R6 classes directly (as on the right), as they prevent adding new fields by default. However, if you extract the parameter list from an R6 class and modify that list separately, you’ll encounter the same silent error problem. For this reason, it’s better to modify parameters through the class interface rather than extracting and modifying the underlying list.

Function-based example:

# pylint: disable=missing-module-docstring
def param_function(transfer_prob=0.3):
    """
    Returns transfer_prob for validation example.

    Parameters
    ----------
    transfer_prob : float
        Transfer probability (0-1).

    Returns
    -------
    Dictionary containing the transfer_prob parameter.
    """
    return {"transfer_prob": transfer_prob}
# pylint: disable=undefined-variable
# Use function to create params dict
params = param_function()

# Mistype transfer_prob
params["transfer_probs"] = 0.4
print(params)
{'transfer_prob': 0.3, 'transfer_probs': 0.4}
#' Returns transfer_prob for validation example.
#'
#' @param transfer_prob Numeric. Transfer probability (0-1).
#'
#' @return A named list containing the transfer_prob parameter.

param_function <- function(transfer_prob = 0.3) {
  list(transfer_prob = transfer_prob)
}
# Use function to create params list
params <- param_function()

# Mistype transfer_prob
params$transfer_probs <- 0.4
params
$transfer_prob
[1] 0.3

$transfer_probs
[1] 0.4

Class-based example:

# pylint: disable=missing-module-docstring, invalid-name, too-few-public-methods
class ParamClass:
    """
    Returns transfer_prob for validation example.
    """
    def __init__(self, transfer_prob=0.3):
        """
        Initialise ParamClass instance.

        Parameters
        ----------
        transfer_prob : float
            Transfer probability (0-1).
        """
        self.transfer_prob = transfer_prob
# pylint: disable=used-before-assignment
# Create instance of ParamClass
params = ParamClass()

# Mistype transfer_prob
params.transfer_probs = 0.4
print(params.__dict__)
{'transfer_prob': 0.3, 'transfer_probs': 0.4}
#' @title Returns transfer_prob for validation example.
#'
#' @field transfer_prob Numeric. Transfer probability (0-1).

ParamClass <- R6Class( # nolint: object_name_linter
  public = list(
    transfer_prob = NULL,

    #' @description
    #' Initialises the R6 object.

    initialize = function(transfer_prob = 0.3) {
      self$transfer_prob <- transfer_prob
    }
  )
)
# Create instance of ParamClass
params <- ParamClass$new()

# Mistype transfer_prob
try({
  params$transfer_probs <- 0.4
})
Error in params$transfer_probs <- 0.4 : 
  cannot add bindings to a locked environment


2.2 The solution

There are two main approaches to prevent these silent failures:

  • When using functions: Implement parameter validation within your model functions.
  • When using classes: Build validation directly into the class structure.

There are two main approaches to prevent these silent failures when using functions:

  • Implement parameter validation within your model functions.
  • Switch to using R6 classes.

Class-based validation is often preferable as it catches errors at the point when parameters are defined - but both approaches are effective!

Functions are more commonly used in R than classes, so you may prefer to stick with the function-based approach for consistency with R conventions.


2.3 Parameter validation within the model functions

Our parameter function returns a dictionary. While you could validate arguments within the parameter function, this would not prevent accidental modification after the dictionary is returned.

When calling the function, you are restricted to the defined arguments - so extra entries cannot be added at that stage. However, once the collection is dictionary, it can be altered by adding or removing entries. For this reason, we incorporate validation into the model function (rather than the parameter function).

Our parameter function returns a list. While you could validate arguments within the parameter function, this would not prevent accidental modification after the list is returned.

When calling the function, you are restricted to the defined arguments - so extra entries cannot be added at that stage. However, once the collection is list, it can be altered by adding or removing entries. For this reason, we incorporate validation into the model function (rather than the parameter function).

As a reminder, this is our parameter function:

# pylint: disable=missing-module-docstring
def param_function(transfer_prob=0.3):
    """
    Returns transfer_prob for validation example.

    Parameters
    ----------
    transfer_prob : float
        Transfer probability (0-1).

    Returns
    -------
    Dictionary containing the transfer_prob parameter.
    """
    return {"transfer_prob": transfer_prob}
#' Returns transfer_prob for validation example.
#'
#' @param transfer_prob Numeric. Transfer probability (0-1).
#'
#' @return A named list containing the transfer_prob parameter.

param_function <- function(transfer_prob = 0.3) {
  list(transfer_prob = transfer_prob)
}

We can write a validation function which checks that all the required parameters are present, and that no extra parameters are provided:

def check_param_names(param_dict, param_function):
    """
    Validate parameter names.

    Ensure that all required parameters are present, and no extra parameters
    are provided.

    Parameters
    ----------
    param_dict : dict
        Dictionary containing parameters for the simulation.
    param_function : function
        Function used to generate the parameter dictionary.
    """
    # Get the set of valid parameter names from the function signature
    valid_params = set(inspect.signature(param_function).parameters)

    # Get the set of input parameter names from the provided dictionary
    input_params = set(param_dict)

    # Identify missing and extra parameters
    missing = valid_params - input_params
    extra = input_params - valid_params

    # If there are any missing or extra parameters, raise an error message
    if missing or extra:
        raise ValueError("; ".join([
            f"Missing keys: {', '.join(missing)}" if missing else "",
            f"Extra keys: {', '.join(extra)}" if extra else ""
        ]).strip("; "))
#' Validate parameter names.
#'
#' Ensure that all required parameters are present, and no extra parameters are
#' provided.
#'
#' @param param List containing parameters for the simulation.
#' @param param_function Function used to generate parameter list.
#'
#' @return Throws an error if there are missing or extra parameters.

check_param_names <- function(param, param_function) {

  # Get valid argument names from the function
  valid_names <- names(formals(param_function))

  # Get names from input parameter list
  input_names <- names(param)

  # Find missing keys (i.e. are there things in valid_names not in input)
  # and extra keys (i.e. are there things in input not in valid_names)
  missing_keys <- setdiff(valid_names, input_names)
  extra_keys <- setdiff(input_names, valid_names)

  # If there are any missing or extra keys, throw an error
  if (length(missing_keys) > 0L || length(extra_keys) > 0L) {
    error_message <- ""
    if (length(missing_keys) > 0L) {
      error_message <- paste0(
        error_message, "Missing keys: ", toString(missing_keys), ". "
      )
    }
    if (length(extra_keys) > 0L) {
      error_message <- paste0(
        error_message, "Extra keys: ", toString(extra_keys), ". "
      )
    }
    stop(error_message, call. = FALSE)
  }
}

Then, in our model function, we call the validation function to check all inputs before proceeding with the simulation:

def model(param_dict, param_function):
    """
    Run simulation after validating parameter names.

    Parameters
    ----------
    param_dict : dict
        Dictionary of parameters.
    param_function : function
        Function used to generate the parameter dictionary.
    """
    # Check all inputs are valid
    check_param_names(param_dict=param_dict, param_function=param_function)

    # Simulation code...


# Example usage
# No extra or missing parameters - model runs without issue
params = param_function()
model(params, param_function)

# Mistype transfer_prob - model returns an error
params["transfer_probs"] = 0.4
try:
    model(params, param_function)
except ValueError as e:
    print(e)
Extra keys: transfer_probs
#' Run simulation after validating parameter names.
#'
#' @param param Named list of model parameters.
#' @param param_function Function used to generate parameter list.
model <- function(param, param_function) {

  # Check all inputs are valid
  check_param_names(param = param, param_function = param_function)

  # Simulation code...
}


# Example usage
# No extra or missing parameters - model runs without issue
params <- param_function()
model(params, param_function)

# Mistype transfer_prob - model returns an error
params$transfer_probs <- 0.4
try(model(params, param_function))
Error : Extra keys: transfer_probs. 


2.4 Parameter validation within the class

We can add logic to prevent the addition of new attributes to our Python classes. This can be implemented either:

  • Directly within the class, or-
  • Using class inheritance.


Direct implementation within the class:

This approach implements validation logic directly within the class using a custom __setattr__ method.

# pylint: disable=too-few-public-methods
class Param:
    """
    Parameter class with validation to prevent the addition of new attributes.
    """
    def __init__(self, param1="test", param2=42):
        """
        Initialise Param instance.
        """
        # Disable restriction during initialisation
        object.__setattr__(self, "_initialising", True)

        # Set the attributes
        self.param1 = param1
        self.param2 = param2

        # Re-enable attribute checks after initialisation
        object.__setattr__(self, "_initialising", False)

    def __setattr__(self, name, value):
        """
        Prevent addition of new attributes.

        This method overrides the default `__setattr__` behavior to restrict
        the addition of new attributes to the instance. It allows modification
        of existing attributes but raises an `AttributeError` if an attempt is
        made to create a new attribute. This ensures that accidental typos in
        attribute names do not silently create new attributes.

        Parameters
        ----------
        name : str
          The name of the attribute to set.
        value : Any
          The value to assign to the attribute.

        Raises
        -------
        AttributeError:
            If `name` is not an existing attribute and an attempt is made
            to add it to the instance.
        """
        # Skip validation if still initialising
        # pylint: disable=maybe-no-member
        if hasattr(self, "_initialising") and self._initialising:
            super().__setattr__(name, value)
        else:
            # Check if attribute already exists
            if name in self.__dict__:
                super().__setattr__(name, value)
            else:
                raise AttributeError(
                    f"Cannot add new attribute '{name}' - only possible to "
                    f"modify existing attributes: {self.__dict__.keys()}"
                )


# Example usage...

# Create an instance of the class
params = Param()

# Successfully modify an existing attribute
params.param1 = "newtest"

# Attempts to add new attributes should raise an error
try:
    params.new_attribute = 3  # pylint: disable=attribute-defined-outside-init
except AttributeError as e:
    print(f"Error: {e}")
Error: Cannot add new attribute 'new_attribute' - only possible to modify existing attributes: dict_keys(['_initialising', 'param1', 'param2'])


Using class inheritance:

Class inheritance allows a class to serve as a blueprint for another, passing down attributes and methods. This approach is useful for sharing logic - such as restricting new attribute creation—across multiple classes - reducing code duplication and making each class simpler.

Here, we have three classes, with each subsequent class inheriting from those above:

  1. RestrictAttributesMeta (metaclass).
  2. RestrictAttributes (parent/base class).
  3. Param (child/derived class).

RestrictAttributesMeta (metaclass): A metaclass controls how classes and instances are created. In this case, it adds an _initialised flag after __init__ completes.

class RestrictAttributesMeta(type):
    """
    Metaclass for attribute restriction.

    A metaclass modifies class construction. It intercepts instance creation
    via __call__, adding the _initialised flag after __init__ completes. This
    is later used by RestrictAttributes to enforce attribute restrictions.
    """
    def __call__(cls, *args, **kwargs):
        # Create instance using the standard method
        instance = super().__call__(*args, **kwargs)
        # Set the "_initialised" flag to True, marking end of initialisation
        instance.__dict__["_initialised"] = True
        return instance

RestrictAttributes (parent/base class): A parent or base class controls the behaviour (but doesn’t impact initialisation). Here, the class inherits from RestrictAttributesMeta, and adds a new method which prevents the addition of new attributes after initialisation.

# pylint: disable=too-few-public-methods
class RestrictAttributes(metaclass=RestrictAttributesMeta):
    """
    Base class that prevents the addition of new attributes after
    initialisation.

    This class uses RestrictAttributesMeta as its metaclass to implement
    attribute restriction. It allows for safe initialisation of attributes
    during the __init__ method, but prevents the addition of new attributes
    afterwards.

    The restriction is enforced through the custom __setattr__ method, which
    checks if the attribute already exists before allowing assignment.
    """
    def __setattr__(self, name, value):
        """
        Prevent addition of new attributes.

        Parameters
        ----------
        name: str
            The name of the attribute to set.
        value: any
            The value to assign to the attribute.

        Raises
        ------
        AttributeError
            If `name` is not an existing attribute and an attempt is made
            to add it to the class instance.
        """
        # Check if the instance is initialised and the attribute doesn"t exist
        if hasattr(self, "_initialised") and not hasattr(self, name):
            # Get a list of existing attributes for the error message
            existing = ", ".join(self.__dict__.keys())
            raise AttributeError(
                f"Cannot add new attribute '{name}' - only possible to " +
                f"modify existing attributes: {existing}."
            )
        # If checks pass, set the attribute using the standard method
        object.__setattr__(self, name, value)

As a child or derived class, Param inherits behavior from RestrictAttributes (and, by extension, its metaclass RestrictAttributesMeta). This means Param automatically gains the validation logic defined in its parent classes.

# pylint: disable=too-few-public-methods,function-redefined
class Param(RestrictAttributes):
    """
    Parameter class with validation to prevent the addition of new attributes.
    """
    def __init__(self, param1="test", param2=42):
        """
        Initialise Param instance.
        """
        self.param1 = param1
        self.param2 = param2


# Example usage...

# Create an instance of the class
params = Param()

# Successfully modify an existing attribute
params.param1 = "newtest"

# Attempts to add new attributes should raise an error
try:
    params.new_attribute = 3
except AttributeError as e:
    print(f"Error: {e}")
Error: Cannot add new attribute 'new_attribute' - only possible to modify existing attributes: param1, param2, _initialised.

By default, R6 classes prevent the addition of new fields. However, there are ways you can override this behavior or set up your classes differently that would not have this protection for your parameters. These include:

  • Setting lock_objects = FALSE.
  • Setting parameters within a list.


Setting lock_objects = FALSE

The prevention of new fields is thanks to the default lock_objects = TRUE setting. If we override this and set lock_objects = FALSE, it will not raise an error when new fields are added. Therefore, it’s important not to override this default behavior.

Original class with lock_objects = TRUE:

#' @title Returns transfer_prob for validation example.
#'
#' @field transfer_prob Numeric. Transfer probability (0-1).

ParamClass <- R6Class( # nolint: object_name_linter
  public = list(
    transfer_prob = NULL,

    #' @description
    #' Initialises the R6 object.

    initialize = function(transfer_prob = 0.3) {
      self$transfer_prob <- transfer_prob
    }
  )
)
# Create instance of ParamClass
params <- ParamClass$new()

# Mistype transfer_prob
try({
  params$transfer_probs <- 0.4
})
Error in params$transfer_probs <- 0.4 : 
  cannot add bindings to a locked environment

Same class with lock_objects = TRUE:

#' @title Returns transfer_prob for validation example.
#'
#' @field transfer_prob Numeric. Transfer probability (0-1).

ParamClass <- R6Class( # nolint: object_name_linter
  lock_objects = FALSE, 
  public = list(
    transfer_prob = NULL,

    #' @description
    #' Initialises the R6 object.

    initialize = function(transfer_prob = 0.3) {
      self$transfer_prob <- transfer_prob
    }
  )
)
# Create instance of ParamClass
params <- ParamClass$new()

# Mistype transfer_prob
try({
  params$transfer_probs <- 0.4
})
params 
<R6>
  Public:
    clone: function (deep = FALSE) 
    initialize: function (transfer_prob = 0.3) 
    transfer_prob: 0.3
    transfer_probs: 0.4


Setting parameters within a list

If you set up your R6 class with each parameter as a class field, then by default, it will have validation to prevent the addition of new fields.

You may choose to store parameters in a list instead, to make it easier to access them all at once - but, if you do this, the validation won’t apply.

Original class with individual fields:

#' @title Returns transfer_prob for validation example.
#'
#' @field transfer_prob Numeric. Transfer probability (0-1).

ParamClass <- R6Class( # nolint: object_name_linter
  public = list(
    transfer_prob = NULL,

    #' @description
    #' Initialises the R6 object.

    initialize = function(transfer_prob = 0.3) {
      self$transfer_prob <- transfer_prob
    }
  )
)
# Create instance of ParamClass
params <- ParamClass$new()

# Mistype transfer_prob
try({
  params$transfer_probs <- 0.4
})
Error in params$transfer_probs <- 0.4 : 
  cannot add bindings to a locked environment

Class with parameters in a list:

#' @title Returns transfer_prob for validation example.
#'
#' @field transfer_prob Numeric. Transfer probability (0-1).

ParamClass <- R6Class( # nolint: object_name_linter
  public = list(
    parameters = NULL, 

    #' @description
    #' Initialises the R6 object.

    initialize = function(transfer_prob = 0.3) {
      self$parameters <- list( 
        transfer_prob = transfer_prob 
      ) 
    }
  )
)
# Create instance of ParamClass
params <- ParamClass$new()

# Mistype transfer_prob
try({
  params$parameters$transfer_probs <- 0.4
})
params$parameters 
$transfer_prob
[1] 0.3

$transfer_probs
[1] 0.4


However, there is a clean solution that allows you to access parameters easily while maintaining individual fields with built-in validation. By adding a get_params() method, you can extract parameters without storing them in a list.

Original class with individual fields:

#' @title Returns transfer_prob for validation example.
#'
#' @field transfer_prob Numeric. Transfer probability (0-1).

ParamClass <- R6Class( # nolint: object_name_linter
  public = list(
    transfer_prob = NULL,

    #' @description
    #' Initialises the R6 object.

    initialize = function(transfer_prob = 0.3) {
      self$transfer_prob <- transfer_prob
    }
  )
)
# Create instance of ParamClass
params <- ParamClass$new()

# Mistype transfer_prob
try({
  params$transfer_probs <- 0.4
})
Error in params$transfer_probs <- 0.4 : 
  cannot add bindings to a locked environment

Same class with added get_params() method:

#' @title Returns transfer_prob for validation example.
#'
#' @field transfer_prob Numeric. Transfer probability (0-1).

ParamClass <- R6Class( # nolint: object_name_linter
  public = list(
    transfer_prob = NULL,

    #' @description
    #' Initialises the R6 object.

    initialize = function(transfer_prob = 0.3) {
      self$transfer_prob <- transfer_prob
    },

    #' @description 
    #' Returns parameters as a named list. 

    get_params = function() { 
      # Get all non-function fields 
      all_names <- ls(self) 
      is_not_function <- vapply( 
        all_names, 
        function(x) !is.function(self[[x]]), 
        FUN.VALUE = logical(1L) 
      ) 
      param_names <- all_names[is_not_function] 
      mget(param_names, envir = self) 
    } 
  )
)
# Create instance of ParamClass
params <- ParamClass$new()

# Mistype transfer_prob
try({
  params$transfer_probs <- 0.4
})
Error in params$transfer_probs <- 0.4 : 
  cannot add bindings to a locked environment
# Get all parameters 
params$get_params() 
$transfer_prob
[1] 0.3


3 Validating parameter values (e.g. range)

You can check that the provided inputs are valid - expected format, range, etc.

This can either be:

  • When using functions: Implement parameter validation within your model functions.
  • When using classes: Build validation directly into the class structure.


3.1 Parameter validation within the model functions

As a reminder, this is our parameter function:

# pylint: disable=missing-module-docstring
def param_function(transfer_prob=0.3):
    """
    Returns transfer_prob for validation example.

    Parameters
    ----------
    transfer_prob : float
        Transfer probability (0-1).

    Returns
    -------
    Dictionary containing the transfer_prob parameter.
    """
    return {"transfer_prob": transfer_prob}
#' Returns transfer_prob for validation example.
#'
#' @param transfer_prob Numeric. Transfer probability (0-1).
#'
#' @return A named list containing the transfer_prob parameter.

param_function <- function(transfer_prob = 0.3) {
  list(transfer_prob = transfer_prob)
}

We can write a validation function which checks that the provided transfer probability is between 0 and 1.

def validate_param(parameters):
    """ 
    Check that the transfer probability is between 0 and 1.

    Parameters
    ----------
    parameters : dict
      Dictionary of parameters.
    """
    transfer_prob = parameters["transfer_prob"]
    if transfer_prob < 0 or transfer_prob > 1:
        raise ValueError(
          f"transfer_prob must be between 0 and 1, but is: {transfer_prob}"
        )
#' @title Check that the transfer probability is between 0 and 1.
#'
#' @param parameters Named list of model parameters.

validate_param <- function(parameters) {
  transfer_prob <- parameters$transfer_prob
  if (transfer_prob < 0L || transfer_prob > 1L) {
    stop(
      "transfer_prob must be between 0 and 1, but is: ", transfer_prob,
      call. = FALSE
    )
  }
}

Then, in our model function, we call the validation function to check all inputs before proceeding with the simulation:

def model(param_dict):
    """
    Run simulation.

    Parameters
    ----------
    param_dict : dict
        Dictionary of parameters.
    """
    # Check all inputs are valid
    validate_param(parameters=param_dict)

    # Simulation code...


# Example usage
# Run model() invalid transfer_prob - raises an error
param = param_function(transfer_prob = 1.4)
try:
    model(param)
except ValueError as e:
    print(e)
transfer_prob must be between 0 and 1, but is: 1.4
#' Run simulation.
#'
#' @param param Named list of model parameters.

model <- function(param) {

  # Check all inputs are valid
  validate_param(parameters = param)

  # Simulation code...
}


# Example usage
# Run model() invalid transfer_prob - raises an error
params <- param_function({
  transfer_prob <- 1.4
})
try(model(params))
Error : transfer_prob must be between 0 and 1, but is: 1.4


3.2 Parameter validation within the class

The new validate_param() method checks whether transfer_prob is between 0 and 1.

Although this is defined within the class, it could also be called from within the model function, so that all parameters are checked before the simulation runs.

Original ParamClass:

# pylint: disable=missing-module-docstring, invalid-name, too-few-public-methods
class ParamClass:
    """
    Returns transfer_prob for validation example.
    """
    def __init__(self, transfer_prob=0.3):
        """
        Initialise ParamClass instance.

        Parameters
        ----------
        transfer_prob : float
            Transfer probability (0-1).
        """
        self.transfer_prob = transfer_prob
# Create instance with invalid transfer_prob - no error raised
param = ParamClass(transfer_prob = 1.4)
#' @title Returns transfer_prob for validation example.
#'
#' @field transfer_prob Numeric. Transfer probability (0-1).

ParamClass <- R6Class( # nolint: object_name_linter
  public = list(
    transfer_prob = NULL,

    #' @description
    #' Initialises the R6 object.

    initialize = function(transfer_prob = 0.3) {
      self$transfer_prob <- transfer_prob
    }
  )
)
# Create instance with invalid transfer_prob - no error raised
param <- ParamClass$new(transfer_prob = 1.4)

Same class with added validate_param() method:

class ParamClass:
    """
    Returns transfer_prob for validation example.
    """
    def __init__(self, transfer_prob=0.3):
        """
        Initialise ParamClass instance.

        Parameters
        ----------
        transfer_prob : float
            Transfer probability (0-1).
        """
        self.transfer_prob = transfer_prob

    def validate_param(self): 
        """ 
        Check that transfer_prob is between 0 and 1. 
        """ 
        if self.transfer_prob < 0 or self.transfer_prob > 1: 
            raise ValueError("transfer_prob must be between 0 and 1" + 
                             f", but is: {self.transfer_prob}") 
# Create instance of ParamClass with invalid transfer_prob and run method
param = ParamClass(transfer_prob = 1.4)
try: 
    param.validate_param() 
except ValueError as e: 
    print(f"Error: {e}") 
Error: transfer_prob must be between 0 and 1, but is: 1.4
#' @title Returns transfer_prob for validation example.
#'
#' @field transfer_prob Numeric. Transfer probability (0-1).

ParamClass <- R6Class( # nolint: object_name_linter
  public = list(
    transfer_prob = NULL,

    #' @description
    #' Initialises the R6 object.

    initialize = function(transfer_prob = 0.3) {
      self$transfer_prob <- transfer_prob
    },

    #' @description 
    #' Check that transfer_prob is between 0 and 1. 
    #' @return No return value; throws an error if invalid. 

    validate_param = function() { 
      if (self$transfer_prob < 0L || self$transfer_prob > 1L) { 
        stop( 
          "transfer_prob must be between 0 and 1, but is: ", 
          self$transfer_prob, call. = FALSE 
        ) 
      } 
    } 
  )
)
# Create instance with invalid transfer_prob and run method
param <- ParamClass$new(transfer_prob = 1.4)
try(param$validate_param()) 
Error : transfer_prob must be between 0 and 1, but is: 1.4


4 Examples

This section contains full code examples for our example conceptual models.


Show/Hide example 1: 🩺 Nurse visit simulation


This example is from simulation/param.py in pydesrap_mms.

"""
Param.
"""

from .simlogger import SimLogger


# pylint: disable=too-many-instance-attributes,too-few-public-methods

class Param:
    """
    Default parameters for simulation.

    Attributes
    ----------
    _initialising : bool
        Whether the object is currently initialising.
    patient_inter : float
        Mean inter-arrival time between patients in minutes.
    mean_n_consult_time : float
        Mean nurse consultation time in minutes.
    number_of_nurses : float
        Number of available nurses.
    warm_up_period : int
        Duration of the warm-up period in minutes.
    data_collection_period : int
        Duration of data collection period in minutes.
    number_of_runs : int
        The number of runs (i.e. replications).
    audit_interval : int
        How frequently to audit resource utilisation, in minutes.
    scenario_name : int|float|str
        Label for the scenario.
    cores : int
        Number of CPU cores to use for parallel execution. For all
        available cores, set to -1. For sequential execution, set to 1.
    logger : logging.Logger
        The logging instance used for logging messages.
    """
    # pylint: disable=too-many-arguments,too-many-positional-arguments
    def __init__(
        self,
        patient_inter=4,
        mean_n_consult_time=10,
        number_of_nurses=5,
        warm_up_period=1440*27,  # 27 days
        data_collection_period=1440*30,  # 30 days
        number_of_runs=31,
        audit_interval=120,  # Every 2 hours
        scenario_name=0,
        cores=-1,
        logger=SimLogger(log_to_console=False, log_to_file=False)
    ):
        """
        Initialise instance of parameters class.

        Parameters
        ----------
        patient_inter : float, optional
            Mean inter-arrival time between patients in minutes.
        mean_n_consult_time : float, optional
            Mean nurse consultation time in minutes.
        number_of_nurses : float, optional
            Number of available nurses.
        warm_up_period : int, optional
            Duration of the warm-up period in minutes.
        data_collection_period : int, optional
            Duration of data collection period in minutes.
        number_of_runs : int, optional
            The number of runs (i.e. replications).
        audit_interval : int, optional
            How frequently to audit resource utilisation, in minutes.
        scenario_name : int|float|str, optional
            Label for the scenario.
        cores : int, optional
            Number of CPU cores to use for parallel execution.
        logger : logging.Logger, optional
            The logging instance used for logging messages.
        """
        # Disable restriction on attribute modification during initialisation
        object.__setattr__(self, "_initialising", True)
        self.patient_inter = patient_inter
        self.mean_n_consult_time = mean_n_consult_time
        self.number_of_nurses = number_of_nurses
        self.warm_up_period = warm_up_period
        self.data_collection_period = data_collection_period
        self.number_of_runs = number_of_runs
        self.audit_interval = audit_interval
        self.scenario_name = scenario_name
        self.cores = cores
        self.logger = logger

        # Re-enable attribute checks after initialisation
        object.__setattr__(self, "_initialising", False)

    def __setattr__(self, name, value):
        """
        Prevent addition of new attributes.

        Parameters
        ----------
        name : str
            The name of the attribute to set.
        value : Any
            The value to assign to the attribute.

        Raises
        ------
        AttributeError
            If `name` is not an existing attribute and an attempt is made
            to add it to the instance.
        """
        # Skip the check if the object is still initialising
        # pylint: disable=maybe-no-member
        if hasattr(self, "_initialising") and self._initialising:
            super().__setattr__(name, value)
        else:
            # Check if attribute of that name is already present
            if name in self.__dict__:
                super().__setattr__(name, value)
            else:
                raise AttributeError(
                    f"Cannot add new attribute '{name}' - only possible to "
                    f"modify existing attributes: {self.__dict__.keys()}"
                )

This example is from R/validate_model_inputs.R in rdesrap_mms.

#' Validate input parameters for the simulation.
#'
#' @param run_number Integer representing index of current simulation run.
#' @param param List containing parameters for the simulation.
#'
#' @return Throws an error if any parameter is invalid.
#' @export

valid_inputs <- function(run_number, param) {
  check_run_number(run_number)
  check_param_names(param)
  check_param_values(param)
}


#' Checks if the run number is a non-negative integer.
#'
#' @param run_number Integer representing index of current simulation run.
#'
#' @return Throws an error if the run number is invalid.

check_run_number <- function(run_number) {
  if (run_number < 0L || run_number %% 1L != 0L) {
    stop("The run number must be a non-negative integer. Provided: ",
         run_number, call. = FALSE)
  }
}


#' Validate parameter names.
#'
#' Ensure that all required parameters are present, and no extra parameters are
#' provided.
#'
#' @param param List containing parameters for the simulation.
#'
#' @return Throws an error if there are missing or extra parameters.

check_param_names <- function(param) {
  # Get valid argument names from the function
  valid_names <- names(formals(parameters))

  # Get names from input parameter list
  input_names <- names(param)

  # Find missing keys (i.e. are there things in valid_names not in input)
  # and extra keys (i.e. are there things in input not in valid_names)
  missing_keys <- setdiff(valid_names, input_names)
  extra_keys <- setdiff(input_names, valid_names)

  # If there are any missing or extra keys, throw an error
  if (length(missing_keys) > 0L || length(extra_keys) > 0L) {
    error_message <- ""
    if (length(missing_keys) > 0L) {
      error_message <- paste0(
        error_message, "Missing keys: ", toString(missing_keys), ". "
      )
    }
    if (length(extra_keys) > 0L) {
      error_message <- paste0(
        error_message, "Extra keys: ", toString(extra_keys), ". "
      )
    }
    stop(error_message, call. = FALSE)
  }
}


#' Validate parameter values.
#'
#' Ensure that specific parameters are positive numbers, or non-negative
#' integers.
#'
#' @param param List containing parameters for the simulation.
#'
#' @return Throws an error if any specified parameter value is invalid.

check_param_values <- function(param) {

  # Check that listed parameters are always positive
  p_list <- c("patient_inter", "mean_n_consult_time", "number_of_runs")
  for (p in p_list) {
    if (param[[p]] <= 0L) {
      stop('The parameter "', p, '" must be greater than 0.', call. = FALSE)
    }
  }

  # Check that listed parameters are non-negative integers
  n_list <- c("warm_up_period", "data_collection_period", "number_of_nurses")
  for (n in n_list) {
    if (param[[n]] < 0L || param[[n]] %% 1L != 0L) {
      stop('The parameter "', n,
           '" must be an integer greater than or equal to 0.', call. = FALSE)
    }
  }
}


Show/Hide example 2: 🧠 Stroke pathway simulation


This example is from simulation/parameters.py in pydesrap_stroke.

"""
Stroke pathway simulation parameters.

It includes arrival rates, length of stay distributions, and routing
probabilities between different care settings.
"""

import time

from simulation.logging import SimLogger


class RestrictAttributesMeta(type):
    """
    Metaclass for attribute restriction.

    A metaclass modifies class construction. It intercepts instance creation
    via __call__, adding the _initialised flag after __init__ completes. This
    is later used by RestrictAttributes to enforce attribute restrictions.
    """
    def __call__(cls, *args, **kwargs):
        # Create instance using the standard method
        instance = super().__call__(*args, **kwargs)
        # Set the "_initialised" flag to True, marking end of initialisation
        instance.__dict__["_initialised"] = True
        return instance


class RestrictAttributes(metaclass=RestrictAttributesMeta):
    """
    Base class that prevents the addition of new attributes after
    initialisation.

    This class uses RestrictAttributesMeta as its metaclass to implement
    attribute restriction. It allows for safe initialisation of attributes
    during the __init__ method, but prevents the addition of new attributes
    afterwards.

    The restriction is enforced through the custom __setattr__ method, which
    checks if the attribute already exists before allowing assignment.
    """
    def __setattr__(self, name, value):
        """
        Prevent addition of new attributes.

        Parameters
        ----------
        name: str
            The name of the attribute to set.
        value: any
            The value to assign to the attribute.

        Raises
        ------
        AttributeError
            If `name` is not an existing attribute and an attempt is made
            to add it to the class instance.
        """
        # Check if the instance is initialised and the attribute doesn"t exist
        if hasattr(self, "_initialised") and not hasattr(self, name):
            # Get a list of existing attributes for the error message
            existing = ", ".join(self.__dict__.keys())
            raise AttributeError(
                f"Cannot add new attribute '{name}' - only possible to " +
                f"modify existing attributes: {existing}."
            )
        # If checks pass, set the attribute using the standard method
        object.__setattr__(self, name, value)


class ASUArrivals(RestrictAttributes):
    """
    Arrival rates for the acute stroke unit (ASU) by patient type.

    These are the average time intervals (in days) between new admissions.
    For example, a value of 1.2 means a new admission every 1.2 days.
    """
    def __init__(self, stroke=1.2, tia=9.3, neuro=3.6, other=3.2):
        """
        Parameters
        ----------
        stroke: float
            Stroke patient.
        tia: float
            Transient ischaemic attack (TIA) patient.
        neuro: float
            Complex neurological patient.
        other: float
            Other patient types (including medical outliers).
        """
        self.stroke = stroke
        self.tia = tia
        self.neuro = neuro
        self.other = other


class RehabArrivals(RestrictAttributes):
    """
    Arrival rates for the rehabiliation unit by patient type.

    These are the average time intervals (in days) between new admissions.
    For example, a value of 21.8 means a new admission every 21.8 days.
    """
    def __init__(self, stroke=21.8, neuro=31.7, other=28.6):
        """
        Parameters
        ----------
        stroke: float
            Stroke patient.
        neuro: float
            Complex neurological patient.
        other: float
            Other patient types.
        """
        self.stroke = stroke
        self.neuro = neuro
        self.other = other


class ASULOS(RestrictAttributes):
    """
    Mean and standard deviation (SD) of length of stay (LOS) in days in the
    acute stroke unit (ASU) by patient type.

    Attributes
    ----------
    stroke_noesd: dict
        Mean and SD of LOS for stroke patients without early support discharge.
    stroke_esd: dict
        Mean and SD of LOS for stroke patients with early support discharge.
    tia: dict
        Mean and SD of LOS for transient ischemic attack (TIA) patients.
    neuro: dict
        Mean and SD of LOS for complex neurological patients.
    other: dict
        Mean and SD of LOS for other patients.
    """
    def __init__(
        self,
        stroke_no_esd_mean=7.4, stroke_no_esd_sd=8.61,
        stroke_esd_mean=4.6, stroke_esd_sd=4.8,
        stroke_mortality_mean=7.0, stroke_mortality_sd=8.7,
        tia_mean=1.8, tia_sd=2.3,
        neuro_mean=4.0, neuro_sd=5.0,
        other_mean=3.8, other_sd=5.2
    ):
        """
        Parameters
        ----------
        stroke_no_esd_mean: float
            Mean LOS for stroke patients without early support discharge (ESD)
            services.
        stroke_no_esd_sd: float
            SD of LOS for stroke patients without ESD.
        stroke_esd_mean: float
            Mean LOS for stroke patients with ESD.
        stroke_esd_sd: float
            SD of LOS for stroke patients with ESD.
        stroke_mortality_mean: float
            Mean LOS for stroke patients who pass away.
        stroke_mortality_sd: float
            SD of LOS for stroke patients who pass away.
        tia_mean: float
            Mean LOS for TIA patients.
        tia_sd: float
            SD of LOS for TIA patients.
        neuro_mean: float
            Mean LOS for complex neurological patients.
        neuro_sd: float
            SD of LOS for complex neurological patients.
        other_mean: float
            Mean LOS for other patient types.
        other_sd: float
            SD of LOS for other patient types.
        """
        self.stroke_noesd = {
            "mean": stroke_no_esd_mean,
            "sd": stroke_no_esd_sd
        }
        self.stroke_esd = {
            "mean": stroke_esd_mean,
            "sd": stroke_esd_sd
        }
        self.stroke_mortality = {
            "mean": stroke_mortality_mean,
            "sd": stroke_mortality_sd
        }
        self.tia = {
            "mean": tia_mean,
            "sd": tia_sd
        }
        self.neuro = {
            "mean": neuro_mean,
            "sd": neuro_sd
        }
        self.other = {
            "mean": other_mean,
            "sd": other_sd
        }


class RehabLOS(RestrictAttributes):
    """
    Mean and standard deviation (SD) of length of stay (LOS) in days in the
    rehabilitation unit by patient type.

    Attributes
    ----------
    stroke_noesd: dict
        Mean and SD of LOS for stroke patients without early support discharge.
    stroke_esd: dict
        Mean and SD of LOS for stroke patients with early support discharge.
    tia: dict
        Mean and SD of LOS for transient ischemic attack (TIA) patients.
    neuro: dict
        Mean and SD of LOS for complex neurological patients.
    other: dict
        Mean and SD of LOS for other patients.
    """
    def __init__(
        self,
        stroke_no_esd_mean=28.4, stroke_no_esd_sd=27.2,
        stroke_esd_mean=30.3, stroke_esd_sd=23.1,
        tia_mean=18.7, tia_sd=23.5,
        neuro_mean=27.6, neuro_sd=28.4,
        other_mean=16.1, other_sd=14.1
    ):
        """
        Parameters
        ----------
        stroke_no_esd_mean: float
            Mean LOS for stroke patients without early support discharge (ESD)
            services.
        stroke_no_esd_sd: float
            SD of LOS for stroke patients without ESD.
        stroke_esd_mean: float
            Mean LOS for stroke patients with ESD.
        stroke_esd_sd: float
            SD of LOS for stroke patients with ESD.
        tia_mean: float
            Mean LOS for TIA patients.
        tia_sd: float
            SD of LOS for TIA patients.
        neuro_mean: float
            Mean LOS for complex neurological patients.
        neuro_sd: float
            SD of LOS for complex neurological patients.
        other_mean: float
            Mean LOS for other patient types.
        other_sd: float
            SD of LOS for other patient types.
        """
        self.stroke_noesd = {
            "mean": stroke_no_esd_mean,
            "sd": stroke_no_esd_sd
        }
        self.stroke_esd = {
            "mean": stroke_esd_mean,
            "sd": stroke_esd_sd
        }
        self.tia = {
            "mean": tia_mean,
            "sd": tia_sd
        }
        self.neuro = {
            "mean": neuro_mean,
            "sd": neuro_sd
        }
        self.other = {
            "mean": other_mean,
            "sd": other_sd
        }


class ASURouting(RestrictAttributes):
    """
    Probabilities of each patient type being transferred from the acute
    stroke unit (ASU) to other destinations.

    Attributes
    ----------
    stroke: dict
        Routing probabilities for stroke patients.
    tia: dict
        Routing probabilities for transient ischemic attack (TIA) patients.
    neuro: dict
        Routing probabilities for complex neurological patients.
    other: dict
        Routing probabilities for other patients.
    """
    def __init__(
        self,
        # Stroke patients
        stroke_rehab=0.24, stroke_esd=0.13, stroke_other=0.63,
        # TIA patients
        tia_rehab=0.01, tia_esd=0.01, tia_other=0.98,
        # Complex neurological patients
        neuro_rehab=0.11, neuro_esd=0.05, neuro_other=0.84,
        # Other patients
        other_rehab=0.05, other_esd=0.10, other_other=0.85
    ):
        """
        Parameters
        ----------
        stroke_rehab: float
            Stroke patient to rehabilitation unit.
        stroke_esd: float
            Stroke patient to early support discharge (ESD) services.
        stroke_other: float
            Stroke patient to other destinations (e.g., own home, care
            home, mortality).
        tia_rehab: float
            TIA patient to rehabilitation unit.
        tia_esd: float
            TIA patient to ESD.
        tia_other: float
            TIA patient to other destinations.
        neuro_rehab: float
            Complex neurological patient to rehabilitation unit.
        neuro_esd: float
            Complex neurological patient to ESD.
        neuro_other: float
            Complex neurological patient to other destinations.
        other_rehab: float
            Other patient type to rehabilitation unit.
        other_esd: float
            Other patient type to ESD.
        other_other: float
            Other patient type to other destinations.
        """
        self.stroke = {
            "rehab": stroke_rehab,
            "esd": stroke_esd,
            "other": stroke_other
        }
        self.tia = {
            "rehab": tia_rehab,
            "esd": tia_esd,
            "other": tia_other
        }
        self.neuro = {
            "rehab": neuro_rehab,
            "esd": neuro_esd,
            "other": neuro_other
        }
        self.other = {
            "rehab": other_rehab,
            "esd": other_esd,
            "other": other_other
        }


class RehabRouting(RestrictAttributes):
    """
    Probabilities of each patient type being transferred from the rehabiliation
    unit to other destinations.

    Attributes
    ----------
    stroke: dict
        Routing probabilities for stroke patients.
    tia: dict
        Routing probabilities for transient ischemic attack (TIA) patients.
    neuro: dict
        Routing probabilities for complex neurological patients.
    other: dict
        Routing probabilities for other patients.
    """
    def __init__(
        self,
        # Stroke patients
        stroke_esd=0.40, stroke_other=0.60,
        # TIA patients
        tia_esd=0, tia_other=1,
        # Complex neurological patients
        neuro_esd=0.09, neuro_other=0.91,
        # Other patients
        other_esd=0.13, other_other=0.88
    ):
        """
        Parameters
        ----------
        stroke_esd: float
            Stroke patient to early support discharge (ESD) services.
        stroke_other: float
            Stroke patient to other destinations (e.g., own home, care home,
            mortality).
        tia_esd: float
            TIA patient to ESD.
        tia_other: float
            TIA patient to other destinations.
        neuro_esd: float
            Complex neurological patient to ESD.
        neuro_other: float
            Complex neurological patient to other destinations.
        other_esd: float
            Other patient type to ESD.
        other_other: float
            Other patient type to other destinations.
        """
        self.stroke = {
            "esd": stroke_esd,
            "other": stroke_other
        }
        self.tia = {
            "esd": tia_esd,
            "other": tia_other
        }
        self.neuro = {
            "esd": neuro_esd,
            "other": neuro_other
        }
        self.other = {
            "esd": other_esd,
            "other": other_other
        }


class Param(RestrictAttributes):
    """
    Default parameters for simulation.
    """
    def __init__(
        self,
        asu_arrivals=ASUArrivals(),
        rehab_arrivals=RehabArrivals(),
        asu_los=ASULOS(),
        rehab_los=RehabLOS(),
        asu_routing=ASURouting(),
        rehab_routing=RehabRouting(),
        warm_up_period=365*3,  # 3 years
        data_collection_period=365*5,  # 5 years
        number_of_runs=150,
        audit_interval=1,
        cores=1,
        log_to_console=False,
        log_to_file=False,
        log_file_path=("../outputs/logs/" +
                       f"{time.strftime("%Y-%m-%d_%H-%M-%S")}.log")
    ):
        """
        Initialise a parameter set for the simulation.

        Parameters
        ----------
        asu_arrivals: ASUArrivals
            Arrival rates to the acute stroke unit (ASU) in days.
        rehab_arrivals: RehabArrivals
            Arrival rates to the rehabilitation unit in days.
        asu_los: ASULOS
            Length of stay (LOS) distributions for patients in the ASU in days.
        rehab_los: RehabLOS
            LOS distributions for patients in the rehabilitation unit in days.
        asu_routing: ASURouting
            Transfer probabilities from the ASU to other destinations.
        rehab_routing: RehabRouting
            Transfer probabilities from the rehabilitation unit to other
            destinations.
        warm_up_period: int
            Length of the warm-up period.
        data_collection_period: int
            Length of the data collection period.
        number_of_runs: int
            The number of runs (i.e. replications), defining how many times to
            re-run the simulation (with different random numbers).
        audit_interval: float
            Frequency of simulation audits in days.
        cores: int
            Number of CPU cores to use for parallel execution. Set to desired
            number, or to -1 to use all available cores. For sequential
            execution, set to 1.
        log_to_console: boolean
            Whether to print log messages to the console.
        log_to_file: boolean
            Whether to save log to a file.
        log_file_path: str
            Path to save log to file. Note, if you use an existing .log
            file name, it will append to that log.
        """
        # Set parameters
        self.asu_arrivals = asu_arrivals
        self.rehab_arrivals = rehab_arrivals
        self.asu_los = asu_los
        self.rehab_los = rehab_los
        self.asu_routing = asu_routing
        self.rehab_routing = rehab_routing
        self.warm_up_period = warm_up_period
        self.data_collection_period = data_collection_period
        self.number_of_runs = number_of_runs
        self.audit_interval = audit_interval
        self.cores = cores

        # Set up logger
        self.logger = SimLogger(log_to_console=log_to_console,
                                log_to_file=log_to_file,
                                file_path=log_file_path)

    def check_param_validity(self):
        """
        Check the validity of the provided parameters.

        Validates all simulation parameters to ensure they meet requirements:
        - Warm-up period and data collection period must be >= 0
        - Number of runs and audit interval must be > 0
        - Arrival rates must be >= 0
        - Length of stay parameters must be >= 0
        - Routing probabilities must sum to 1 and be between 0 and 1

        Raises
        ------
        ValueError
            If any parameter fails validation with a descriptive error message.
        """
        # Validate parameters that must be >= 0
        for param in ["warm_up_period", "data_collection_period"]:
            self.validate_param(
                param, lambda x: x >= 0,
                "must be greater than or equal to 0")

        # Validate parameters that must be > 0
        for param in ["number_of_runs", "audit_interval"]:
            self.validate_param(
                param, lambda x: x > 0,
                "must be greater than 0")

        # Validate arrival parameters
        for param in ["asu_arrivals", "rehab_arrivals"]:
            self.validate_nested_param(
                param, lambda x: x >= 0,
                "must be greater than 0")

        # Validate length of stay parameters
        for param in ["asu_los", "rehab_los"]:
            self.validate_nested_param(
                param, lambda x: x >= 0,
                "must be greater than 0", nested=True)

        # Validate routing parameters
        for param in ["asu_routing", "rehab_routing"]:
            self.validate_routing(param)

    def validate_param(self, param_name, condition, error_msg):
        """
        Validate a single parameter against a condition.

        Parameters
        ----------
        param_name: str
            Name of the parameter being validated.
        condition: callable
            A function that returns True if the value is valid.
        error_msg: str
            Error message to display if validation fails.

        Raises
        ------
        ValueError:
            If the parameter fails the validation condition.
        """
        value = getattr(self, param_name)
        if not condition(value):
            raise ValueError(
                f"Parameter '{param_name}' {error_msg}, but is: {value}")

    def validate_nested_param(
        self, obj_name, condition, error_msg, nested=False
    ):
        """
        Validate parameters within a nested object structure.

        Parameters
        ----------
        obj_name: str
            Name of the object containing parameters.
        condition: callable
            A function that returns True if the value is valid.
        error_msg: str
            Error message to display if validation fails.
        nested: bool, optional
            If True, validates parameters in a doubly-nested structure. If
            False, validates parameters in a singly-nested structure.

        Raises
        ------
        ValueError:
            If any nested parameter fails the validation condition.
        """
        obj = getattr(self, obj_name)
        for key, value in vars(obj).items():
            if key == "_initialised":
                continue
            if nested:
                for sub_key, sub_value in value.items():
                    if not condition(sub_value):
                        raise ValueError(
                            f"Parameter '{sub_key}' for '{key}' in " +
                            f"'{obj_name}' {error_msg}, but is: {sub_value}")
            else:
                if not condition(value):
                    raise ValueError(
                        f"Parameter '{key}' from '{obj_name}' {error_msg}, " +
                        f"but is: {value}")

    def validate_routing(self, obj_name):
        """
        Validate routing probability parameters.

        Performs two validations:
        1. Checks that all probabilities for each routing option sum to 1.
        2. Checks that individual probabilities are between 0 and 1 inclusive.

        Parameters
        ----------
        obj_name: str
            Name of the routing object.

        Raises
        ------
        ValueError:
            If the probabilities don't sum to 1, or if any probability is
            outside [0,1].
        """
        obj = getattr(self, obj_name)
        for key, value in vars(obj).items():
            if key == "_initialised":
                continue

            # Check that probabilities sum to 1
            # Note: In the article, rehab other is 88% and 13%, so have
            # allowed deviation of 1%
            total_prob = sum(value.values())
            if total_prob < 0.99 or total_prob > 1.01:
                raise ValueError(
                    f"Routing probabilities for '{key}' in '{obj_name}' " +
                    f"should sum to apx. 1 but sum to: {total_prob}")

            # Check that probabilities are between 0 and 1
            for sub_key, sub_value in value.items():
                if sub_value < 0 or sub_value > 1:
                    raise ValueError(
                        f"Parameter '{sub_key}' for '{key}' in '{obj_name}'" +
                        f"must be between 0 and 1, but is: {sub_value}")

TODO