from typing import NewType, Callable, Dict, Any, Set
import time
from ..utils import to_set
InvarianceTest = NewType('InvarianceTest', Callable[[Dict, Any, Any, Set], Dict])
[docs]class InvarianceTester:
def __init__(self):
pass
def is_invariant(self, node, context, cond_set=set()):
raise NotImplementedError
[docs]class MemoizedInvarianceTester(InvarianceTester):
def __init__(self, invariance_test: InvarianceTest, suffstat: Dict, track_times=False, detailed=False, **kwargs):
"""
Class for memoizing the results of invariance tests.
Parameters
----------
invariance_test:
Function taking suffstat, context, node, and conditioning set, and returning a dictionary that includes
the key 'reject'.
suffstat:
Dictionary containing sufficient statistics for all contexts.
track_times:
If True, keep a dictionary mapping each invariance test to the time taken to perform it.
detailed:
If True, keep a dictionary mapping each invariance test to its full set of results.
**kwargs:
Additional keyword arguments to be passed to the invariance test.
See Also
--------
PlainInvarianceTester
Example
-------
"""
InvarianceTester.__init__(self)
self.invariance_dict_detailed = dict()
self.invariance_dict = dict()
self.invariance_test = invariance_test
self.suffstat = suffstat
self.kwargs = kwargs
self.detailed = detailed
self.track_times = track_times
self.invariance_times = dict()
def is_invariant(self, node, context, cond_set=set()):
"""
Check if the conditional distribution of node, given cond_set, is invariant to the context.
"""
cond_set = to_set(cond_set)
index = (node, context, frozenset(cond_set))
# check if result exists and return
_is_invariant = self.invariance_dict.get(index)
if _is_invariant is not None:
return _is_invariant
# otherwise, compute result and save
if self.track_times:
start = time.time()
test_results = self.invariance_test(
self.suffstat,
context,
node,
cond_set=cond_set,
**self.kwargs
)
if self.track_times:
self.invariance_times[index] = time.time() - start
if self.detailed:
self.invariance_dict_detailed[index] = test_results
_is_invariant = not test_results['reject']
self.invariance_dict[index] = _is_invariant
return _is_invariant
class PlainInvarianceTester(InvarianceTester):
def __init__(self, invariance_test: InvarianceTest, suffstat: Dict, **kwargs):
"""
Class for returning the results of invariance tests.
Parameters
----------
invariance_test:
Function taking suffstat, context, node, and conditioning set, and returning a dictionary that includes
the key 'reject'.
suffstat:
Dictionary containing sufficient statistics for all contexts.
**kwargs:
Additional keyword arguments to be passed to the invariance test.
See Also
--------
MemoizedInvarianceTester
Example
-------
"""
InvarianceTester.__init__(self)
self.invariance_test = invariance_test
self.suffstat = suffstat
self.kwargs = kwargs
def is_invariant(self, node, context, cond_set=set()):
return self.invariance_test(
self.suffstat,
context,
node,
cond_set=cond_set,
**self.kwargs
)