# Import model
import model
# Data processing and figure generation
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# Manage file paths
from dataclasses import dataclass
import os
# Hide warnings that appear
import warnings
from pandas.core.common import SettingWithCopyWarning
='ignore', category=SettingWithCopyWarning)
warnings.simplefilter(action'ignore')
warnings.filterwarnings(
# To record runtime of this notebook
import time
Reproduction
This notebook reproduces all the results from:
Lim CY, Bohn MK, Lippi G, Ferrari M, Loh TP, Yuen K, Adeli K, Horvath AR Staff Rostering, Split Team Arrangement, Social Distancing (Physical Distancing) and Use of Personal Protective Equipment to Minimize Risk of Workplace Transmission During the COVID-19 Pandemic: A Simulation Study. Clinical Biochemistry 86:15-22 (2020). https://doi.org/10.1016/j.clinbiochem.2020.09.003.
If running all scenarios from scratch, the total run time for this notebook is 49 minutes and 17 seconds. This time was from reproduction on an Intel Core i7-12700H with 32GB RAM running Ubuntu 22.04.4 Linux.
The individual run times for each scenario were: * Base (15% probability of secondary infection) - 7m 29s * Base with 5% probability of secondary infection - 5m 49s * Base with 30% probability of secondary infection - 8m 38s * Random shift assignment - 1m 45s * Protective measures: social distancing - 5m 21s * Protective measures: gloves - 5m 20s * Protective measures: surgical mask - 5m 5s * Protective measures: gown - 4m 51s * Protective measures: N95 mask - 4m 54s
Set-up
Import required packages
Set file paths
@dataclass(frozen=True)
class Paths:
'''Singleton object for storing paths to data and database.'''
= '../outputs'
outputs = 'output_base5.csv'
base_5 = 'output_base15.csv'
base_15 = 'output_base30.csv'
base_30 = 'output_base15_norest.csv'
no_rest = 'output_base15_contact20.csv'
contact_half = 'output_base15_gloves.csv'
gloves = 'output_base15_surgicalmask.csv'
surgical_mask = 'output_base15_gown.csv'
gown = 'output_base15_n95mask.csv'
n95_mask
= '../../original_study'
paper = 'supp_tab2_reformat.csv'
tab2 = 'supp_tab3_reformat.csv'
tab3 = 'supp_tab4_reformat.csv'
tab4 = 'supp_tab5_reformat.csv'
tab5 = 'supp_tab6_reformat.csv'
tab6 = 'fig2.png'
fig2 = 'fig3.png'
fig3 = 'fig4.png'
fig4 = 'fig5.png'
fig5
= Paths() paths
Run model
= False
run_model
# Start timer if running models
if run_model:
= time.time() start
Base scenario (15% probability of secondary infection)
if run_model:
# Run model
= model.run_scenarios()
res # Save results to CSV
=False) res.to_csv(os.path.join(paths.outputs, paths.base_15), index
Base scenario with 5% probability of secondary infection
if run_model:
# Run model
= model.run_scenarios(secondary_attack_rate=0.05)
res # Save results to CSV
=False) res.to_csv(os.path.join(paths.outputs, paths.base_5), index
Base scenario with 30% probability of secondary infection
if run_model:
# Run model
= model.run_scenarios(secondary_attack_rate=0.3)
res # Save results to CSV
=False) res.to_csv(os.path.join(paths.outputs, paths.base_30), index
Base scenario (15% probability of secondary infection) with only one shift per day, and without a predefined minimum rest day after shifts (to simulate random shift assignment after each shift).
if run_model:
# Run model
= model.run_scenarios(shift_day=[1], rest_day=False)
res # Save results to CSV
=False) res.to_csv(os.path.join(paths.outputs, paths.no_rest), index
Base scenario (15% probability of secondary infection) with contact rate halved (to simulate effect of workplace social distancing)
if run_model:
# Run model
= model.run_scenarios(contact_rate=0.2)
res # Save results to CSV
=False) res.to_csv(os.path.join(paths.outputs, paths.contact_half), index
Base scenario with altered probability of secondary infection (to simulate workplace protective measures)
if run_model:
# Run model
= model.run_scenarios(secondary_attack_rate=0.15*0.45)
gloves # Save results to csv
=False) gloves.to_csv(os.path.join(paths.outputs, paths.gloves), index
if run_model:
# Run model
= model.run_scenarios(secondary_attack_rate=0.15*0.32)
surgical_mask # Save results to csv
=False) surgical_mask.to_csv(os.path.join(paths.outputs, paths.surgical_mask), index
if run_model:
# Run model
= model.run_scenarios(secondary_attack_rate=0.15*0.23)
gown # Save results to csv
=False) gown.to_csv(os.path.join(paths.outputs, paths.gown), index
if run_model:
# Run model
= model.run_scenarios(secondary_attack_rate=0.15*0.09)
n95_mask # Save results to csv
=False) n95_mask.to_csv(os.path.join(paths.outputs, paths.n95_mask), index
Import results
= pd.read_csv(os.path.join(paths.outputs, paths.base_15))
model_base15 = pd.read_csv(os.path.join(paths.paper, paths.tab2))
paper_tab2
= pd.read_csv(os.path.join(paths.outputs, paths.base_5))
model_base5 = pd.read_csv(os.path.join(paths.paper, paths.tab3))
paper_tab3
= pd.read_csv(os.path.join(paths.outputs, paths.base_30))
model_base30 = pd.read_csv(os.path.join(paths.paper, paths.tab4))
paper_tab4
= pd.read_csv(os.path.join(paths.outputs, paths.no_rest))
model_norest = pd.read_csv(os.path.join(paths.paper, paths.tab5))
paper_tab5
= pd.read_csv(os.path.join(paths.outputs, paths.contact_half))
model_contact = pd.read_csv(os.path.join(paths.outputs, paths.gloves))
model_gloves = pd.read_csv(os.path.join(paths.outputs, paths.surgical_mask))
model_surgical = pd.read_csv(os.path.join(paths.outputs, paths.gown))
model_gown = pd.read_csv(os.path.join(paths.outputs, paths.n95_mask))
model_n95 = pd.read_csv(os.path.join(paths.paper, paths.tab6)) paper_tab6
Function to compare tables
def compare_tables(model_tab, paper_tab):
'''
Combine the model and paper tables into single dataframe with a diff column
Parameters:
-----------
model_tab : dataframe
Raw output from model
paper_tab : dataframe
Reformatted table from supplementary materials
'''
# Merge the dataframes
= pd.merge(
comp ={'prop_infected': 'prop_infected_model'}),
model_tab.rename(columns={'prop_infected': 'prop_infected_paper'}))
paper_tab.rename(columns
# Calculate difference
'diff'] = abs(comp['prop_infected_model'] -
comp['prop_infected_paper'])
comp[
return comp
Examine differences in supplementary tables
Supplementary table 2
# Get table 2 and save to csv
= model_base15[model_base15['end_of_day'].isin([7, 14, 21])]
model_tab2 =False) model_tab2.to_csv(os.path.join(paths.outputs, paths.tab2), index
# Combine model results alongside results from paper
= compare_tables(model_tab2, paper_tab2)
t2_comp
# Descriptive statistics for absolute difference in results
print(t2_comp['diff'].describe())
# Extracting instances where absolute difference is more than 0.05
'diff'] > 0.05]) display(t2_comp[t2_comp[
count 420.000000
mean 0.008405
std 0.017549
min 0.000000
25% 0.000000
50% 0.000000
75% 0.010000
max 0.100000
Name: diff, dtype: float64
strength | staff_change | staff_per_shift | shifts_per_day | end_of_day | prop_infected_model | prop_infected_paper | diff | |
---|---|---|---|---|---|---|---|---|
24 | 2 | 7 | 5 | 1 | 7 | 0.40 | 0.30 | 0.10 |
36 | 2 | 14 | 5 | 1 | 7 | 0.30 | 0.40 | 0.10 |
84 | 4 | 7 | 5 | 1 | 7 | 0.15 | 0.20 | 0.05 |
96 | 4 | 14 | 5 | 1 | 7 | 0.15 | 0.20 | 0.05 |
97 | 4 | 14 | 5 | 2 | 7 | 0.15 | 0.20 | 0.05 |
98 | 4 | 14 | 5 | 3 | 7 | 0.15 | 0.20 | 0.05 |
109 | 4 | 21 | 5 | 2 | 7 | 0.15 | 0.20 | 0.05 |
204 | 2 | 7 | 5 | 1 | 14 | 0.40 | 0.30 | 0.10 |
243 | 4 | 1 | 10 | 1 | 14 | 0.12 | 0.17 | 0.05 |
248 | 4 | 1 | 20 | 3 | 14 | 0.44 | 0.53 | 0.09 |
250 | 4 | 1 | 30 | 2 | 14 | 0.60 | 0.66 | 0.06 |
264 | 4 | 7 | 5 | 1 | 14 | 0.15 | 0.20 | 0.05 |
311 | 6 | 1 | 30 | 3 | 14 | 0.29 | 0.34 | 0.05 |
422 | 4 | 1 | 5 | 3 | 21 | 0.23 | 0.30 | 0.07 |
423 | 4 | 1 | 10 | 1 | 21 | 0.25 | 0.34 | 0.09 |
424 | 4 | 1 | 10 | 2 | 21 | 0.34 | 0.41 | 0.07 |
435 | 4 | 3 | 10 | 1 | 21 | 0.35 | 0.45 | 0.10 |
436 | 4 | 3 | 10 | 2 | 21 | 0.40 | 0.47 | 0.07 |
489 | 6 | 1 | 30 | 1 | 21 | 0.44 | 0.50 | 0.06 |
Supplementary table 3
# Get table 3 and save to csv
= model_base5[model_base5['end_of_day'].isin([7, 14, 21])]
model_tab3 =False) model_tab3.to_csv(os.path.join(paths.outputs, paths.tab3), index
# Combine tables
= compare_tables(model_tab3, paper_tab3)
t3_comp
# Descriptive statistics for absolute difference in results
print(t3_comp['diff'].describe())
# Extracting instances where absolute difference is more than 0.05
'diff'] > 0.05]) display(t3_comp[t3_comp[
count 420.000000
mean 0.009905
std 0.017464
min 0.000000
25% 0.000000
50% 0.000000
75% 0.010000
max 0.100000
Name: diff, dtype: float64
strength | staff_change | staff_per_shift | shifts_per_day | end_of_day | prop_infected_model | prop_infected_paper | diff | |
---|---|---|---|---|---|---|---|---|
24 | 2 | 7 | 5 | 1 | 7 | 0.10 | 0.20 | 0.10 |
30 | 2 | 7 | 20 | 1 | 7 | 0.15 | 0.20 | 0.05 |
48 | 2 | 21 | 5 | 1 | 7 | 0.10 | 0.20 | 0.10 |
192 | 2 | 3 | 5 | 1 | 14 | 0.10 | 0.20 | 0.10 |
204 | 2 | 7 | 5 | 1 | 14 | 0.10 | 0.20 | 0.10 |
210 | 2 | 7 | 20 | 1 | 14 | 0.15 | 0.20 | 0.05 |
216 | 2 | 14 | 5 | 1 | 14 | 0.20 | 0.30 | 0.10 |
228 | 2 | 21 | 5 | 1 | 14 | 0.20 | 0.30 | 0.10 |
396 | 2 | 14 | 5 | 1 | 21 | 0.20 | 0.30 | 0.10 |
469 | 4 | 21 | 5 | 2 | 21 | 0.20 | 0.15 | 0.05 |
Supplementary table 4
# Get table 4 and save to csv
= model_base30[model_base30['end_of_day'].isin([7, 14, 21])]
model_tab4 =False) model_tab4.to_csv(os.path.join(paths.outputs, paths.tab4), index
# Combine tables
= compare_tables(model_tab4, paper_tab4)
t4_comp
# Descriptive statistics for absolute difference in results
print(t4_comp['diff'].describe())
# Extracting instances where absolute difference is more than 0.05
'diff'] > 0.05]) display(t4_comp[t4_comp[
count 420.000000
mean 0.004762
std 0.012765
min 0.000000
25% 0.000000
50% 0.000000
75% 0.000000
max 0.120000
Name: diff, dtype: float64
strength | staff_change | staff_per_shift | shifts_per_day | end_of_day | prop_infected_model | prop_infected_paper | diff | |
---|---|---|---|---|---|---|---|---|
12 | 2 | 3 | 5 | 1 | 7 | 0.40 | 0.30 | 0.10 |
74 | 4 | 3 | 5 | 3 | 7 | 0.20 | 0.15 | 0.05 |
257 | 4 | 3 | 10 | 3 | 14 | 0.50 | 0.55 | 0.05 |
421 | 4 | 1 | 5 | 2 | 21 | 0.53 | 0.65 | 0.12 |
422 | 4 | 1 | 5 | 3 | 21 | 0.60 | 0.65 | 0.05 |
Supplementary table 5
# Get table 5 and save to csv
= model_norest[model_norest['end_of_day'].isin([7, 14, 21])]
model_tab5 =False) model_tab5.to_csv(os.path.join(paths.outputs, paths.tab5), index
# Combine tables
= compare_tables(model_tab5, paper_tab5)
t5_comp
# Descriptive statistics for absolute difference in results
print(t5_comp['diff'].describe())
# Extracting instances where absolute difference is more than 0.05
'diff'] > 0.05]) display(t5_comp[t5_comp[
count 180.000000
mean 0.012056
std 0.022614
min 0.000000
25% 0.000000
50% 0.000000
75% 0.020000
max 0.100000
Name: diff, dtype: float64
strength | staff_change | staff_per_shift | shifts_per_day | end_of_day | prop_infected_model | prop_infected_paper | diff | |
---|---|---|---|---|---|---|---|---|
12 | 2 | 14 | 5 | 1 | 7 | 0.40 | 0.30 | 0.10 |
16 | 2 | 21 | 5 | 1 | 7 | 0.30 | 0.40 | 0.10 |
60 | 2 | 1 | 5 | 1 | 14 | 0.40 | 0.50 | 0.10 |
68 | 2 | 7 | 5 | 1 | 14 | 0.60 | 0.50 | 0.10 |
120 | 2 | 1 | 5 | 1 | 21 | 0.70 | 0.80 | 0.10 |
124 | 2 | 3 | 5 | 1 | 21 | 0.70 | 0.80 | 0.10 |
140 | 4 | 1 | 5 | 1 | 21 | 0.15 | 0.20 | 0.05 |
141 | 4 | 1 | 10 | 1 | 21 | 0.29 | 0.39 | 0.10 |
162 | 6 | 1 | 20 | 1 | 21 | 0.29 | 0.34 | 0.05 |
Supplementary table 6
# Add label to each dataframe
'workplace_measure'] = 'Social distancing'
model_contact['workplace_measure'] = 'Gloves'
model_gloves['workplace_measure'] = 'Surgical mask'
model_surgical['workplace_measure'] = 'Gown'
model_gown['workplace_measure'] = 'N95 mask'
model_n95[
# Combine results into a single table and filter to day 14
= pd.concat([
model_tab6
model_contact, model_gloves, model_surgical, model_gown, model_n95])= model_tab6[model_tab6['end_of_day'] == 14]
model_tab6
# Save to csv
=False) model_tab6.to_csv(os.path.join(paths.outputs, paths.tab6), index
# Confirming how many are NaN, so we are sure the combined table includes all
# the relevant counts being compared
'prop_infected'].isnull().value_counts() paper_tab6[
False 700
True 200
Name: prop_infected, dtype: int64
# Combine tables
= compare_tables(model_tab6, paper_tab6)
t6_comp
# Descriptive statistics for absolute difference in results
print(t6_comp['diff'].describe())
# Extracting instances where absolute difference is more than 0.05
'diff'] > 0.05]) display(t6_comp[t6_comp[
count 700.000000
mean 0.006214
std 0.011161
min 0.000000
25% 0.000000
50% 0.000000
75% 0.010000
max 0.050000
Name: diff, dtype: float64
strength | staff_change | staff_per_shift | shifts_per_day | end_of_day | prop_infected_model | workplace_measure | prop_infected_paper | diff | |
---|---|---|---|---|---|---|---|---|---|
71 | 4 | 1 | 30 | 3 | 14 | 0.18 | Social distancing | 0.23 | 0.05 |
96 | 4 | 14 | 5 | 1 | 14 | 0.15 | Social distancing | 0.20 | 0.05 |
97 | 4 | 14 | 5 | 2 | 14 | 0.15 | Social distancing | 0.20 | 0.05 |
98 | 4 | 14 | 5 | 3 | 14 | 0.20 | Social distancing | 0.15 | 0.05 |
183 | 2 | 1 | 10 | 1 | 14 | 0.20 | Gloves | 0.15 | 0.05 |
591 | 2 | 21 | 10 | 1 | 14 | 0.15 | Gown | 0.20 | 0.05 |
Figures
Define function to create the subplots
def plot_fig(fig_dict_list, ax, letter, title,
='Proportion of staff infected', legend=True,
ylabel=0.2):
ytick_freq'''
Create one of the subplots from the article's figures.
Parameters:
-----------
fig_dict_list : list
List of dictionaries with parameters to filter dataframe by and for
formatting the figure
ax : axes object
To create the plot on
letter : string
Letter of subplot (e.g. '(a)', '(b)')
title : string
Title for the subplot
ylabel : string
Title for Y axis
legend : boolean
Whether to include a figure legend for that subplot
ytick_freq : number
Frequency of Y axis ticks
'''
# Create each of the line plots
for fig_dict in fig_dict_list:
# Get the filters for the dataframe
= {key: fig_dict[key] for key in [
filt 'shifts_per_day', 'staff_per_shift', 'strength', 'staff_change']}
# Get subset of dataframe meeting the conditions
= ' & '.join([f"{col} == {val}" for col, val in filt.items()])
query = fig_dict['df'].query(query)
subset
# Reformat so ready to plot
= subset.set_index('end_of_day')['prop_infected']
to_plot
# Plot on ax
=fig_dict['label'], color=fig_dict['color'],
ax.plot(to_plot, label=fig_dict['linestyle'], linewidth=3)
linestyle
# Formatting the figure to match the paper
ax.set_title(title)'Day')
ax.set_xlabel(0, 21, 5))
ax.set_xticks(np.arange(
ax.set_ylabel(ylabel)0, 1.1, ytick_freq))
ax.set_yticks(np.arange(0, 1)
ax.set_ylim(if legend == True:
='upper left')
ax.legend(loc=(-0.15, 1.1), xycoords='axes fraction')
ax.annotate(letter, xy ax.grid()
Figure 2
Set out parameters for each of the lines
= {
fig2a_black 'shifts_per_day': 1,
'staff_per_shift': 30,
'strength': 2,
'staff_change': 1,
'df': model_base15,
'label': '2 x 30 staff/shift',
'color': 'black',
'linestyle': '-'
}
= {
fig2a_green 'shifts_per_day': 1,
'staff_per_shift': 10,
'strength': 6,
'staff_change': 1,
'df': model_base15,
'label': '6 x 10 staff/shift',
'color': '#94B454',
'linestyle': '--'
}
= {
fig2b_black 'shifts_per_day': 2,
'staff_per_shift': 30,
'strength': 4,
'staff_change': 1,
'df': model_base15,
'label': '4 x 30 staff/shift',
'color': 'black',
'linestyle': '-'
}
= {
fig2b_green 'shifts_per_day': 2,
'staff_per_shift': 20,
'strength': 6,
'staff_change': 1,
'df': model_base15,
'label': '6 x 20 staff/shift',
'color': '#94B454',
'linestyle': '--'
}
= {
fig2c_black 'shifts_per_day': 3,
'staff_per_shift': 30,
'strength': 4,
'staff_change': 1,
'df': model_base15,
'label': '4 x 30 staff/shift',
'color': 'black',
'linestyle': '-'
}
= {
fig2c_green 'shifts_per_day': 3,
'staff_per_shift': 20,
'strength': 6,
'staff_change': 1,
'df': model_base15,
'label': '6 x 20 staff/shift',
'color': '#94B454',
'linestyle': '--'
}
= {
fig2d_blue 'shifts_per_day': 2,
'staff_per_shift': 30,
'strength': 4,
'staff_change': 1,
'df': model_base5,
'label': '5%',
'color': '#0D5BB7',
'linestyle': ':'
}
= {
fig2d_black 'shifts_per_day': 2,
'staff_per_shift': 30,
'strength': 4,
'staff_change': 1,
'df': model_base15,
'label': '15%',
'color': 'black',
'linestyle': '-'
}
= {
fig2d_red 'shifts_per_day': 2,
'staff_per_shift': 30,
'strength': 4,
'staff_change': 1,
'df': model_base30,
'label': '3%',
'color': '#C14C46',
'linestyle': '--'
}
Create the figure
# Set up number of subplots and figure size
= plt.subplots(2, 2, figsize=(8, 6))
fig, ax
# Create the subplots
0,0], letter='(a)',
plot_fig([fig2a_black, fig2a_green], ax[='1 shift per day\nTotal staff = 60')
title0,1], letter='(b)',
plot_fig([fig2b_black, fig2b_green], ax[='2 shifts per day\nTotal staff = 120')
title1,0], letter='(c)',
plot_fig([fig2c_black, fig2c_green], ax[='3 shifts per day\nTotal staff = 120')
title1,1], letter='(d)',
plot_fig([fig2d_blue, fig2d_black, fig2d_red], ax[='2 shifts per day\nTotal staff = 4 x 30 staff/shift')
title
# Prevent overlap between subplots
fig.tight_layout()
# Save the figure
plt.savefig(os.path.join(paths.outputs, paths.fig2))
# Display the figure
plt.show()
Figure 3
Formatting for each of the four lines (which are the same for each subplot)
= [
fig3_lines # Black line
{'staff_per_shift': 5,
'label': '5 staff/shift',
'color': 'black',
'linestyle': '-'
},# Green line
{'staff_per_shift': 10,
'label': '10 staff/shift',
'color': '#94B454',
'linestyle': '--'
},# Purple line
{'staff_per_shift': 20,
'label': '20 staff/shift',
'color': '#7F619D',
'linestyle': ':'
},# Red line
{'staff_per_shift': 30,
'label': '30 staff/shift',
'color': '#C14C46',
'linestyle': '-'
} ]
Set default parameters for each subplot
= {
all_fig3 'staff_change': 1,
'df': model_base15
}
# These are default for Figure 3 and 4, so we reuse them below
= {
all_fig34a 'shifts_per_day': 1,
'strength': 4
}
= {
all_fig34b 'shifts_per_day': 1,
'strength': 6
}
= {
all_fig34c 'shifts_per_day': 2,
'strength': 4
}
= {
all_fig34d 'shifts_per_day': 2,
'strength': 6
}
= {
all_fig34e 'shifts_per_day': 3,
'strength': 4
}
= {
all_fig34f 'shifts_per_day': 3,
'strength': 6
}
Combine dictionaries to get conditions for plots
= [{**all_fig3, **all_fig34a, **x} for x in fig3_lines]
fig3a = [{**all_fig3, **all_fig34b, **x} for x in fig3_lines]
fig3b = [{**all_fig3, **all_fig34c, **x} for x in fig3_lines]
fig3c = [{**all_fig3, **all_fig34d, **x} for x in fig3_lines]
fig3d = [{**all_fig3, **all_fig34e, **x} for x in fig3_lines]
fig3e = [{**all_fig3, **all_fig34f, **x} for x in fig3_lines] fig3f
Create the plot
# Set up number of subplots and figure size
= plt.subplots(3, 2, figsize=(8, 9))
fig, ax
0,0], title='Total staff = 4 x no. staff/shift',
plot_fig(fig3a, ax[='(a)',
letter='$\\bf{1\ shift\ per\ day}$\nProportion of staff infected')
ylabel
0,1], letter='(b)',
plot_fig(fig3b, ax[='Total staff = 6 x no. staff/shift', legend=False)
title
1,0], letter='(c)', title='',
plot_fig(fig3c, ax[='$\\bf{2\ shifts\ per\ day}$\nProportion of staff infected',
ylabel=False)
legend
1,1], letter='(d)', title='', legend=False)
plot_fig(fig3d, ax[
2,0], letter='(e)', title='',
plot_fig(fig3e, ax[='$\\bf{3\ shifts\ per\ day}$\nProportion of staff infected',
ylabel=False)
legend
2,1], letter='(f)', title='', legend=False)
plot_fig(fig3f, ax[
# Prevent overlap between subplots
fig.tight_layout()
# Save the figure
plt.savefig(os.path.join(paths.outputs, paths.fig3))
# Display the figure
plt.show()
Figure 4
Formatting for each of the five lines (which are the same for each subplot)
= [
fig4_lines # Black line
{'staff_change': 1,
'label': 'Shift change after 1 day',
'color': 'black',
'linestyle': '-'
},# Grey line
{'staff_change': 3,
'label': 'Shift change after 3 days',
'color': '#7F7F7F',
'linestyle': '--'
},# Blue line
{'staff_change': 7,
'label': 'Shift change after 7 days',
'color': '#0054B4',
'linestyle': ':'
},# Red line
{'staff_change': 14,
'label': 'Shift change after 14 days',
'color': '#F20000',
'linestyle': '-'
},# Green line
{'staff_change': 21,
'label': 'Shift change after 21 days',
'color': '#9BB05C',
'linestyle': '--'
} ]
Set default parameters for subplots (using some defined above, as same as Figure 3 for some)
= {
all_fig4 'staff_per_shift': 20,
'df': model_base15
}
Combine dictionaries to get conditions for plots
= [{**all_fig4, **all_fig34a, **x} for x in fig4_lines]
fig4a = [{**all_fig4, **all_fig34b, **x} for x in fig4_lines]
fig4b = [{**all_fig4, **all_fig34c, **x} for x in fig4_lines]
fig4c = [{**all_fig4, **all_fig34d, **x} for x in fig4_lines]
fig4d = [{**all_fig4, **all_fig34e, **x} for x in fig4_lines]
fig4e = [{**all_fig4, **all_fig34f, **x} for x in fig4_lines] fig4f
Create the plot
# Set up number of subplots and figure size
= plt.subplots(3, 2, figsize=(8, 9))
fig, ax
0,0], letter='(a)',
plot_fig(fig4a, ax[='Total staff = 4 x no. staff/shift',
title='$\\bf{1\ shift\ per\ day}$\nProportion of staff infected',
ylabel=False)
legend
0,1], letter='(b)',
plot_fig(fig4b, ax[='Total staff = 6 x no. staff/shift', legend=True)
title
1,0], letter='(c)', title='',
plot_fig(fig4c, ax[='$\\bf{2\ shifts\ per\ day}$\nProportion of staff infected',
ylabel=False)
legend
1,1], letter='(d)', title='', legend=False)
plot_fig(fig4d, ax[
2,0], letter='(e)', title='',
plot_fig(fig4e, ax[='$\\bf{3\ shifts\ per\ day}$\nProportion of staff infected',
ylabel=False)
legend
2,1], letter='(f)', title='', legend=False)
plot_fig(fig4f, ax[
# Prevent overlap between subplots
fig.tight_layout()
# Save the figure
plt.savefig(os.path.join(paths.outputs, paths.fig4))
# Display the figure
plt.show()
Figure 5
Set default parameters for all lines
= {
all_fig5 'shifts_per_day': 1,
'staff_per_shift': 30,
'strength': 2,
'staff_change': 1
}
Then set specific parameters for each line, and combine with the default parameters
= {
fig5a_black **all_fig5,
**{
'df': model_base15,
'label': 'Split team',
'color': 'black',
'linestyle': '-'
}
}
= {
fig5a_green **all_fig5,
**{
'df': model_contact,
'label': 'Split team with social distancing',
'color': '#008E27',
'linestyle': '--'
}
}
= {
fig5a_blue **all_fig5,
**{
'df': model_norest,
'label': 'Random roster assignment',
'color': '#0054B4',
'linestyle': ':'
}
}
= {
fig5b_black **all_fig5,
**{
'df': model_base15,
'label': 'Baseline',
'color': 'black',
'linestyle': '-'
}
}
= {
fig5b_green **all_fig5,
**{
'df': model_gloves,
'label': 'Wearing gloves',
'color': '#008E27',
'linestyle': '--'
}
}
= {
fig5b_blue **all_fig5,
**{
'df': model_surgical,
'label': 'Wearing surgical mask',
'color': '#0054B4',
'linestyle': ':'
}
}
= {
fig5b_grey **all_fig5,
**{
'df': model_gown,
'label': 'Wearing gown',
'color': '#7F7F7F',
'linestyle': '--'
}
}
= {
fig5b_red **all_fig5,
**{
'df': model_n95,
'label': 'Wearing N95 mask',
'color': '#E71111',
'linestyle': '-'
} }
# Set up number of subplots and figure size
= plt.subplots(2, 1, figsize=(6, 9))
fig, ax
0], letter='(a)',
plot_fig([fig5a_black, fig5a_green, fig5a_blue], ax[='', legend=True, ytick_freq=0.1)
title
1],
plot_fig([fig5b_black, fig5b_green, fig5b_blue, fig5b_grey, fig5b_red], ax[='(b)', title='', legend=True, ytick_freq=0.1)
letter
# Prevent overlap between subplots
fig.tight_layout()
# Save the figure
plt.savefig(os.path.join(paths.outputs, paths.fig5))
# Display the figure
plt.show()
Time elapsed
if run_model:
# Find run time in seconds
= time.time()
end = round(end-start)
runtime
# Display converted to minutes and seconds
print(f'Notebook run time: {runtime//60}m {runtime%60}s')