| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- # pyright: reportUnknownMemberType=false
- from __future__ import annotations
- import _thread
- import json
- import logging
- import random
- import time
- import typing
- import redis
- from . import exceptions, utils
- logger = logging.getLogger(__name__)
- DEFAULT_UNAVAILABLE_TIMEOUT = 1
- DEFAULT_THREAD_SLEEP_TIME = 0.1
- class PubSubWorkerThread(redis.client.PubSubWorkerThread):
- def run(self) -> None:
- try:
- super().run()
- except Exception: # pragma: no cover
- _thread.interrupt_main()
- raise
- class RedisLock(utils.LockBase):
- """
- An extremely reliable Redis lock based on pubsub with a keep-alive thread
- As opposed to most Redis locking systems based on key/value pairs,
- this locking method is based on the pubsub system. The big advantage is
- that if the connection gets killed due to network issues, crashing
- processes or otherwise, it will still immediately unlock instead of
- waiting for a lock timeout.
- To make sure both sides of the lock know about the connection state it is
- recommended to set the `health_check_interval` when creating the redis
- connection..
- Args:
- channel: the redis channel to use as locking key.
- connection: an optional redis connection if you already have one
- or if you need to specify the redis connection
- timeout: timeout when trying to acquire a lock
- check_interval: check interval while waiting
- fail_when_locked: after the initial lock failed, return an error
- or lock the file. This does not wait for the timeout.
- thread_sleep_time: sleep time between fetching messages from redis to
- prevent a busy/wait loop. In the case of lock conflicts this
- increases the time it takes to resolve the conflict. This should
- be smaller than the `check_interval` to be useful.
- unavailable_timeout: If the conflicting lock is properly connected
- this should never exceed twice your redis latency. Note that this
- will increase the wait time possibly beyond your `timeout` and is
- always executed if a conflict arises.
- redis_kwargs: The redis connection arguments if no connection is
- given. The `DEFAULT_REDIS_KWARGS` are used as default, if you want
- to override these you need to explicitly specify a value (e.g.
- `health_check_interval=0`)
- """
- redis_kwargs: dict[str, typing.Any]
- thread: PubSubWorkerThread | None
- channel: str
- timeout: float
- connection: redis.client.Redis[str] | None
- pubsub: redis.client.PubSub | None = None
- close_connection: bool
- DEFAULT_REDIS_KWARGS: typing.ClassVar[dict[str, typing.Any]] = dict(
- health_check_interval=10,
- decode_responses=True,
- )
- def __init__(
- self,
- channel: str,
- connection: redis.client.Redis[str] | None = None,
- timeout: float | None = None,
- check_interval: float | None = None,
- fail_when_locked: bool | None = False,
- thread_sleep_time: float = DEFAULT_THREAD_SLEEP_TIME,
- unavailable_timeout: float = DEFAULT_UNAVAILABLE_TIMEOUT,
- redis_kwargs: dict[str, typing.Any] | None = None,
- ) -> None:
- # We don't want to close connections given as an argument
- self.close_connection = not connection
- self.thread = None
- self.channel = channel
- self.connection = connection
- self.thread_sleep_time = thread_sleep_time
- self.unavailable_timeout = unavailable_timeout
- self.redis_kwargs = redis_kwargs or dict()
- for key, value in self.DEFAULT_REDIS_KWARGS.items():
- self.redis_kwargs.setdefault(key, value)
- super().__init__(
- timeout=timeout,
- check_interval=check_interval,
- fail_when_locked=fail_when_locked,
- )
- def get_connection(self) -> redis.client.Redis[str]:
- if not self.connection:
- self.connection = redis.client.Redis(**self.redis_kwargs)
- return self.connection
- def channel_handler(self, message: dict[str, str]) -> None:
- if message.get('type') != 'message': # pragma: no cover
- return
- raw_data = message.get('data')
- if not raw_data:
- return
- try:
- data = json.loads(raw_data)
- except TypeError: # pragma: no cover
- logger.debug('TypeError while parsing: %r', message)
- return
- assert self.connection is not None
- self.connection.publish(data['response_channel'], str(time.time()))
- @property
- def client_name(self) -> str:
- return f'{self.channel}-lock'
- def _timeout_generator(
- self, timeout: float | None, check_interval: float | None
- ) -> typing.Iterator[int]:
- if timeout is None:
- timeout = 0.0
- if check_interval is None:
- check_interval = self.thread_sleep_time
- deadline = time.monotonic() + timeout
- first = True
- while first or time.monotonic() < deadline:
- first = False
- effective_interval = (
- check_interval
- if check_interval > 0
- else self.thread_sleep_time
- )
- sleep_time = effective_interval * (0.5 + random.random())
- time.sleep(sleep_time)
- yield 0
- def acquire( # type: ignore[override]
- self,
- timeout: float | None = None,
- check_interval: float | None = None,
- fail_when_locked: bool | None = None,
- ) -> RedisLock:
- timeout = utils.coalesce(timeout, self.timeout, 0.0)
- check_interval = utils.coalesce(
- check_interval,
- self.check_interval,
- 0.0,
- )
- fail_when_locked = utils.coalesce(
- fail_when_locked,
- self.fail_when_locked,
- )
- assert not self.pubsub, 'This lock is already active'
- connection = self.get_connection()
- timeout_generator = self._timeout_generator(timeout, check_interval)
- for _ in timeout_generator: # pragma: no branch
- subscribers = connection.pubsub_numsub(self.channel)[0][1]
- if subscribers:
- logger.debug(
- 'Found %d lock subscribers for %s',
- subscribers,
- self.channel,
- )
- if self.check_or_kill_lock(
- connection,
- self.unavailable_timeout,
- ): # pragma: no branch
- continue
- else: # pragma: no cover
- subscribers = 0
- # Note: this should not be changed to an elif because the if
- # above can still end up here
- if not subscribers:
- connection.client_setname(self.client_name)
- self.pubsub = connection.pubsub()
- self.pubsub.subscribe(**{self.channel: self.channel_handler})
- self.thread = PubSubWorkerThread(
- self.pubsub,
- sleep_time=self.thread_sleep_time,
- )
- self.thread.start()
- time.sleep(0.01)
- subscribers = connection.pubsub_numsub(self.channel)[0][1]
- if subscribers == 1: # pragma: no branch
- return self
- else: # pragma: no cover
- # Race condition, let's try again
- self.release()
- if fail_when_locked: # pragma: no cover
- raise exceptions.AlreadyLocked()
- raise exceptions.AlreadyLocked()
- def check_or_kill_lock(
- self,
- connection: redis.client.Redis[str],
- timeout: float,
- ) -> bool | None:
- # Random channel name to get messages back from the lock
- response_channel = f'{self.channel}-{random.random()}'
- pubsub = connection.pubsub()
- pubsub.subscribe(response_channel)
- connection.publish(
- self.channel,
- json.dumps(
- dict(
- response_channel=response_channel,
- message='ping',
- ),
- ),
- )
- check_interval = min(self.thread_sleep_time, timeout / 10)
- for _ in self._timeout_generator(
- timeout,
- check_interval,
- ): # pragma: no branch
- if pubsub.get_message(timeout=check_interval):
- pubsub.close()
- return True
- for client_ in connection.client_list('pubsub'): # pragma: no cover
- if client_.get('name') == self.client_name:
- logger.warning('Killing unavailable redis client: %r', client_)
- connection.client_kill_filter( # pyright: ignore
- client_.get('id'),
- )
- return None
- def release(self) -> None:
- if self.thread: # pragma: no branch
- self.thread.stop()
- self.thread.join()
- self.thread = None
- time.sleep(0.01)
- if self.pubsub: # pragma: no branch
- self.pubsub.unsubscribe(self.channel)
- self.pubsub.close()
- self.pubsub = None
- def __del__(self) -> None:
- self.release()
|