redis.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # pyright: reportUnknownMemberType=false
  2. from __future__ import annotations
  3. import _thread
  4. import json
  5. import logging
  6. import random
  7. import time
  8. import typing
  9. import redis
  10. from . import exceptions, utils
  11. logger = logging.getLogger(__name__)
  12. DEFAULT_UNAVAILABLE_TIMEOUT = 1
  13. DEFAULT_THREAD_SLEEP_TIME = 0.1
  14. class PubSubWorkerThread(redis.client.PubSubWorkerThread):
  15. def run(self) -> None:
  16. try:
  17. super().run()
  18. except Exception: # pragma: no cover
  19. _thread.interrupt_main()
  20. raise
  21. class RedisLock(utils.LockBase):
  22. """
  23. An extremely reliable Redis lock based on pubsub with a keep-alive thread
  24. As opposed to most Redis locking systems based on key/value pairs,
  25. this locking method is based on the pubsub system. The big advantage is
  26. that if the connection gets killed due to network issues, crashing
  27. processes or otherwise, it will still immediately unlock instead of
  28. waiting for a lock timeout.
  29. To make sure both sides of the lock know about the connection state it is
  30. recommended to set the `health_check_interval` when creating the redis
  31. connection..
  32. Args:
  33. channel: the redis channel to use as locking key.
  34. connection: an optional redis connection if you already have one
  35. or if you need to specify the redis connection
  36. timeout: timeout when trying to acquire a lock
  37. check_interval: check interval while waiting
  38. fail_when_locked: after the initial lock failed, return an error
  39. or lock the file. This does not wait for the timeout.
  40. thread_sleep_time: sleep time between fetching messages from redis to
  41. prevent a busy/wait loop. In the case of lock conflicts this
  42. increases the time it takes to resolve the conflict. This should
  43. be smaller than the `check_interval` to be useful.
  44. unavailable_timeout: If the conflicting lock is properly connected
  45. this should never exceed twice your redis latency. Note that this
  46. will increase the wait time possibly beyond your `timeout` and is
  47. always executed if a conflict arises.
  48. redis_kwargs: The redis connection arguments if no connection is
  49. given. The `DEFAULT_REDIS_KWARGS` are used as default, if you want
  50. to override these you need to explicitly specify a value (e.g.
  51. `health_check_interval=0`)
  52. """
  53. redis_kwargs: dict[str, typing.Any]
  54. thread: PubSubWorkerThread | None
  55. channel: str
  56. timeout: float
  57. connection: redis.client.Redis[str] | None
  58. pubsub: redis.client.PubSub | None = None
  59. close_connection: bool
  60. DEFAULT_REDIS_KWARGS: typing.ClassVar[dict[str, typing.Any]] = dict(
  61. health_check_interval=10,
  62. decode_responses=True,
  63. )
  64. def __init__(
  65. self,
  66. channel: str,
  67. connection: redis.client.Redis[str] | None = None,
  68. timeout: float | None = None,
  69. check_interval: float | None = None,
  70. fail_when_locked: bool | None = False,
  71. thread_sleep_time: float = DEFAULT_THREAD_SLEEP_TIME,
  72. unavailable_timeout: float = DEFAULT_UNAVAILABLE_TIMEOUT,
  73. redis_kwargs: dict[str, typing.Any] | None = None,
  74. ) -> None:
  75. # We don't want to close connections given as an argument
  76. self.close_connection = not connection
  77. self.thread = None
  78. self.channel = channel
  79. self.connection = connection
  80. self.thread_sleep_time = thread_sleep_time
  81. self.unavailable_timeout = unavailable_timeout
  82. self.redis_kwargs = redis_kwargs or dict()
  83. for key, value in self.DEFAULT_REDIS_KWARGS.items():
  84. self.redis_kwargs.setdefault(key, value)
  85. super().__init__(
  86. timeout=timeout,
  87. check_interval=check_interval,
  88. fail_when_locked=fail_when_locked,
  89. )
  90. def get_connection(self) -> redis.client.Redis[str]:
  91. if not self.connection:
  92. self.connection = redis.client.Redis(**self.redis_kwargs)
  93. return self.connection
  94. def channel_handler(self, message: dict[str, str]) -> None:
  95. if message.get('type') != 'message': # pragma: no cover
  96. return
  97. raw_data = message.get('data')
  98. if not raw_data:
  99. return
  100. try:
  101. data = json.loads(raw_data)
  102. except TypeError: # pragma: no cover
  103. logger.debug('TypeError while parsing: %r', message)
  104. return
  105. assert self.connection is not None
  106. self.connection.publish(data['response_channel'], str(time.time()))
  107. @property
  108. def client_name(self) -> str:
  109. return f'{self.channel}-lock'
  110. def _timeout_generator(
  111. self, timeout: float | None, check_interval: float | None
  112. ) -> typing.Iterator[int]:
  113. if timeout is None:
  114. timeout = 0.0
  115. if check_interval is None:
  116. check_interval = self.thread_sleep_time
  117. deadline = time.monotonic() + timeout
  118. first = True
  119. while first or time.monotonic() < deadline:
  120. first = False
  121. effective_interval = (
  122. check_interval
  123. if check_interval > 0
  124. else self.thread_sleep_time
  125. )
  126. sleep_time = effective_interval * (0.5 + random.random())
  127. time.sleep(sleep_time)
  128. yield 0
  129. def acquire( # type: ignore[override]
  130. self,
  131. timeout: float | None = None,
  132. check_interval: float | None = None,
  133. fail_when_locked: bool | None = None,
  134. ) -> RedisLock:
  135. timeout = utils.coalesce(timeout, self.timeout, 0.0)
  136. check_interval = utils.coalesce(
  137. check_interval,
  138. self.check_interval,
  139. 0.0,
  140. )
  141. fail_when_locked = utils.coalesce(
  142. fail_when_locked,
  143. self.fail_when_locked,
  144. )
  145. assert not self.pubsub, 'This lock is already active'
  146. connection = self.get_connection()
  147. timeout_generator = self._timeout_generator(timeout, check_interval)
  148. for _ in timeout_generator: # pragma: no branch
  149. subscribers = connection.pubsub_numsub(self.channel)[0][1]
  150. if subscribers:
  151. logger.debug(
  152. 'Found %d lock subscribers for %s',
  153. subscribers,
  154. self.channel,
  155. )
  156. if self.check_or_kill_lock(
  157. connection,
  158. self.unavailable_timeout,
  159. ): # pragma: no branch
  160. continue
  161. else: # pragma: no cover
  162. subscribers = 0
  163. # Note: this should not be changed to an elif because the if
  164. # above can still end up here
  165. if not subscribers:
  166. connection.client_setname(self.client_name)
  167. self.pubsub = connection.pubsub()
  168. self.pubsub.subscribe(**{self.channel: self.channel_handler})
  169. self.thread = PubSubWorkerThread(
  170. self.pubsub,
  171. sleep_time=self.thread_sleep_time,
  172. )
  173. self.thread.start()
  174. time.sleep(0.01)
  175. subscribers = connection.pubsub_numsub(self.channel)[0][1]
  176. if subscribers == 1: # pragma: no branch
  177. return self
  178. else: # pragma: no cover
  179. # Race condition, let's try again
  180. self.release()
  181. if fail_when_locked: # pragma: no cover
  182. raise exceptions.AlreadyLocked()
  183. raise exceptions.AlreadyLocked()
  184. def check_or_kill_lock(
  185. self,
  186. connection: redis.client.Redis[str],
  187. timeout: float,
  188. ) -> bool | None:
  189. # Random channel name to get messages back from the lock
  190. response_channel = f'{self.channel}-{random.random()}'
  191. pubsub = connection.pubsub()
  192. pubsub.subscribe(response_channel)
  193. connection.publish(
  194. self.channel,
  195. json.dumps(
  196. dict(
  197. response_channel=response_channel,
  198. message='ping',
  199. ),
  200. ),
  201. )
  202. check_interval = min(self.thread_sleep_time, timeout / 10)
  203. for _ in self._timeout_generator(
  204. timeout,
  205. check_interval,
  206. ): # pragma: no branch
  207. if pubsub.get_message(timeout=check_interval):
  208. pubsub.close()
  209. return True
  210. for client_ in connection.client_list('pubsub'): # pragma: no cover
  211. if client_.get('name') == self.client_name:
  212. logger.warning('Killing unavailable redis client: %r', client_)
  213. connection.client_kill_filter( # pyright: ignore
  214. client_.get('id'),
  215. )
  216. return None
  217. def release(self) -> None:
  218. if self.thread: # pragma: no branch
  219. self.thread.stop()
  220. self.thread.join()
  221. self.thread = None
  222. time.sleep(0.01)
  223. if self.pubsub: # pragma: no branch
  224. self.pubsub.unsubscribe(self.channel)
  225. self.pubsub.close()
  226. self.pubsub = None
  227. def __del__(self) -> None:
  228. self.release()