"""
Core APLUS simulation engine which progresses patients through a given workflow
"""
import io
import random
from types import CodeType
from typing import Any, Callable, Optional, Tuple, Dict, List
import numpy as np
import pandas as pd
import pickle
import pydot
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from tqdm import tqdm
from aplusml.config import Config, ConfigUtility
import aplusml.draw as draw
from aplusml.models import Patient, State, Transition, History, Utility
from aplusml.parse import is_valid_config_yaml, load_yaml
[docs]
class Simulation(object):
"""The core APLUS simulation engine which progresses patients through a given workflow.
This class manages the entire simulation process including:
* Patient state transitions and history tracking
* Variable evaluation and management
* Resource allocation and replenishment
* Utility calculations
* Workflow visualization
The simulation operates on a timestep basis, processing patients through states and transitions
based on conditional and probabilistic rules defined in the workflow configuration.
Attributes:
metadata (dict): Configuration metadata for the simulation
variables (Dict[str, Dict]): Variables used in the simulation, keyed by ID
variable_history (Dict[List[Tuple[int, Any]]]): History of variable values over time
states (Dict[str, State]): States in the workflow, keyed by ID
current_timestep (int): Current simulation timestep
"""
[docs]
def __init__(self):
"""Initializes the simulation with default values"""
self.metadata = {}
self.variables: Dict[str, Dict] = {} # [key] = id, [value] = dict
self.variable_history: Dict[List[Tuple[int, Any]]] = {} # [key] = id, [value] = tuple(timestep, value)
self.states: Dict[str, State] = {} # [key] = id, [value] = State
self.current_timestep: int = None
def __repr__(self):
metadata_str = f"self.metadata:\n"
for k, v in self.metadata.items():
metadata_str += f" - {k}: {v}\n"
variables_str = "self.variables:\n"
for var in self.variables:
variables_str += f" - {var}: {self.variables[var]}\n"
states_str = "self.states:\n"
for state in self.states:
states_str += f" - {state}\n"
for t in self.states[state].transitions:
states_str += f" - {t.dest}\n"
if t.is_conditional_if():
states_str += f" if: {t.if_}\n"
elif t.is_conditional_prob():
states_str += f" prob: {t.prob}\n"
return f"{metadata_str}\n{variables_str}\n{states_str}"
[docs]
def evaluate_variables(self, patient: Patient) -> Dict[str, Any]:
"""Evaluates all variables for a given patient at the current timestep.
Processes different variable types:
* `scalar`: Returns constant value
* `resource`: Returns current resource level
* `property`: Returns patient-specific property value
* `simulation`: Returns simulation-specific values (time remaining, elapsed time, current timestep)
* `function`: Executes and returns function result
Args:
patient (Patient): The patient to evaluate variables for
Returns:
Dict[str, Any]: Dictionary mapping variable IDs to their evaluated values
Raises:
AssertionError: If number of evaluated variables doesn't match total variables
"""
variable_to_value = {}
for v_id, v in self.variables.items():
if v['type'] == 'scalar':
variable_to_value[v_id] = v['value']
elif v['type'] == 'resource':
variable_to_value[v_id] = v['level']
elif v['type'] == 'property':
variable_to_value[v_id] = patient.properties[v_id]
elif v['type'] == 'simulation':
if v_id == 'time_left_in_sim':
# If patient LOS = 1, t = 0, patient start = 0, then 'time_left_in_sim' = 1
variable_to_value[v_id] = max(0, patient.start_timestep + patient.properties['total_duration_in_sim'] - self.current_timestep)
elif v_id == 'time_already_in_sim':
# If t = 1, patient start = 0, then 'time_already_in_sim' = 1
variable_to_value[v_id] = max(0, self.current_timestep - patient.start_timestep)
elif v_id == 'sim_current_timestep':
variable_to_value[v_id] = self.current_timestep
elif v['type'] == 'function':
# TODO
# variable_to_value[v_id] = v_id()
raise NotImplementedError("Function variables are not yet implemented")
else:
print("ERROR - Invalid 'type' for variable:", v_id)
assert len(variable_to_value) == len(self.variables)
return variable_to_value
[docs]
def evaluate_expression(self,
patient: Patient,
expression: Any,
variables: dict,
expression_compiled: Optional[CodeType] = None) -> Any:
"""Evaluates a Python expression in the context of a patient and variables.
Handles evaluation of:
* Literal values (bool, int, float)
* String expressions using Python's eval()
* Pre-compiled expressions for better performance
The expression can reference any variables passed in the variables dict.
Args:
patient (Patient): The patient context for evaluation
expression (Any): The expression to evaluate - can be literal value or string expression
variables (dict): Dictionary of variables available during evaluation
expression_compiled (CodeType, optional): Pre-compiled version of the expression
Returns:
Any: The result of evaluating the expression
Raises:
NameError: If expression references undefined variables
"""
if type(expression) in [bool, int, float]:
return expression
elif type(expression) != str:
# NOTE: This is necessary b/c eval() won't accept a bool
return None
# Evaluate expression
try:
if expression_compiled:
return eval(expression_compiled, {}, variables)
else:
return eval(expression, {}, variables)
except NameError as e:
print(f"ERROR - Missing variable ({e}) for expression: {expression}")
return None
except Exception as e:
print(f"ERROR - Error evaluating expression: {expression}")
print(f" - Error: {e}")
print(f" - Variables: {variables}")
return None
[docs]
def evaluate_transition_if(self, patient: Patient,
transition: Transition,
variables: dict) -> bool:
"""Evaluates whether a conditional transition should be taken.
Checks if the transition has an 'if' condition and evaluates it in the context
of the given patient and variables. If there is no condition, returns True.
Args:
patient (Patient): The patient attempting the transition
transition (Transition): The transition to evaluate
variables (dict): Variables available for condition evaluation
Returns:
bool: True if transition condition is met or no condition exists, False otherwise
"""
if not transition.is_conditional_if():
# If there is no 'if', return TRUE
return True
return self.evaluate_expression(patient, transition.if_, variables, transition.if_compiled)
[docs]
def evaluate_transition_prob(self, patient: Patient, transition: Transition, variables: Dict[str, Any]) -> float:
"""Evaluates whether a probabilistic transition should be taken.
For transitions with a 'prob' value, evaluates the probability expression.
For transitions without a 'prob', returns None (will be set to remaining probability).
Args:
patient (Patient): The patient attempting the transition
transition (Transition): The transition to evaluate
variables (Dict[str, Any]): Variables available for probability evaluation
Returns:
float: Evaluated probability value between 0 and 1, or None if no probability specified
"""
if not transition.is_conditional_prob():
# If there is no 'prob', return None (default will be set to: 1 - (sum of other probs))
return None
return self.evaluate_expression(patient, transition.prob, variables, transition.prob_compiled)
[docs]
def evaluate_utility_if(self, patient: Patient, utility: Utility, variables: Dict[str, Any]) -> bool:
"""Evaluates whether a utility should be applied based on its condition.
Similar to evaluate_transition_if, but for utility conditions. Returns True
if utility has no condition or if condition evaluates to True.
Args:
patient (Patient): The patient to evaluate utility for
utility (Utility): The utility to evaluate
variables (Dict[str, Any]): Variables available for condition evaluation
Returns:
bool: True if utility should be applied, False otherwise
"""
if not utility.is_conditional_if():
# If there is no 'if', return TRUE
return True
return self.evaluate_expression(patient, utility.if_, variables, utility.if_compiled)
[docs]
def evaluate_utility_value(self, patient: Patient, utility: Utility, variables: Dict[str, Any]) -> Any:
"""Evaluates the actual value of a utility.
Evaluates the utility's value expression in the context of the patient and variables.
The value can be a constant or an expression using available variables.
Args:
patient (Patient): The patient context for evaluation
utility (Utility): The utility whose value should be evaluated
variables (Dict[str, Any]): Variables available for value evaluation
Returns:
Any: The evaluated utility value
"""
return self.evaluate_expression(patient, utility.value, variables, utility.value_compiled)
[docs]
def evaluate_duration(self, patient: Patient, duration: str, variables: Dict[str, Any]) -> int:
"""Evaluates how long a patient should remain in a state or transition.
Duration can be a constant or an expression using available variables.
Args:
patient (Patient): The patient to evaluate duration for
duration (str): The duration expression to evaluate
variables (Dict[str, Any]): Variables available for duration evaluation
Returns:
int: Number of timesteps the duration should last
"""
return self.evaluate_expression(patient, duration, variables)
[docs]
def select_transition(self, patient: Patient,
transitions: List[Transition],
variables: Dict[str, Any]) -> int:
"""Determines which transition a patient should take from their current state.
First evaluates all conditional ('if') transitions in order. If none are true,
then evaluates probabilistic transitions. For probabilistic transitions,
ensures probabilities sum to 1 and handles default probability case.
Args:
patient (Patient): The patient attempting to transition
transitions (List[Transition]): Available transitions from current state
variables (Dict[str, Any]): Variables available for transition evaluation
Returns:
int: Index of selected transition, or None if no valid transition found
Raises:
AssertionError: If no valid transition can be selected
"""
# Conditional transition
start_idx_for_probs = 0
for t_idx, t in enumerate(transitions):
if t.is_conditional_prob():
start_idx_for_probs = t_idx
break
value = self.evaluate_transition_if(patient, t, variables)
if value:
# NOTE: Must just be `value` to evaluate Truthiness -- i.e. can't have "value is True" here,
# otherwise code like "X is True" will return false when eval() returns an np.bool_(True)
# b/c "is" compares X to the Python True, otherwise returns false
# See: https://stackoverflow.com/questions/27276610/boolean-identity-true-vs-is-true
return t_idx
elif value is None:
return None
# Probabilistic transition
prob_transitions = transitions[start_idx_for_probs:]
probs = [ self.evaluate_transition_prob(patient, t, variables) for t in prob_transitions ]
if probs.count(None) > 1:
# If there are multiple Nones, throw error
print(f"ERROR - Multiple transitions are missing a 'prob' value: {transitions}")
return None
elif probs.count(None) == 1:
# If there is one None, set it = 1 - (sum of other probs)
probs[probs.index(None)] = 1 - sum(filter(None, probs))
else:
# If there is no catch-all (i.e. no `None`)...
if sum(probs) != 1:
# ...And probs don't sum to 1, throw error
print(f"ERROR - Values for 'prob' don't sum to 1: {transitions}")
return None
t_idx = start_idx_for_probs + random.choices(range(len(prob_transitions)), weights = probs)[0]
return t_idx
[docs]
def init_variables(self, variables: Dict[str, dict]):
"""Initializes resource variables with their starting values.
Note: Modifies 'variables' in place
For each resource variable, sets:
* Initial resource level from init_amount
* Last refill timestep to 0
* Empty history tracking list
Args:
variables (Dict[str, dict]): Dictionary of variables from config, modified in place
"""
# Resource initial amounts
for v_id, v in variables.items():
if v['type'] == 'resource':
variables[v_id]['level'] = v['init_amount']
variables[v_id]['last_refill_timestep'] = 0
self.variable_history[v_id] = [ (0, variables[v_id]['level']) ]
[docs]
def init_run(self, random_seed: int = 0):
"""Initializes the simulation for a new run.
Resets simulation state by:
- Setting current timestep to 0
- Initializing all variables
- Setting random seeds for reproducibility
Args:
random_seed (int, optional): Seed for random number generators. Defaults to 0.
"""
# Track simulation timesteps
self.current_timestep: int = 0 # Current simulation timestep
# Initialize variables
self.init_variables(self.variables)
# Random seed
np.random.seed(random_seed)
random.seed(random_seed)
[docs]
def is_valid_patients(self, patients: List[Patient]) -> Tuple[bool, str]:
"""Validates a list of patients for simulation.
Currently checks:
- All patients have unique IDs
Args:
patients (List[Patient]): List of patients to validate
Returns:
Tuple[bool, str]: (is_valid, error_message) where is_valid is True if validation passes,
and error_message contains details if validation fails
"""
is_unique_ids = len(set([ p.id for p in patients ])) == len(patients)
if not is_unique_ids:
return False, f"Patients must all have unique IDs"
return True, ""
[docs]
def init_patients(self, patients: List[Patient]):
"""Initializes patients for simulation.
For each patient:
- Sets initial state to 'start'
- Initializes empty history list
Args:
patients (List[Patient]): List of patients to initialize, modified in place
"""
for p in patients:
# Current state = 'start'
p.current_state = 'start'
# History
p.history = []
[docs]
def log(self, string: str):
"""Logs a message if logging is enabled.
Format: t=<current timestep of simulation> | {string}
Args:
string (str): Message to log
"""
if self.is_print_log:
print(f"t={self.current_timestep} | {string}")
[docs]
def run(self,
all_patients: List[Patient],
max_timesteps: int = None,
random_seed: int = 0,
is_print_log: bool = False,
is_print_tqdm: bool = True) -> List[Patient]:
"""Runs the simulation by progressing all patients through the workflow.
Core simulation loop that:
1. Processes patients in order of admission time and preference sorting
2. Moves each patient through states based on transitions
3. Tracks utilities and resource usage
4. Handles patient pausing/unpausing for state/transition durations
5. Continues until all patients finish or max timesteps reached
Performance: Takes ~3 seconds for 15,000 patients, ~10 seconds for 50,000 patients.
Args:
all_patients (List[Patient]): All patients to simulate across all admit days.
Modified in place - only history and current_state attributes are changed.
max_timesteps (int, optional): Maximum timesteps to simulate. If reached, simulation stops
with current_timestep = max_timesteps - 1. Defaults to None.
random_seed (int, optional): Random seed for reproducibility. Defaults to 0.
is_print_log (bool, optional): Whether to print debug logs. Defaults to False.
is_print_tqdm (bool, optional): Whether to print a progress bar. Defaults to True.
Returns:
List[Patient]: The patients after simulation with updated histories.
Raises:
AssertionError: If patients are invalid or simulation state becomes invalid
"""
assert isinstance(all_patients, list), f"ERROR - 'all_patients' must be a list, but instead got: {type(all_patients)}"
assert isinstance(random_seed, int), f"ERROR - 'random_seed' must be an int, but instead got: {type(random_seed)}"
self.is_print_log = is_print_log
# Deep copy patients
if is_print_log:
print(f"Deep copying patients...")
all_patients = pickle.loads(pickle.dumps(all_patients))
if is_print_log:
print(f"Done deep copying patients")
# Double check that all patients have all properties set
for p in all_patients:
for v_id, v in self.variables.items():
if v['type'] == 'property':
assert v_id in p.properties, f"ERROR - Patient '{p.id}' does not have property '{v_id}' set"
# Track patients admitted after current value of 'self.current_timestep'
admitted_patients: List[Patient] = []
# Track patients waiting some # of timesteps
paused_patients: dict[tuple[int, str]] = {} # [key] = patient ID, [value] = (time remaining in state/transition, state ID / transition dest)
# Track patients who were just unpaused
unpaused_patients: dict[str] = {} # [key] = patient ID, [value] = state ID/transition dest where patient is paused
# Track patients who hit an end state
finished_patients: dict[bool] = {} # [key] = patient ID, [value] = TRUE if patient hit an end state
# Initialize progress bar
if is_print_tqdm:
pbar = tqdm(total=len(all_patients), desc='Simulating patients')
else:
pbar = None
prev_finished_count = 0
# Reset simulation to initialize for run
self.init_patients(all_patients)
self.init_run(random_seed)
# Validate input
is_valid_patients, msg = self.is_valid_patients(all_patients)
assert is_valid_patients, f"ERROR - Patient are invalid - {msg}"
#
# Sort all_patients by `start_timestep`
#
all_patients = sorted(all_patients, key = lambda x : x.start_timestep)
## Track which patient is the latest cutoff for admittance
current_timestep_to_all_patients_idx = {}
for p_idx, p in enumerate(all_patients):
current_timestep_to_all_patients_idx[p.start_timestep] = p_idx
most_recent_current_timestep_to_all_patients_idx = current_timestep_to_all_patients_idx.get(0, 0)
## Track which patient is the earliest cutoff for unfinished patients
earliest_unfinished_patient_tuple: Tuple[int, int] = (0, 0) # Tuple (X,Y) where X = idx of patient with 'start_timestep' = Y, where X is the smallest idx in 'all_patients' which corresponds to an unfinished patient
# Progress all patients until all are finished or hit `max_timestamp`...
while True:
#
# Check if we should end this simulation
#
if max_timesteps is not None and self.current_timestep >= max_timesteps:
self.current_timestep -= 1 # NOTE: Don't change this
self.log(f"Max timestep exceeded @ t={self.current_timestep}")
break
if len(all_patients) <= len(finished_patients):
self.current_timestep -= 1 # NOTE: Don't change this
self.log(f"All patients are finished @ t={self.current_timestep}")
break
self.log(f"Top of sim loop")
#
# Replenish resources
#
for v_id, v in self.variables.items():
if v.get('type', 'scalar') == 'resource':
if self.current_timestep - v['last_refill_timestep'] >= v['refill_duration']:
# Refill resource
v['last_refill_timestep'] = self.current_timestep
v['level'] += v['refill_amount']
# Cap resource at 'max_amount'
v['level'] = min(v['max_amount'], v['level'])
self.variable_history[v_id].append((self.current_timestep, v['level']))
assert v['level'] <= v['max_amount'], f"ERROR - Variable '{v}' value for 'level' exceeded 'max_amount' during REFILL"
#
# Admit new patients
#
if self.current_timestep in current_timestep_to_all_patients_idx:
most_recent_current_timestep_to_all_patients_idx = current_timestep_to_all_patients_idx[self.current_timestep]
admitted_patients_start_idx: int = earliest_unfinished_patient_tuple[0] # NOTE: Need to save this info here for usage in 'while: True' loop
admitted_patients_end_idx: int = most_recent_current_timestep_to_all_patients_idx + 1
admitted_patients: List[Patient] = all_patients[admitted_patients_start_idx:admitted_patients_end_idx]
## Set 'earliest_unfinished_patient_tuple' to best possible case (i.e. assume all 'admitted_patients' finish)
## We'll progressively decrease this value to "worse" cases (i.e. earlier patients in 'admitted_patients') as we loop through them below
earliest_unfinished_patient_tuple: Tuple[int,int] = (admitted_patients_end_idx, all_patients[min(admitted_patients_end_idx, len(all_patients) - 1)].start_timestep)
#
# Process admitted patients
#
# Within a timestep, ordering of patients is arbitrary, so process them
# in order of our constraint preference
# NOTE: this avoids the need to do separate processing on individual constraints
# NOTE: we need to return the actual sorting indices instead of sorting the list inplace
# b/c we need to know the actual indices of each element in `admitted_patients` so
# that we can adjust `earliest_unfinished_patient_tuple` correctly
patient_sort_preference_variable: str = self.metadata['patient_sort_preference_property'].get('variable') if self.metadata['patient_sort_preference_property'] else None
patient_sort_preference_is_ascending: str = self.metadata['patient_sort_preference_property'].get('is_ascending') if self.metadata['patient_sort_preference_property'] else None
sorted_indices = sort_patient_by_preference(admitted_patients, property_to_sort_by=patient_sort_preference_variable, is_ascending=patient_sort_preference_is_ascending)
self.log(f"admitted_patients_idxs=[{admitted_patients_start_idx}:{admitted_patients_end_idx}] | earliest unfinished tuple={earliest_unfinished_patient_tuple} | finished={len(finished_patients)} paused={len(paused_patients)} unpaused={len(unpaused_patients)}")
for p_idx in sorted_indices:
p = admitted_patients[p_idx]
while True:
# Progress patient until hit pause or finish...
if p.id in finished_patients:
break
if p.id in paused_patients:
# Note that this patient is unfinished
earliest_unfinished_patient_tuple = (p_idx + admitted_patients_start_idx, p.start_timestep) if p_idx + admitted_patients_start_idx < earliest_unfinished_patient_tuple[0] else earliest_unfinished_patient_tuple
break
# Go through 'current_state'...
current_state: State = self.states[p.current_state]
transition: Transition = None
# Evaluate variables
variables = self.evaluate_variables(p)
# Unpause patient if we've already waited the requisite timesteps
paused_state_or_transition = None
if p.id in unpaused_patients:
paused_state_or_transition: str = unpaused_patients[p.id]
del unpaused_patients[p.id]
# Track if we need to "wait X timesteps" AS SOON AS state is hit (unless we've already waited, i.e. patient is in 'unpaused_patients')
if paused_state_or_transition is not None and paused_state_or_transition == current_state.id:
# We HAVE already waited the requisite timesteps for 'current_state', so continue with rest of iteration
pass
else:
# We HAVE NOT already waited the requisite timesteps for 'current_state' (i.e. this is the first time we're hitting this state,
# or we were just waiting but on some other state / transition
current_state_duration = self.evaluate_duration(p, current_state.duration, variables)
if current_state_duration > 0:
# If we need to wait > 0 timesteps, then add patient to 'paused_patients'
paused_patients[p.id] = (current_state_duration, current_state.id)
continue
# Select TRANSITION / Update if patient is 'finished' with workflow (i.e. has reached an end state)
transition_idx: int = None
if current_state.type == 'end':
# If this is an END state, add patient to 'finished_patients'
finished_patients[p.id] = 1
# Update progress bar when a new patient finishes
if len(finished_patients) > prev_finished_count:
if is_print_tqdm:
pbar.update(len(finished_patients) - prev_finished_count)
prev_finished_count = len(finished_patients)
else:
# Select transition
transition_idx = self.select_transition(p, current_state.transitions, variables)
assert transition_idx is not None, f"ERROR - No transition conditional is TRUE for patient '{p.id}' given transitions: {current_state.transitions}"
transition = current_state.transitions[transition_idx]
# Determine STATE / TRANSITION utilities
state_utility_idxs, transition_utility_idxs = [], []
state_utility_vals, transition_utility_vals = [], []
for utils_idx, utils in enumerate([ current_state.utilities, transition.utilities if transition else [] ]):
for u_idx, u in enumerate(utils):
# Check that 'if' statement is TRUE (if present)
if not self.evaluate_utility_if(p, u, variables):
continue
# Evaluate value
u_val = self.evaluate_utility_value(p, u, variables)
if utils_idx == 0:
state_utility_idxs.append(u_idx)
state_utility_vals.append(u_val)
else:
transition_utility_idxs.append(u_idx)
transition_utility_vals.append(u_val)
# Record history
p.history.append(History(self.current_timestep,
current_state.id,
transition_idx,
state_utility_idxs,
transition_utility_idxs,
state_utility_vals,
transition_utility_vals,
variables,
)
)
# Take transition
next_state: str = None
if transition:
# Move patient to 'next_state'
assert transition.dest in self.states, f"ERROR - Transition dest '{transition.dest}' not in 'states' section of YAML"
next_state = transition.dest
# Track if we need to "wait X timesteps" AFTER we take this transition (NOTE: we don't need to do an 'unpaused_patients' check, like we do for state, b/c it's impossible for this transition to have already been taken)
transition_duration = self.evaluate_duration(p, transition.duration, variables)
if transition_duration > 0:
assert p.id not in paused_patients
paused_patients[p.id] = (transition_duration, transition.dest)
p.current_state = next_state
# Decrement variables used in this STATE or TRANSITION
resource_deltas: Dict[str, float] = { **current_state.resource_deltas, **(transition.resource_deltas if transition else {}) }
for v_id, delta in resource_deltas.items():
# Add 'delta' to resource
assert v_id in self.variables, f"ERROR - Variable '{v_id}' is not in the 'variables' section of the YAML, as it is used in the 'resource_deltas' of a state or transition"
self.variables[v_id]['level'] += delta
# Cap resource at 'max_amount'
self.variables[v_id]['level'] = min(self.variables[v_id]['max_amount'], self.variables[v_id]['level'])
self.variable_history[v_id].append((self.current_timestep, self.variables[v_id]['level']))
assert self.variables[v_id]['level'] <= self.variables[v_id]['max_amount'], f"ERROR - Variable '{v_id}' value for 'level' exceeded 'max_amount' "
self.log(f"Transition: ({p.id}) => {transition.dest if transition else 'N/A'}")
#
# By this point, all patients are either finished or paused
# Now, we take a timestep forward
#
# For all paused patients, advance them one timestep...
for p_id in list(paused_patients.keys()):
time_left = paused_patients[p_id][0]
paused_state = paused_patients[p_id][1]
# Decrement time left
time_left -= 1
if time_left <= 0:
del paused_patients[p_id]
# Record that we just unpaused this patient from this state/transition
unpaused_patients[p_id] = paused_state
else:
paused_patients[p_id] = (time_left, paused_state)
self.current_timestep += 1
# Close progress bar at end of simulation
if is_print_tqdm:
pbar.close()
return all_patients
[docs]
def get_all_utility_units(self) -> List[str]:
"""Gets all unique utility unit types used in the workflow.
Examines all utilities across all states and transitions to find unique unit types
(e.g. "QALY", "USD").
Returns:
List[str]: List of unique utility unit names
"""
units = []
for s in self.states.values():
for u in s.utilities:
units.append(u.unit)
for t in s.transitions:
for u in t.utilities:
units.append(u.unit)
return units
[docs]
def draw_workflow_diagram(self,
path_to_file: str = None,
is_display: bool = True,
figsize: Tuple[int, int] = (20, 20)):
"""Visualizes the workflow as a directed graph using pydot.
Creates a visual diagram showing:
* States as nodes
* Transitions as edges
* Transition conditions and probabilities
* State/transition utilities and durations
Args:
path_to_file (str, optional): Path to save diagram. Must include supported file extension.
If None, diagram is not saved. Defaults to None.
is_display (bool, optional): Whether to display diagram (useful for Jupyter).
Defaults to True.
figsize (Tuple[int, int], optional): Figure size for matplotlib.
Defaults to (20, 20).
Raises:
ValueError: If path_to_file has unsupported file extension
"""
dot_graph = pydot.Dot(graph_type='digraph')
colors = [
'blue',
'red',
'green',
'purple',
'brown',
'olive',
'cyan',
'darkblue',
'darkslategray',
'orange',
'maroon',
'darkcyan',
]
for idx, state in enumerate(self.states.values()):
color = colors[idx % len(colors)]
# Generate edges
for t in state.transitions:
# Label
if t.is_conditional_prob():
title = 'Prob = ' + str(t.prob)
elif t.is_conditional_if():
title = 'If ' + str(t.if_)
elif len(state.transitions) == 1:
title = 'Always'
else:
title = 'Otherwise'
label: str = draw.create_node_label(title, None, t.duration, t.utilities, t.resource_deltas, is_edge=True)
# Turn this edge into a node for visualization purposes
node_name: str = state.id + '-' + t.dest
node = pydot.Node(node_name, label=label)
node.set_shape('plain')
node.set_color(color)
dot_graph.add_node(node)
# Add edges to/from this new "edge" node
for (start, end) in [(state.id, node_name), (node_name, t.dest)]:
edge = pydot.Edge(start, end)
edge.set_color(color)
edge.set_fontcolor(color)
dot_graph.add_edge(edge)
# Generate node (default to intermediate node)
label: str = draw.create_node_label(state.label, state.id, state.duration, state.utilities, state.resource_deltas)
# Shape/color
node = pydot.Node(state.id, label=label)
node.set_shape('plain')
node.set_color(color)
dot_graph.add_node(node)
if path_to_file:
format: str = path_to_file.split('.')[-1]
if format not in dot_graph.formats:
raise ValueError(f"ERROR - Invalid file extension '{format}' specified for 'path_to_file'. Must be one of: {dot_graph.formats}")
dot_graph.write(path_to_file, format=format)
if is_display:
# Source: https://stackoverflow.com/questions/4596962/display-graph-without-saving-using-pydot
png_str = dot_graph.create_png()
sio = io.BytesIO()
sio.write(png_str)
sio.seek(0)
img = mpimg.imread(sio)
plt.figure(figsize = figsize)
plt.axis('off')
plt.imshow(img, aspect='equal')
[docs]
def create_patients_for_simulation(self,
patients: List[Patient],
func_match_patient_to_property_column: Callable = None,
is_overwrite_existing_properties: bool = True,
random_seed: int = 0) -> List[Patient]:
"""Creates a deep copy of patients and initializes their properties.
1. Deep copies patients using pickle
2. Sorts patients by ID
3. Initializes patient properties from:
- Constants
- CSV file data (using matching function)
- Random distributions
Args:
patients (List[Patient]): Template patients to copy and initialize
func_match_patient_to_property_column (Callable, optional): Function to match patients
to rows in properties CSV. Takes ``(patient_id, random_idx, df, column)``.
Required if using CSV properties without ID column. Defaults to ``None``.
is_overwrite_existing_properties (bool, optional): If TRUE, then overwrite each patient's existing properties with their default values, as specified in `self.variables`.
If FALSE, then only initialize properties that are not already set.
Defaults to ``True``.
random_seed (int, optional): Random seed for reproducibility. Defaults to ``0``.
Returns:
List[Patient]: New list of initialized patients
Raises:
ValueError: If property configuration is invalid or required matching function missing
"""
if is_overwrite_existing_properties:
print("\n!! WARNING - Because `is_overwrite_existing_properties` is TRUE, we are OVERWRITING each patient's existing properties (i.e. patient.properties) with default values from `simulation.variables`. "
"If this is UNDESIRED, i.e. you already defined each patient's properties manually and want to keep them as currently defined, "
"then set `is_overwrite_existing_properties` to FALSE.\n")
# create deep copy of the `patients` object using pickle
patients = pickle.loads(pickle.dumps(patients))
patients = sorted(patients, key = lambda x: x.id)
properties = [ (id, v) for id, v in self.variables.items() if v.get('type', 'scalar') == 'property' ]
# CSV for file-defined properties
path_to_properties = self.metadata.get('path_to_properties', None)
properties_col_for_patient_id = self.metadata.get('properties_col_for_patient_id', None)
if path_to_properties:
# Read CSV containing patient properties
_df = pd.read_csv(path_to_properties)
if properties_col_for_patient_id:
# Check that column corresponding to the Patient ID actually exists in CSV
if properties_col_for_patient_id not in _df.columns:
print(f"ERROR - The value for `properties_col_for_patient_id` ({properties_col_for_patient_id}) must be a column name in the file {path_to_properties}")
return None
# Sort patients by ID
_df = _df.sort_values(properties_col_for_patient_id)
# If we want to randomly sample patients from the CSV (instead of using their ID),
# then do this sampling deterministically by tracking `map_pid_to_random_df_idx`
np.random.seed(random_seed)
random_idxs = np.random.randint(0, _df.shape[0], size = len(patients))
map_pid_to_random_df_idx = { p.id: random_idxs[idx] for idx, p in enumerate(patients) }
#
# Add properties to each Patient
np.random.seed(random_seed)
for (v_id, v) in properties:
if 'value' in v:
# Set to constant
for p in patients:
if is_overwrite_existing_properties:
p.properties[v_id] = v['value']
else:
if v_id not in p.properties:
p.properties[v_id] = v['value']
elif 'column' in v:
# Load from 'path_to_properties' file
if not path_to_properties:
print(f"ERROR - If you specify a 'column' variable, you need to specify a 'path_to_properties' value in the 'metadata' section")
return None
if v['column'] not in _df:
print(f"ERROR - 'column' {v['column']} is not contained in the file pointed to by 'path_to_properties'")
return None
if properties_col_for_patient_id:
## NOTE: This may seem like an unnecessary special case of the functionality offered by `func_match_patient_to_property_column`,
## But its a necessary performance optimization that actually helps speed up the program a lot
sorted_properties = _df[v['column']].values
for p_idx, p in enumerate(patients):
if is_overwrite_existing_properties:
p.properties[v_id] = sorted_properties[p_idx]
else:
if v_id not in p.properties:
p.properties[v_id] = sorted_properties[p_idx]
else:
if func_match_patient_to_property_column is None and properties_col_for_patient_id is None:
raise ValueError(f"ERROR - You need to either specify a `func_match_patient_to_property_column` when calling this function, or the `properties_col_for_patient_id` metadata property in your YAML file (otherwise we have no idea how to match patients to rows in the file)")
for p in patients:
if is_overwrite_existing_properties:
p.properties[v_id] = func_match_patient_to_property_column(p.id, map_pid_to_random_df_idx[p.id], _df, v['column'])
else:
if v_id not in p.properties:
p.properties[v_id] = func_match_patient_to_property_column(p.id, map_pid_to_random_df_idx[p.id], _df, v['column'])
elif 'distribution' in v:
# Distribution
if v['distribution'] == 'bernoulli':
assert 'mean' in v, f"ERROR - Bernoulli variable '{v_id}' missing 'mean' property"
values = np.random.binomial(1, v['mean'], size = len(patients))
elif v['distribution'] == 'exponential':
assert 'mean' in v, f"ERROR - Exponential variable '{v_id}' missing 'mean' property"
values = np.random.exponential(v['mean'], size = len(patients))
elif v['distribution'] == 'binomial':
assert 'mean' in v and 'n' in v, f"ERROR - Binomial variable '{v_id}' missing 'mean' or 'n' property"
values = np.random.binomial(v['n'], v['mean'], size = len(patients))
elif v['distribution'] == 'normal':
assert 'mean' in v and 'std' in v, f"ERROR - Normal variable '{v_id}' missing 'mean' or 'std' property"
values = np.random.normal(v['mean'], v['std'], size = len(patients))
elif v['distribution'] == 'poisson':
assert 'mean' in v, f"ERROR - Poisson variable '{v_id}' missing 'mean' property"
values = np.random.poisson(v['mean'], size = len(patients))
elif v['distribution'] == 'uniform':
assert 'start' in v and 'end' in v, f"ERROR - Uniform variable '{v_id}' missing 'start' or 'end' property"
values = np.random.uniform(v['start'], v['end'], size = len(patients))
else:
raise ValueError(f"ERROR - Unrecognized 'distribution' in variable '{v_id}'")
for idx, p in enumerate(patients):
if is_overwrite_existing_properties:
p.properties[v_id] = values[idx]
else:
if v_id not in p.properties:
p.properties[v_id] = values[idx]
else:
raise ValueError(f"ERROR - Unrecognized properties for variable '{v_id}'")
return patients
[docs]
def load_seismometer_patients_for_simulation(self,
patients: List[Patient],
df_patients_seismometer: pd.DataFrame,
func_match_patient_to_property_column: Callable = None,
random_seed: int = 0) -> List[Patient]:
"""Loads patient properties from an Epic Seismometer dataframe.
Similar to `create_patients_for_simulation` but loads properties from a provided
dataframe instead of CSV file.
Args:
patients (List[Patient]): Template patients to initialize
df_patients_seismometer (pd.DataFrame): Dataframe containing patient properties
func_match_patient_to_property_column (Callable, optional): Function to match patients
to rows in dataframe. Takes ``(patient_id, random_idx, df, column)``.
Required if not using ID column. Defaults to ``None``.
random_seed (int, optional): Random seed for reproducibility. Defaults to ``0``.
Returns:
List[Patient]: Patients with initialized properties
Raises:
ValueError: If property configuration is invalid or required matching function missing
"""
patients = pickle.loads(pickle.dumps(patients))
patients = sorted(patients, key = lambda x: x.id)
properties = [ (id, v) for id, v in self.variables.items() if v.get('type', 'scalar') == 'property' ]
properties_col_for_patient_id = self.metadata.get('properties_col_for_patient_id', None)
if properties_col_for_patient_id:
# Check that column corresponding to the Patient ID actually exists in Seismometer dataframe
if properties_col_for_patient_id not in df_patients_seismometer.columns:
print(f"ERROR - The value for `properties_col_for_patient_id` ({properties_col_for_patient_id}) must be a column name in the dataframe {df_patients_seismometer}")
return None
# Sort Seismometer patients by ID
df_patients_seismometer = df_patients_seismometer.sort_values(properties_col_for_patient_id)
# If we want to randomly sample patients from the Seismometer dataframe (instead of using their ID),
# then do this sampling deterministically by tracking `map_pid_to_random_df_idx`
np.random.seed(random_seed)
random_idxs = np.random.randint(0, df_patients_seismometer.shape[0], size = len(patients))
map_pid_to_random_df_idx = { p.id: random_idxs[idx] for idx, p in enumerate(patients) }
# Add properties to each Patient
np.random.seed(random_seed)
for (v_id, v) in properties:
if 'value' in v:
# Set to constant
for p in patients:
p.properties[v_id] = v['value']
elif 'column' in v:
# Load from patients_dataframe
if v['column'] not in df_patients_seismometer:
print(f"ERROR - 'column' {v['column']} is not contained in the dataframe {df_patients_seismometer}")
return None
if properties_col_for_patient_id:
## NOTE: This may seem like an unnecessary special case of the functionality offered by `func_match_patient_to_property_column`,
## But its a necessary performance optimization that actually helps speed up the program a lot
sorted_properties = df_patients_seismometer[v['column']].values
for p_idx, p in enumerate(patients):
p.properties[v_id] = sorted_properties[p_idx]
else:
if func_match_patient_to_property_column is None and properties_col_for_patient_id is None:
print(f"ERROR - You need to either specify a `func_match_patient_to_property_column` when calling this function, or the `properties_col_for_patient_id` metadata property in your YAML file (otherwise we have no idea how to match patients to rows in the file)")
return None
for p in patients:
p.properties[v_id] = func_match_patient_to_property_column(p.id, map_pid_to_random_df_idx[p.id], df_patients_seismometer, v['column'])
return patients
[docs]
@classmethod
def create_from_yaml(cls, path_to_yaml: str, path_to_patient_properties: Optional[str] = None) -> 'aplusml.sim.Simulation':
"""Create a :class:`~aplusml.sim.Simulation` object from YAML
If ``path_to_patient_properties`` is provided, then it overwrites the current Metadata section's ``path_to_properties`` key.
Args:
path_to_yaml (str): Path to YAML file
path_to_patient_properties (str): Path to patient properties CSV file. Optional.
Returns:
:class:`~aplusml.sim.Simulation`: A :class:`~aplusml.sim.Simulation` object that contains all of the metadata, variables, states, and transitions from the YAML file
"""
yaml: dict = load_yaml(path_to_yaml)
if not is_valid_config_yaml(yaml):
raise ValueError("ERROR - Invalid YAML")
# Create new Simulation
simulation = cls()
#
# Metadata
metadata = yaml.get('metadata', {})
for key in metadata:
simulation.metadata[key] = metadata[key]
## Set defaults
simulation.metadata['name'] = simulation.metadata.get('name', '')
simulation.metadata['path_to_properties'] = simulation.metadata.get('path_to_properties', None)
simulation.metadata['properties_col_for_patient_id'] = simulation.metadata.get('properties_col_for_patient_id', None)
simulation.metadata['patient_sort_preference_property'] = simulation.metadata.get('patient_sort_preference_property', None)
#
# Variables
variables = yaml.get('variables', {})
for v_id, v in variables.items():
simulation.variables[v_id] = {
'type' : v.get('type', 'scalar'),
'value' : v.get('value', None),
**v,
}
#
# States
states = yaml.get('states', {})
for s_id, s in states.items():
transitions: list[Transition] = []
for t in s.get('transitions', []):
raw_utils = t.get('utilities', [])
if type(raw_utils) != list:
# Handle 'utilities' value being a float|int|str (i.e. non-list)
raw_utils = [{ 'value' : raw_utils }]
utilities: list[Utility] = []
for u in raw_utils:
utilities.append(Utility(u.get('value', 0.0), u.get('unit', ''), u.get('if')))
transitions.append(Transition(
t['dest'],
t.get('label', ''),
t.get('duration', 0),
utilities,
{ key: float(val) for key, val in t.get('resource_deltas', {}).items() },
if_ = t.get('if'),
prob = t.get('prob'),
))
raw_utils = s.get('utilities', [])
if type(raw_utils) != list:
# Handle 'utilities' value being a float|int|str (i.e. non-list)
raw_utils = [{ 'value' : raw_utils }]
utilities: list[Utility] = []
for u in raw_utils:
utilities.append(Utility(u.get('value', 0.0), u.get('unit', ''), u.get('if')))
simulation.states[s_id] = State(
s_id,
s.get('label', s_id),
s.get('type', 'intermediate'),
s.get('duration', 0),
utilities,
transitions,
{ key: float(val) for key, val in s.get('resource_deltas', {}).items() }
)
if path_to_patient_properties:
simulation.metadata['path_to_properties'] = path_to_patient_properties
return simulation
[docs]
@classmethod
def create_from_config(cls, config: Config, path_to_patient_properties: Optional[str] = None) -> 'aplusml.sim.Simulation':
"""Create a :class:`~aplusml.sim.Simulation` object from a :class:`~aplusml.config.Config` object.
If ``path_to_patient_properties`` is provided, then it overwrites the current Metadata section's ``path_to_properties`` key.
Args:
config (Config): A Python :class:`~aplusml.config.Config` object
path_to_patient_properties (str): Path to patient properties CSV file. Optional.
Returns:
:class:`~aplusml.sim.Simulation`: A :class:`~aplusml.sim.Simulation` object that contains all of the metadata, variables, states, and transitions from the YAML file
"""
if not config.is_valid():
raise ValueError("ERROR - Invalid Config")
# Create new Simulation
simulation = cls()
#
# Metadata
metadata = config.metadata
## Set defaults
simulation.metadata['name'] = metadata.name
simulation.metadata['path_to_properties'] = metadata.path_to_properties
simulation.metadata['properties_col_for_patient_id'] = metadata.properties_col_for_patient_id
simulation.metadata['patient_sort_preference_property'] = metadata.patient_sort_preference_property
#
# Variables
variables = config.variables
for v_id, v in variables.items():
simulation.variables[v_id] = {
'type' : v.type,
'value' : v.value,
**{ key: val for key, val in v.model_dump().items() if val is not None }, # ignore `None` values (for clarity)
}
#
# States
states = config.states
for s_id, s in states.items():
transitions: list[Transition] = []
for t in s.transitions:
raw_utils = t.utilities
if not isinstance(raw_utils, list):
# Handle 'utilities' value being a float|int|str (i.e. non-list)
raw_utils = [ ConfigUtility(value=raw_utils) ]
utilities: list[Utility] = []
for u in raw_utils:
utilities.append(Utility(u.value, u.unit, u.if_))
transitions.append(Transition(
t.dest,
t.label,
t.duration,
utilities,
t.resource_deltas,
if_ = t.if_,
prob = t.prob,
))
raw_utils = s.utilities
if not isinstance(raw_utils, list):
# Handle 'utilities' value being a float|int|str (i.e. non-list)
raw_utils = [ ConfigUtility(value=raw_utils) ]
utilities: list[Utility] = []
for u in raw_utils:
utilities.append(Utility(u.value, u.unit, u.if_))
simulation.states[s_id] = State(
s_id,
s.label if s.label is not None else s_id,
s.type,
s.duration,
utilities,
transitions,
s.resource_deltas
)
if path_to_patient_properties:
simulation.metadata['path_to_properties'] = path_to_patient_properties
return simulation
[docs]
def sort_patient_by_preference(patients: List[Patient],
property_to_sort_by: str = None,
is_ascending: bool = True) -> List[int]:
"""Returns indices that would sort patients by specified property.
Can sort by:
- Direct :class:`~aplusml.models.Patient` attributes (``id``, ``start_timestep``)
- Patient properties dictionary values
Args:
patients (List[Patient]): Patients to sort
property_to_sort_by (str, optional): Property name to sort by. Defaults to ``None``.
is_ascending (bool, optional): Sort order. Defaults to ``True``.
Returns:
List[int]: Indices that would sort the patients list
"""
if property_to_sort_by == 'id':
# Attribute that is directly part of `Patient` object
properties = [ x.id for x in patients ]
return sorted(range(len(patients)), key=properties.__getitem__, reverse = not is_ascending)
elif property_to_sort_by == 'start_timestep':
# Attribute that is directly part of `Patient` object
properties = [ x.start_timestep for x in patients ]
return sorted(range(len(patients)), key=properties.__getitem__, reverse = not is_ascending)
elif property_to_sort_by:
# Attribute in `.properties` attribute
properties = [ x.properties[property_to_sort_by] for x in patients ]
return sorted(range(len(patients)), key=properties.__getitem__, reverse = not is_ascending)
else:
return range(len(patients))
[docs]
def get_unit_utility_baselines(patients: List[Patient],
utilities: Dict[str, float],
y_true_column_name: str = 'ground_truth') -> Dict[str, float]:
"""Calculates baseline utility metrics for different treatment strategies.
Computes average per-patient utility for:
- Treat all patients
- Treat no patients
- Perfect treatment (treat only true positives)
Args:
patients (List[Patient]): Patients to analyze
utilities (Dict[str, float]): Utility values for TP/FP/FN/TN outcomes
y_true_column_name (str, optional): Patient property containing ground truth.
Defaults to 'ground_truth'.
Returns:
Dict[str, float]: Average utilities for 'all', 'none', and 'perfect' strategies
"""
positives = len([ p for p in patients if p.properties[y_true_column_name] == 1 ])
negatives = len([ p for p in patients if p.properties[y_true_column_name] == 0 ])
"""Treat all (i.e. TP -> TP; FP -> FP, FN -> TP, TN -> FP)
"""
treat_all = 0
treat_all += positives * utilities['tp']
treat_all += negatives * utilities['fp']
"""Treat none (i.e. TP -> FN; FP -> TN, FN -> FN, TN -> TN)
"""
treat_none = 0
treat_none += positives * utilities['fn']
treat_none += negatives * utilities['tn']
"""Treat perfect (i.e. TP -> TP; FP -> TN, FN -> TP, TN -> TN)
"""
treat_perfect = 0
treat_perfect += positives * utilities['tp']
treat_perfect += negatives * utilities['tn']
return {
'all' : treat_all / len(patients),
'none' : treat_none / len(patients),
'perfect' : treat_perfect / len(patients),
}
[docs]
def log_patients(simulation: Simulation, patients: List[Patient]):
"""Prints detailed debug information about patients.
For each patient, prints:
- ID and start timestep
- Properties
- State history
- Sum of utilities
Args:
simulation (Simulation): Simulation context
patients (List[Patient]): Patients to log
"""
for p in patients:
print(f"{p.id} (t_0 = {p.start_timestep})")
print('\t', p.properties)
print('\t', p.print_state_history())
print('\t', p.get_sum_utilities(simulation))
if __name__ == "__main__":
pass