utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. from __future__ import annotations
  2. import abc
  3. import atexit
  4. import contextlib
  5. import logging
  6. import os
  7. import pathlib
  8. import random
  9. import tempfile
  10. import time
  11. import typing
  12. import warnings
  13. from . import constants, exceptions, portalocker, types
  14. from .types import Filename, Mode
  15. logger = logging.getLogger(__name__)
  16. DEFAULT_TIMEOUT = 5
  17. DEFAULT_CHECK_INTERVAL = 0.25
  18. DEFAULT_FAIL_WHEN_LOCKED = False
  19. LOCK_METHOD = constants.LockFlags.EXCLUSIVE | constants.LockFlags.NON_BLOCKING
  20. __all__ = [
  21. 'Lock',
  22. 'open_atomic',
  23. ]
  24. def coalesce(*args: typing.Any, test_value: typing.Any = None) -> typing.Any:
  25. """Simple coalescing function that returns the first value that is not
  26. equal to the `test_value`. Or `None` if no value is valid. Usually this
  27. means that the last given value is the default value.
  28. Note that the `test_value` is compared using an identity check
  29. (i.e. `value is not test_value`) so changing the `test_value` won't work
  30. for all values.
  31. >>> coalesce(None, 1)
  32. 1
  33. >>> coalesce()
  34. >>> coalesce(0, False, True)
  35. 0
  36. >>> coalesce(0, False, True, test_value=0)
  37. False
  38. # This won't work because of the `is not test_value` type testing:
  39. >>> coalesce([], dict(spam='eggs'), test_value=[])
  40. []
  41. """
  42. return next((arg for arg in args if arg is not test_value), None)
  43. @contextlib.contextmanager
  44. def open_atomic(
  45. filename: Filename,
  46. binary: bool = True,
  47. ) -> typing.Iterator[types.IO]:
  48. """Open a file for atomic writing. Instead of locking this method allows
  49. you to write the entire file and move it to the actual location. Note that
  50. this makes the assumption that a rename is atomic on your platform which
  51. is generally the case but not a guarantee.
  52. http://docs.python.org/library/os.html#os.rename
  53. >>> filename = 'test_file.txt'
  54. >>> if os.path.exists(filename):
  55. ... os.remove(filename)
  56. >>> with open_atomic(filename) as fh:
  57. ... written = fh.write(b'test')
  58. >>> assert os.path.exists(filename)
  59. >>> os.remove(filename)
  60. >>> import pathlib
  61. >>> path_filename = pathlib.Path('test_file.txt')
  62. >>> with open_atomic(path_filename) as fh:
  63. ... written = fh.write(b'test')
  64. >>> assert path_filename.exists()
  65. >>> path_filename.unlink()
  66. """
  67. # `pathlib.Path` cast in case `path` is a `str`
  68. path: pathlib.Path
  69. if isinstance(filename, pathlib.Path):
  70. path = filename
  71. else:
  72. path = pathlib.Path(filename)
  73. assert not path.exists(), f'{path!r} exists'
  74. # Create the parent directory if it doesn't exist
  75. path.parent.mkdir(parents=True, exist_ok=True)
  76. with tempfile.NamedTemporaryFile(
  77. mode=(binary and 'wb') or 'w',
  78. dir=str(path.parent),
  79. delete=False,
  80. ) as temp_fh:
  81. yield temp_fh
  82. temp_fh.flush()
  83. os.fsync(temp_fh.fileno())
  84. try:
  85. os.rename(temp_fh.name, path)
  86. finally:
  87. with contextlib.suppress(Exception):
  88. os.remove(temp_fh.name)
  89. class LockBase(abc.ABC): # pragma: no cover
  90. #: timeout when trying to acquire a lock
  91. timeout: float
  92. #: check interval while waiting for `timeout`
  93. check_interval: float
  94. #: skip the timeout and immediately fail if the initial lock fails
  95. fail_when_locked: bool
  96. def __init__(
  97. self,
  98. timeout: float | None = None,
  99. check_interval: float | None = None,
  100. fail_when_locked: bool | None = None,
  101. ) -> None:
  102. self.timeout = coalesce(timeout, DEFAULT_TIMEOUT)
  103. self.check_interval = coalesce(check_interval, DEFAULT_CHECK_INTERVAL)
  104. self.fail_when_locked = coalesce(
  105. fail_when_locked,
  106. DEFAULT_FAIL_WHEN_LOCKED,
  107. )
  108. @abc.abstractmethod
  109. def acquire(
  110. self,
  111. timeout: float | None = None,
  112. check_interval: float | None = None,
  113. fail_when_locked: bool | None = None,
  114. ) -> typing.IO[typing.AnyStr]: ...
  115. def _timeout_generator(
  116. self,
  117. timeout: float | None,
  118. check_interval: float | None,
  119. ) -> typing.Iterator[int]:
  120. f_timeout = coalesce(timeout, self.timeout, 0.0)
  121. f_check_interval = coalesce(check_interval, self.check_interval, 0.0)
  122. yield 0
  123. i = 0
  124. start_time = time.perf_counter()
  125. while start_time + f_timeout > time.perf_counter():
  126. i += 1
  127. yield i
  128. # Take low lock checks into account to stay within the interval
  129. since_start_time = time.perf_counter() - start_time
  130. time.sleep(max(0.001, (i * f_check_interval) - since_start_time))
  131. @abc.abstractmethod
  132. def release(self) -> None: ...
  133. def __enter__(self) -> typing.IO[typing.AnyStr]:
  134. return self.acquire()
  135. def __exit__(
  136. self,
  137. exc_type: type[BaseException] | None,
  138. exc_value: BaseException | None,
  139. traceback: typing.Any, # Should be typing.TracebackType
  140. ) -> bool | None:
  141. self.release()
  142. return None
  143. def __delete__(self, instance: LockBase) -> None:
  144. instance.release()
  145. class Lock(LockBase):
  146. """Lock manager with built-in timeout
  147. Args:
  148. filename: filename
  149. mode: the open mode, 'a' or 'ab' should be used for writing. When mode
  150. contains `w` the file will be truncated to 0 bytes.
  151. timeout: timeout when trying to acquire a lock
  152. check_interval: check interval while waiting
  153. fail_when_locked: after the initial lock failed, return an error
  154. or lock the file. This does not wait for the timeout.
  155. **file_open_kwargs: The kwargs for the `open(...)` call
  156. fail_when_locked is useful when multiple threads/processes can race
  157. when creating a file. If set to true than the system will wait till
  158. the lock was acquired and then return an AlreadyLocked exception.
  159. Note that the file is opened first and locked later. So using 'w' as
  160. mode will result in truncate _BEFORE_ the lock is checked.
  161. """
  162. fh: types.IO | None
  163. filename: str
  164. mode: str
  165. truncate: bool
  166. timeout: float
  167. check_interval: float
  168. fail_when_locked: bool
  169. flags: constants.LockFlags
  170. file_open_kwargs: dict[str, typing.Any]
  171. def __init__(
  172. self,
  173. filename: Filename,
  174. mode: Mode = 'a',
  175. timeout: float | None = None,
  176. check_interval: float = DEFAULT_CHECK_INTERVAL,
  177. fail_when_locked: bool = DEFAULT_FAIL_WHEN_LOCKED,
  178. flags: constants.LockFlags = LOCK_METHOD,
  179. **file_open_kwargs: typing.Any,
  180. ) -> None:
  181. if 'w' in mode:
  182. truncate = True
  183. mode = typing.cast(Mode, mode.replace('w', 'a'))
  184. else:
  185. truncate = False
  186. if timeout is None:
  187. timeout = DEFAULT_TIMEOUT
  188. elif not (flags & constants.LockFlags.NON_BLOCKING):
  189. warnings.warn(
  190. 'timeout has no effect in blocking mode',
  191. stacklevel=1,
  192. )
  193. self.fh = None
  194. self.filename = str(filename)
  195. self.mode = mode
  196. self.truncate = truncate
  197. self.flags = flags
  198. self.file_open_kwargs = file_open_kwargs
  199. super().__init__(timeout, check_interval, fail_when_locked)
  200. def acquire(
  201. self,
  202. timeout: float | None = None,
  203. check_interval: float | None = None,
  204. fail_when_locked: bool | None = None,
  205. ) -> typing.IO[typing.AnyStr]:
  206. """Acquire the locked filehandle"""
  207. fail_when_locked = coalesce(fail_when_locked, self.fail_when_locked)
  208. if (
  209. not (self.flags & constants.LockFlags.NON_BLOCKING)
  210. and timeout is not None
  211. ):
  212. warnings.warn(
  213. 'timeout has no effect in blocking mode',
  214. stacklevel=1,
  215. )
  216. # If we already have a filehandle, return it
  217. fh = self.fh
  218. if fh:
  219. # Due to type invariance we need to cast the type
  220. return typing.cast(typing.IO[typing.AnyStr], fh)
  221. # Get a new filehandler
  222. fh = self._get_fh()
  223. def try_close() -> None: # pragma: no cover
  224. # Silently try to close the handle if possible, ignore all issues
  225. if fh is not None:
  226. with contextlib.suppress(Exception):
  227. fh.close()
  228. exception = None
  229. # Try till the timeout has passed
  230. for _ in self._timeout_generator(timeout, check_interval):
  231. exception = None
  232. try:
  233. # Try to lock
  234. fh = self._get_lock(fh)
  235. break
  236. except exceptions.LockException as exc:
  237. # Python will automatically remove the variable from memory
  238. # unless you save it in a different location
  239. exception = exc
  240. # We already tried to the get the lock
  241. # If fail_when_locked is True, stop trying
  242. if fail_when_locked:
  243. try_close()
  244. raise exceptions.AlreadyLocked(exception) from exc
  245. except Exception as exc:
  246. # Something went wrong with the locking mechanism.
  247. # Wrap in a LockException and re-raise:
  248. try_close()
  249. raise exceptions.LockException(exc) from exc
  250. # Wait a bit
  251. if exception:
  252. try_close()
  253. # We got a timeout... reraising
  254. raise exception
  255. # Prepare the filehandle (truncate if needed)
  256. fh = self._prepare_fh(fh)
  257. self.fh = fh
  258. return typing.cast(typing.IO[typing.AnyStr], fh)
  259. def __enter__(self) -> typing.IO[typing.AnyStr]:
  260. return self.acquire()
  261. def release(self) -> None:
  262. """Releases the currently locked file handle"""
  263. if self.fh:
  264. portalocker.unlock(self.fh)
  265. self.fh.close()
  266. self.fh = None
  267. def _get_fh(self) -> types.IO:
  268. """Get a new filehandle"""
  269. return typing.cast(
  270. types.IO,
  271. open( # noqa: SIM115
  272. self.filename,
  273. self.mode,
  274. **self.file_open_kwargs,
  275. ),
  276. )
  277. def _get_lock(self, fh: types.IO) -> types.IO:
  278. """
  279. Try to lock the given filehandle
  280. returns LockException if it fails"""
  281. portalocker.lock(fh, self.flags)
  282. return fh
  283. def _prepare_fh(self, fh: types.IO) -> types.IO:
  284. """
  285. Prepare the filehandle for usage
  286. If truncate is a number, the file will be truncated to that amount of
  287. bytes
  288. """
  289. if self.truncate:
  290. fh.seek(0)
  291. fh.truncate(0)
  292. return fh
  293. class RLock(Lock):
  294. """
  295. A reentrant lock, functions in a similar way to threading.RLock in that it
  296. can be acquired multiple times. When the corresponding number of release()
  297. calls are made the lock will finally release the underlying file lock.
  298. """
  299. def __init__(
  300. self,
  301. filename: Filename,
  302. mode: Mode = 'a',
  303. timeout: float = DEFAULT_TIMEOUT,
  304. check_interval: float = DEFAULT_CHECK_INTERVAL,
  305. fail_when_locked: bool = False,
  306. flags: constants.LockFlags = LOCK_METHOD,
  307. ) -> None:
  308. super().__init__(
  309. filename,
  310. mode,
  311. timeout,
  312. check_interval,
  313. fail_when_locked,
  314. flags,
  315. )
  316. self._acquire_count = 0
  317. def acquire(
  318. self,
  319. timeout: float | None = None,
  320. check_interval: float | None = None,
  321. fail_when_locked: bool | None = None,
  322. ) -> typing.IO[typing.AnyStr]:
  323. fh: typing.IO[typing.AnyStr]
  324. if self._acquire_count >= 1:
  325. fh = typing.cast(typing.IO[typing.AnyStr], self.fh)
  326. else:
  327. fh = super().acquire(timeout, check_interval, fail_when_locked)
  328. self._acquire_count += 1
  329. assert fh is not None
  330. return fh
  331. def release(self) -> None:
  332. if self._acquire_count == 0:
  333. raise exceptions.LockException(
  334. 'Cannot release more times than acquired',
  335. )
  336. if self._acquire_count == 1:
  337. super().release()
  338. self._acquire_count -= 1
  339. class TemporaryFileLock(Lock):
  340. def __init__(
  341. self,
  342. filename: str = '.lock',
  343. timeout: float = DEFAULT_TIMEOUT,
  344. check_interval: float = DEFAULT_CHECK_INTERVAL,
  345. fail_when_locked: bool = True,
  346. flags: constants.LockFlags = LOCK_METHOD,
  347. ) -> None:
  348. super().__init__(
  349. filename=filename,
  350. mode='w',
  351. timeout=timeout,
  352. check_interval=check_interval,
  353. fail_when_locked=fail_when_locked,
  354. flags=flags,
  355. )
  356. atexit.register(self.release)
  357. def release(self) -> None:
  358. Lock.release(self)
  359. if os.path.isfile(self.filename): # pragma: no branch
  360. os.unlink(self.filename)
  361. class BoundedSemaphore(LockBase):
  362. """
  363. Bounded semaphore to prevent too many parallel processes from running
  364. This method is deprecated because multiple processes that are completely
  365. unrelated could end up using the same semaphore. To prevent this,
  366. use `NamedBoundedSemaphore` instead. The
  367. `NamedBoundedSemaphore` is a drop-in replacement for this class.
  368. >>> semaphore = BoundedSemaphore(2, directory='')
  369. >>> str(semaphore.get_filenames()[0])
  370. 'bounded_semaphore.00.lock'
  371. >>> str(sorted(semaphore.get_random_filenames())[1])
  372. 'bounded_semaphore.01.lock'
  373. """
  374. lock: Lock | None
  375. def __init__(
  376. self,
  377. maximum: int,
  378. name: str = 'bounded_semaphore',
  379. filename_pattern: str = '{name}.{number:02d}.lock',
  380. directory: str = tempfile.gettempdir(),
  381. timeout: float | None = DEFAULT_TIMEOUT,
  382. check_interval: float | None = DEFAULT_CHECK_INTERVAL,
  383. fail_when_locked: bool | None = True,
  384. ) -> None:
  385. self.maximum = maximum
  386. self.name = name
  387. self.filename_pattern = filename_pattern
  388. self.directory = directory
  389. self.lock: Lock | None = None
  390. super().__init__(
  391. timeout=timeout,
  392. check_interval=check_interval,
  393. fail_when_locked=fail_when_locked,
  394. )
  395. if not name or name == 'bounded_semaphore':
  396. warnings.warn(
  397. '`BoundedSemaphore` without an explicit `name` '
  398. 'argument is deprecated, use NamedBoundedSemaphore',
  399. DeprecationWarning,
  400. stacklevel=1,
  401. )
  402. def get_filenames(self) -> typing.Sequence[pathlib.Path]:
  403. return [self.get_filename(n) for n in range(self.maximum)]
  404. def get_random_filenames(self) -> typing.Sequence[pathlib.Path]:
  405. filenames = list(self.get_filenames())
  406. random.shuffle(filenames)
  407. return filenames
  408. def get_filename(self, number: int) -> pathlib.Path:
  409. return pathlib.Path(self.directory) / self.filename_pattern.format(
  410. name=self.name,
  411. number=number,
  412. )
  413. def acquire( # type: ignore[override]
  414. self,
  415. timeout: float | None = None,
  416. check_interval: float | None = None,
  417. fail_when_locked: bool | None = None,
  418. ) -> Lock | None:
  419. assert not self.lock, 'Already locked'
  420. filenames = self.get_filenames()
  421. for n in self._timeout_generator(timeout, check_interval): # pragma:
  422. logger.debug('trying lock (attempt %d) %r', n, filenames)
  423. # no branch
  424. if self.try_lock(filenames): # pragma: no branch
  425. return self.lock # pragma: no cover
  426. if fail_when_locked := coalesce(
  427. fail_when_locked,
  428. self.fail_when_locked,
  429. ):
  430. raise exceptions.AlreadyLocked()
  431. return None
  432. def try_lock(self, filenames: typing.Sequence[Filename]) -> bool:
  433. filename: Filename
  434. for filename in filenames:
  435. logger.debug('trying lock for %r', filename)
  436. self.lock = Lock(filename, fail_when_locked=True)
  437. try:
  438. self.lock.acquire()
  439. except exceptions.AlreadyLocked:
  440. self.lock = None
  441. else:
  442. logger.debug('locked %r', filename)
  443. return True
  444. return False
  445. def release(self) -> None: # pragma: no cover
  446. if self.lock is not None:
  447. self.lock.release()
  448. self.lock = None
  449. class NamedBoundedSemaphore(BoundedSemaphore):
  450. """
  451. Bounded semaphore to prevent too many parallel processes from running
  452. It's also possible to specify a timeout when acquiring the lock to wait
  453. for a resource to become available. This is very similar to
  454. `threading.BoundedSemaphore` but works across multiple processes and across
  455. multiple operating systems.
  456. Because this works across multiple processes it's important to give the
  457. semaphore a name. This name is used to create the lock files. If you
  458. don't specify a name, a random name will be generated. This means that
  459. you can't use the same semaphore in multiple processes unless you pass the
  460. semaphore object to the other processes.
  461. >>> semaphore = NamedBoundedSemaphore(2, name='test')
  462. >>> str(semaphore.get_filenames()[0])
  463. '...test.00.lock'
  464. >>> semaphore = NamedBoundedSemaphore(2)
  465. >>> 'bounded_semaphore' in str(semaphore.get_filenames()[0])
  466. True
  467. """
  468. def __init__(
  469. self,
  470. maximum: int,
  471. name: str | None = None,
  472. filename_pattern: str = '{name}.{number:02d}.lock',
  473. directory: str = tempfile.gettempdir(),
  474. timeout: float | None = DEFAULT_TIMEOUT,
  475. check_interval: float | None = DEFAULT_CHECK_INTERVAL,
  476. fail_when_locked: bool | None = True,
  477. ) -> None:
  478. if name is None:
  479. name = f'bounded_semaphore.{random.randint(0, 1000000):d}'
  480. super().__init__(
  481. maximum,
  482. name,
  483. filename_pattern,
  484. directory,
  485. timeout,
  486. check_interval,
  487. fail_when_locked,
  488. )