|
23 | 23 | from functools import partial, reduce |
24 | 24 | from itertools import groupby |
25 | 25 | from operator import itemgetter |
| 26 | +from weakref import WeakKeyDictionary, WeakSet |
26 | 27 |
|
27 | 28 | import ctypes |
28 | 29 | import functools |
@@ -706,3 +707,120 @@ def batch_contextmanager(f, kwargs_list): |
706 | 707 | for kwargs in kwargs_list: |
707 | 708 | stack.enter_context(f(**kwargs)) |
708 | 709 | yield |
| 710 | + |
| 711 | +class tls_property: |
| 712 | + """ |
| 713 | + Use it like `property` decorator, but the result will be memoized per |
| 714 | + thread. When the owning thread dies, the values for that thread will be |
| 715 | + destroyed. |
| 716 | +
|
| 717 | + In order to get the values, it's necessary to call the object |
| 718 | + given by the property. This is necessary in order to be able to add methods |
| 719 | + to that object, like :meth:`_BoundTLSProperty.get_all_values`. |
| 720 | +
|
| 721 | + Values can be set and deleted as well, which will be a thread-local set. |
| 722 | + """ |
| 723 | + |
| 724 | + @property |
| 725 | + def name(self): |
| 726 | + return self.factory.__name__ |
| 727 | + |
| 728 | + def __init__(self, factory): |
| 729 | + self.factory = factory |
| 730 | + # Lock accesses to shared WeakKeyDictionary and WeakSet |
| 731 | + self.lock = threading.Lock() |
| 732 | + |
| 733 | + def __get__(self, instance, owner=None): |
| 734 | + return _BoundTLSProperty(self, instance, owner) |
| 735 | + |
| 736 | + def _get_value(self, instance, owner): |
| 737 | + tls, values = self._get_tls(instance) |
| 738 | + try: |
| 739 | + return tls.value |
| 740 | + except AttributeError: |
| 741 | + # Bind the method to `instance` |
| 742 | + f = self.factory.__get__(instance, owner) |
| 743 | + obj = f() |
| 744 | + tls.value = obj |
| 745 | + # Since that's a WeakSet, values will be removed automatically once |
| 746 | + # the threading.local variable that holds them is destroyed |
| 747 | + with self.lock: |
| 748 | + values.add(obj) |
| 749 | + return obj |
| 750 | + |
| 751 | + def _get_all_values(self, instance, owner): |
| 752 | + with self.lock: |
| 753 | + # Grab a reference to all the objects at the time of the call by |
| 754 | + # using a regular set |
| 755 | + tls, values = self._get_tls(instance=instance) |
| 756 | + return set(values) |
| 757 | + |
| 758 | + def __set__(self, instance, value): |
| 759 | + tls, values = self._get_tls(instance) |
| 760 | + tls.value = value |
| 761 | + with self.lock: |
| 762 | + values.add(value) |
| 763 | + |
| 764 | + def __delete__(self, instance): |
| 765 | + tls, values = self._get_tls(instance) |
| 766 | + with self.lock: |
| 767 | + values.discard(tls.value) |
| 768 | + del tls.value |
| 769 | + |
| 770 | + def _get_tls(self, instance): |
| 771 | + dct = instance.__dict__ |
| 772 | + name = self.name |
| 773 | + try: |
| 774 | + # Using instance.__dict__[self.name] is safe as |
| 775 | + # getattr(instance, name) will return the property instead, as |
| 776 | + # the property is a descriptor |
| 777 | + tls = dct[name] |
| 778 | + except KeyError: |
| 779 | + with self.lock: |
| 780 | + # Double check after taking the lock to avoid a race |
| 781 | + if name not in dct: |
| 782 | + tls = (threading.local(), WeakSet()) |
| 783 | + dct[name] = tls |
| 784 | + |
| 785 | + return tls |
| 786 | + |
| 787 | + @property |
| 788 | + def basic_property(self): |
| 789 | + """ |
| 790 | + Return a basic property that can be used to access the TLS value |
| 791 | + without having to call it first. |
| 792 | +
|
| 793 | + The drawback is that it's not possible to do anything over than |
| 794 | + getting/setting/deleting. |
| 795 | + """ |
| 796 | + def getter(instance, owner=None): |
| 797 | + prop = self.__get__(instance, owner) |
| 798 | + return prop() |
| 799 | + |
| 800 | + return property(getter, self.__set__, self.__delete__) |
| 801 | + |
| 802 | +class _BoundTLSProperty: |
| 803 | + """ |
| 804 | + Simple proxy object to allow either calling it to get the TLS value, or get |
| 805 | + some other informations by calling methods. |
| 806 | + """ |
| 807 | + def __init__(self, tls_property, instance, owner): |
| 808 | + self.tls_property = tls_property |
| 809 | + self.instance = instance |
| 810 | + self.owner = owner |
| 811 | + |
| 812 | + def __call__(self): |
| 813 | + return self.tls_property._get_value( |
| 814 | + instance=self.instance, |
| 815 | + owner=self.owner, |
| 816 | + ) |
| 817 | + |
| 818 | + def get_all_values(self): |
| 819 | + """ |
| 820 | + Returns all the thread-local values currently in use in the process for |
| 821 | + that property for that instance. |
| 822 | + """ |
| 823 | + return self.tls_property._get_all_values( |
| 824 | + instance=self.instance, |
| 825 | + owner=self.owner, |
| 826 | + ) |
0 commit comments