Add tools to mange (thread-local) context information

This commit is contained in:
Henryk Plötz 2019-07-01 22:15:20 +02:00
parent 6e28db0f6c
commit 1687ecf2ff
4 changed files with 162 additions and 2 deletions

View file

@ -0,0 +1,62 @@
import threading
import contextlib
_localdata = threading.local()
class Context(dict):
def __init__(self, _parent=None, **kwargs):
self._parent = _parent
self._deleted = set()
super().__init__(**kwargs)
def __delitem__(self, key):
with contextlib.suppress(KeyError):
super().__delitem__(key)
self._deleted.add(key)
def __getitem__(self, key):
if key in self._deleted:
raise KeyError("{} deleted in context".format(key))
if key not in self:
if self._parent is not None:
return self._parent[key]
return super().__getitem__(key)
def __setitem__(self, key, value):
if key in self._deleted:
self._deleted.remove(key)
super().__setitem__(key, value)
def _get_current_context():
if not hasattr(_localdata, 'contexts'):
_localdata.contexts = [Context(_parent=GlobalContext)]
return _localdata.contexts[-1]
class MetaContext:
def __delitem__(self, item):
return _get_current_context().__delitem__(item)
def __getattribute__(self, key):
return _get_current_context().__getattribute__(key)
def __setitem__(self, item, value):
return _get_current_context().__setitem__(item, value)
def __getitem__(self, item):
return _get_current_context().__getitem__(item)
@contextlib.contextmanager
def enter_context(**kwargs):
_get_current_context()
_localdata.contexts.append(Context(_parent=_localdata.contexts[-1], **kwargs))
yield
_localdata.contexts.pop()
CurrentContext = MetaContext()
GlobalContext = Context()

View file

@ -269,3 +269,6 @@ class TLVField(Field):
instance._values[self] = TLVContainer()
instance._values[self].pending = True
return super().__get__(instance, objtype)
# FIXME text encoding

View file

@ -170,8 +170,6 @@ class TestAPDUBitmaps(TestCase):
self.assertEqual(b'\x02\xff\xaa', c.raw_tlv)
# FIXME Create empty TLV
# FIXME Create TLV on access
# FIXME TLV names

View file

@ -0,0 +1,97 @@
from ecrterm.packets.context import enter_context, GlobalContext, CurrentContext
from unittest import TestCase, main
import threading, time
class TestContext(TestCase):
def test_normal_access(self):
GlobalContext['test_normal_access'] = 1
self.assertEqual(1, GlobalContext['test_normal_access'])
del GlobalContext['test_normal_access']
self.assertRaises(KeyError, lambda: GlobalContext['test_normal_access'])
def test_nested_access_1(self):
GlobalContext['test_nested_access_1'] = 1
with enter_context():
self.assertEqual(1, CurrentContext['test_nested_access_1'])
CurrentContext['test_nested_access_1'] = 2
self.assertEqual(2, CurrentContext['test_nested_access_1'])
del CurrentContext['test_nested_access_1']
self.assertRaises(KeyError, lambda: CurrentContext['test_nested_access_1'])
self.assertEqual(1, CurrentContext['test_nested_access_1'])
def test_nested_access_2(self):
CurrentContext['test_nested_access_2'] = 1
with enter_context(test_nested_access_2=2):
self.assertEqual(2, CurrentContext['test_nested_access_2'])
with enter_context():
self.assertEqual(2, CurrentContext['test_nested_access_2'])
CurrentContext['test_nested_access_2'] = 3
self.assertEqual(3, CurrentContext['test_nested_access_2'])
self.assertEqual(2, CurrentContext['test_nested_access_2'])
def test_nested_delete(self):
CurrentContext['test_nested_delete'] = 1
with enter_context():
del CurrentContext['test_nested_delete']
self.assertRaises(KeyError, lambda: CurrentContext['test_nested_delete'])
CurrentContext['test_nested_delete'] = 2
self.assertEqual(2, CurrentContext['test_nested_delete'])
self.assertEqual(1, CurrentContext['test_nested_delete'])
def test_nested_access_3(self):
self.assertRaises(KeyError, lambda: CurrentContext['test_nested_access_3'])
with enter_context():
self.assertRaises(KeyError, lambda: CurrentContext['test_nested_access_3'])
GlobalContext['test_nested_access_3'] = 1
self.assertEqual(1, CurrentContext['test_nested_access_3'])
self.assertEqual(1, CurrentContext['test_nested_access_3'])
def test_threads(self):
GlobalContext['test_threads'] = 1
def test_fun(arg):
self.assertEqual(1, CurrentContext['test_threads'])
CurrentContext['test_threads'] = arg
time.sleep(0.01)
self.assertEqual(arg, CurrentContext['test_threads'])
t1 = threading.Thread(target=lambda: test_fun(2))
t2 = threading.Thread(target=lambda: test_fun(3))
t3 = threading.Thread(target=lambda: test_fun(4))
t1.start()
t2.start()
t3.start()
t1.join()
t2.join()
t3.join()
self.assertEqual(1, GlobalContext['test_threads'])
self.assertEqual(1, CurrentContext['test_threads'])
if __name__ == '__main__':
main()