lzw.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import logging
  2. from io import BytesIO
  3. from typing import BinaryIO, Iterator, List, Optional, cast
  4. from pdfminer.pdfexceptions import PDFEOFError, PDFException
  5. logger = logging.getLogger(__name__)
  6. class CorruptDataError(PDFException):
  7. pass
  8. class LZWDecoder:
  9. def __init__(self, fp: BinaryIO) -> None:
  10. self.fp = fp
  11. self.buff = 0
  12. self.bpos = 8
  13. self.nbits = 9
  14. # NB: self.table stores None only in indices 256 and 257
  15. self.table: List[Optional[bytes]] = []
  16. self.prevbuf: Optional[bytes] = None
  17. def readbits(self, bits: int) -> int:
  18. v = 0
  19. while 1:
  20. # the number of remaining bits we can get from the current buffer.
  21. r = 8 - self.bpos
  22. if bits <= r:
  23. # |-----8-bits-----|
  24. # |-bpos-|-bits-| |
  25. # | |----r----|
  26. v = (v << bits) | ((self.buff >> (r - bits)) & ((1 << bits) - 1))
  27. self.bpos += bits
  28. break
  29. else:
  30. # |-----8-bits-----|
  31. # |-bpos-|---bits----...
  32. # | |----r----|
  33. v = (v << r) | (self.buff & ((1 << r) - 1))
  34. bits -= r
  35. x = self.fp.read(1)
  36. if not x:
  37. raise PDFEOFError
  38. self.buff = ord(x)
  39. self.bpos = 0
  40. return v
  41. def feed(self, code: int) -> bytes:
  42. x = b""
  43. if code == 256:
  44. self.table = [bytes((c,)) for c in range(256)] # 0-255
  45. self.table.append(None) # 256
  46. self.table.append(None) # 257
  47. self.prevbuf = b""
  48. self.nbits = 9
  49. elif code == 257:
  50. pass
  51. elif not self.prevbuf:
  52. x = self.prevbuf = cast(bytes, self.table[code]) # assume not None
  53. else:
  54. if code < len(self.table):
  55. x = cast(bytes, self.table[code]) # assume not None
  56. self.table.append(self.prevbuf + x[:1])
  57. elif code == len(self.table):
  58. self.table.append(self.prevbuf + self.prevbuf[:1])
  59. x = cast(bytes, self.table[code])
  60. else:
  61. raise CorruptDataError
  62. table_length = len(self.table)
  63. if table_length == 511:
  64. self.nbits = 10
  65. elif table_length == 1023:
  66. self.nbits = 11
  67. elif table_length == 2047:
  68. self.nbits = 12
  69. self.prevbuf = x
  70. return x
  71. def run(self) -> Iterator[bytes]:
  72. while 1:
  73. try:
  74. code = self.readbits(self.nbits)
  75. except EOFError:
  76. break
  77. try:
  78. x = self.feed(code)
  79. except CorruptDataError:
  80. # just ignore corrupt data and stop yielding there
  81. break
  82. yield x
  83. logger.debug(
  84. "nbits=%d, code=%d, output=%r, table=%r",
  85. self.nbits,
  86. code,
  87. x,
  88. self.table[258:],
  89. )
  90. def lzwdecode(data: bytes) -> bytes:
  91. fp = BytesIO(data)
  92. s = LZWDecoder(fp).run()
  93. return b"".join(s)