jbig2.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import math
  2. import os
  3. from struct import calcsize, pack, unpack
  4. from typing import BinaryIO, Dict, Iterable, List, Optional, Tuple, Union, cast
  5. from pdfminer.pdfexceptions import PDFValueError
  6. # segment structure base
  7. SEG_STRUCT = [
  8. (">L", "number"),
  9. (">B", "flags"),
  10. (">B", "retention_flags"),
  11. (">B", "page_assoc"),
  12. (">L", "data_length"),
  13. ]
  14. # segment header literals
  15. HEADER_FLAG_DEFERRED = 0b10000000
  16. HEADER_FLAG_PAGE_ASSOC_LONG = 0b01000000
  17. SEG_TYPE_MASK = 0b00111111
  18. REF_COUNT_SHORT_MASK = 0b11100000
  19. REF_COUNT_LONG_MASK = 0x1FFFFFFF
  20. REF_COUNT_LONG = 7
  21. DATA_LEN_UNKNOWN = 0xFFFFFFFF
  22. # segment types
  23. SEG_TYPE_IMMEDIATE_GEN_REGION = 38
  24. SEG_TYPE_END_OF_PAGE = 49
  25. SEG_TYPE_END_OF_FILE = 51
  26. # file literals
  27. FILE_HEADER_ID = b"\x97\x4a\x42\x32\x0d\x0a\x1a\x0a"
  28. FILE_HEAD_FLAG_SEQUENTIAL = 0b00000001
  29. def bit_set(bit_pos: int, value: int) -> bool:
  30. return bool((value >> bit_pos) & 1)
  31. def check_flag(flag: int, value: int) -> bool:
  32. return bool(flag & value)
  33. def masked_value(mask: int, value: int) -> int:
  34. for bit_pos in range(31):
  35. if bit_set(bit_pos, mask):
  36. return (value & mask) >> bit_pos
  37. raise PDFValueError("Invalid mask or value")
  38. def mask_value(mask: int, value: int) -> int:
  39. for bit_pos in range(31):
  40. if bit_set(bit_pos, mask):
  41. return (value & (mask >> bit_pos)) << bit_pos
  42. raise PDFValueError("Invalid mask or value")
  43. def unpack_int(format: str, buffer: bytes) -> int:
  44. assert format in {">B", ">I", ">L"}
  45. [result] = cast(Tuple[int], unpack(format, buffer))
  46. return result
  47. JBIG2SegmentFlags = Dict[str, Union[int, bool]]
  48. JBIG2RetentionFlags = Dict[str, Union[int, List[int], List[bool]]]
  49. JBIG2Segment = Dict[
  50. str,
  51. Union[bool, int, bytes, JBIG2SegmentFlags, JBIG2RetentionFlags],
  52. ]
  53. class JBIG2StreamReader:
  54. """Read segments from a JBIG2 byte stream"""
  55. def __init__(self, stream: BinaryIO) -> None:
  56. self.stream = stream
  57. def get_segments(self) -> List[JBIG2Segment]:
  58. segments: List[JBIG2Segment] = []
  59. while not self.is_eof():
  60. segment: JBIG2Segment = {}
  61. for field_format, name in SEG_STRUCT:
  62. field_len = calcsize(field_format)
  63. field = self.stream.read(field_len)
  64. if len(field) < field_len:
  65. segment["_error"] = True
  66. break
  67. value = unpack_int(field_format, field)
  68. parser = getattr(self, "parse_%s" % name, None)
  69. if callable(parser):
  70. value = parser(segment, value, field)
  71. segment[name] = value
  72. if not segment.get("_error"):
  73. segments.append(segment)
  74. return segments
  75. def is_eof(self) -> bool:
  76. if self.stream.read(1) == b"":
  77. return True
  78. else:
  79. self.stream.seek(-1, os.SEEK_CUR)
  80. return False
  81. def parse_flags(
  82. self,
  83. segment: JBIG2Segment,
  84. flags: int,
  85. field: bytes,
  86. ) -> JBIG2SegmentFlags:
  87. return {
  88. "deferred": check_flag(HEADER_FLAG_DEFERRED, flags),
  89. "page_assoc_long": check_flag(HEADER_FLAG_PAGE_ASSOC_LONG, flags),
  90. "type": masked_value(SEG_TYPE_MASK, flags),
  91. }
  92. def parse_retention_flags(
  93. self,
  94. segment: JBIG2Segment,
  95. flags: int,
  96. field: bytes,
  97. ) -> JBIG2RetentionFlags:
  98. ref_count = masked_value(REF_COUNT_SHORT_MASK, flags)
  99. retain_segments = []
  100. ref_segments = []
  101. if ref_count < REF_COUNT_LONG:
  102. for bit_pos in range(5):
  103. retain_segments.append(bit_set(bit_pos, flags))
  104. else:
  105. field += self.stream.read(3)
  106. ref_count = unpack_int(">L", field)
  107. ref_count = masked_value(REF_COUNT_LONG_MASK, ref_count)
  108. ret_bytes_count = int(math.ceil((ref_count + 1) / 8))
  109. for ret_byte_index in range(ret_bytes_count):
  110. ret_byte = unpack_int(">B", self.stream.read(1))
  111. for bit_pos in range(7):
  112. retain_segments.append(bit_set(bit_pos, ret_byte))
  113. seg_num = segment["number"]
  114. assert isinstance(seg_num, int)
  115. if seg_num <= 256:
  116. ref_format = ">B"
  117. elif seg_num <= 65536:
  118. ref_format = ">I"
  119. else:
  120. ref_format = ">L"
  121. ref_size = calcsize(ref_format)
  122. for ref_index in range(ref_count):
  123. ref_data = self.stream.read(ref_size)
  124. ref = unpack_int(ref_format, ref_data)
  125. ref_segments.append(ref)
  126. return {
  127. "ref_count": ref_count,
  128. "retain_segments": retain_segments,
  129. "ref_segments": ref_segments,
  130. }
  131. def parse_page_assoc(self, segment: JBIG2Segment, page: int, field: bytes) -> int:
  132. if cast(JBIG2SegmentFlags, segment["flags"])["page_assoc_long"]:
  133. field += self.stream.read(3)
  134. page = unpack_int(">L", field)
  135. return page
  136. def parse_data_length(
  137. self,
  138. segment: JBIG2Segment,
  139. length: int,
  140. field: bytes,
  141. ) -> int:
  142. if length:
  143. if (
  144. cast(JBIG2SegmentFlags, segment["flags"])["type"]
  145. == SEG_TYPE_IMMEDIATE_GEN_REGION
  146. ) and (length == DATA_LEN_UNKNOWN):
  147. raise NotImplementedError(
  148. "Working with unknown segment length is not implemented yet",
  149. )
  150. else:
  151. segment["raw_data"] = self.stream.read(length)
  152. return length
  153. class JBIG2StreamWriter:
  154. """Write JBIG2 segments to a file in JBIG2 format"""
  155. EMPTY_RETENTION_FLAGS: JBIG2RetentionFlags = {
  156. "ref_count": 0,
  157. "ref_segments": cast(List[int], []),
  158. "retain_segments": cast(List[bool], []),
  159. }
  160. def __init__(self, stream: BinaryIO) -> None:
  161. self.stream = stream
  162. def write_segments(
  163. self,
  164. segments: Iterable[JBIG2Segment],
  165. fix_last_page: bool = True,
  166. ) -> int:
  167. data_len = 0
  168. current_page: Optional[int] = None
  169. seg_num: Optional[int] = None
  170. for segment in segments:
  171. data = self.encode_segment(segment)
  172. self.stream.write(data)
  173. data_len += len(data)
  174. seg_num = cast(Optional[int], segment["number"])
  175. if fix_last_page:
  176. seg_page = cast(int, segment.get("page_assoc"))
  177. if (
  178. cast(JBIG2SegmentFlags, segment["flags"])["type"]
  179. == SEG_TYPE_END_OF_PAGE
  180. ):
  181. current_page = None
  182. elif seg_page:
  183. current_page = seg_page
  184. if fix_last_page and current_page and (seg_num is not None):
  185. segment = self.get_eop_segment(seg_num + 1, current_page)
  186. data = self.encode_segment(segment)
  187. self.stream.write(data)
  188. data_len += len(data)
  189. return data_len
  190. def write_file(
  191. self,
  192. segments: Iterable[JBIG2Segment],
  193. fix_last_page: bool = True,
  194. ) -> int:
  195. header = FILE_HEADER_ID
  196. header_flags = FILE_HEAD_FLAG_SEQUENTIAL
  197. header += pack(">B", header_flags)
  198. # The embedded JBIG2 files in a PDF always
  199. # only have one page
  200. number_of_pages = pack(">L", 1)
  201. header += number_of_pages
  202. self.stream.write(header)
  203. data_len = len(header)
  204. data_len += self.write_segments(segments, fix_last_page)
  205. seg_num = 0
  206. for segment in segments:
  207. seg_num = cast(int, segment["number"])
  208. if fix_last_page:
  209. seg_num_offset = 2
  210. else:
  211. seg_num_offset = 1
  212. eof_segment = self.get_eof_segment(seg_num + seg_num_offset)
  213. data = self.encode_segment(eof_segment)
  214. self.stream.write(data)
  215. data_len += len(data)
  216. return data_len
  217. def encode_segment(self, segment: JBIG2Segment) -> bytes:
  218. data = b""
  219. for field_format, name in SEG_STRUCT:
  220. value = segment.get(name)
  221. encoder = getattr(self, "encode_%s" % name, None)
  222. if callable(encoder):
  223. field = encoder(value, segment)
  224. else:
  225. field = pack(field_format, value)
  226. data += field
  227. return data
  228. def encode_flags(self, value: JBIG2SegmentFlags, segment: JBIG2Segment) -> bytes:
  229. flags = 0
  230. if value.get("deferred"):
  231. flags |= HEADER_FLAG_DEFERRED
  232. if "page_assoc_long" in value:
  233. flags |= HEADER_FLAG_PAGE_ASSOC_LONG if value["page_assoc_long"] else flags
  234. else:
  235. flags |= (
  236. HEADER_FLAG_PAGE_ASSOC_LONG
  237. if cast(int, segment.get("page", 0)) > 255
  238. else flags
  239. )
  240. flags |= mask_value(SEG_TYPE_MASK, value["type"])
  241. return pack(">B", flags)
  242. def encode_retention_flags(
  243. self,
  244. value: JBIG2RetentionFlags,
  245. segment: JBIG2Segment,
  246. ) -> bytes:
  247. flags = []
  248. flags_format = ">B"
  249. ref_count = value["ref_count"]
  250. assert isinstance(ref_count, int)
  251. retain_segments = cast(List[bool], value.get("retain_segments", []))
  252. if ref_count <= 4:
  253. flags_byte = mask_value(REF_COUNT_SHORT_MASK, ref_count)
  254. for ref_index, ref_retain in enumerate(retain_segments):
  255. if ref_retain:
  256. flags_byte |= 1 << ref_index
  257. flags.append(flags_byte)
  258. else:
  259. bytes_count = math.ceil((ref_count + 1) / 8)
  260. flags_format = ">L" + ("B" * bytes_count)
  261. flags_dword = mask_value(REF_COUNT_SHORT_MASK, REF_COUNT_LONG) << 24
  262. flags.append(flags_dword)
  263. for byte_index in range(bytes_count):
  264. ret_byte = 0
  265. ret_part = retain_segments[byte_index * 8 : byte_index * 8 + 8]
  266. for bit_pos, ret_seg in enumerate(ret_part):
  267. ret_byte |= 1 << bit_pos if ret_seg else ret_byte
  268. flags.append(ret_byte)
  269. ref_segments = cast(List[int], value.get("ref_segments", []))
  270. seg_num = cast(int, segment["number"])
  271. if seg_num <= 256:
  272. ref_format = "B"
  273. elif seg_num <= 65536:
  274. ref_format = "I"
  275. else:
  276. ref_format = "L"
  277. for ref in ref_segments:
  278. flags_format += ref_format
  279. flags.append(ref)
  280. return pack(flags_format, *flags)
  281. def encode_data_length(self, value: int, segment: JBIG2Segment) -> bytes:
  282. data = pack(">L", value)
  283. data += cast(bytes, segment["raw_data"])
  284. return data
  285. def get_eop_segment(self, seg_number: int, page_number: int) -> JBIG2Segment:
  286. return {
  287. "data_length": 0,
  288. "flags": {"deferred": False, "type": SEG_TYPE_END_OF_PAGE},
  289. "number": seg_number,
  290. "page_assoc": page_number,
  291. "raw_data": b"",
  292. "retention_flags": JBIG2StreamWriter.EMPTY_RETENTION_FLAGS,
  293. }
  294. def get_eof_segment(self, seg_number: int) -> JBIG2Segment:
  295. return {
  296. "data_length": 0,
  297. "flags": {"deferred": False, "type": SEG_TYPE_END_OF_FILE},
  298. "number": seg_number,
  299. "page_assoc": 0,
  300. "raw_data": b"",
  301. "retention_flags": JBIG2StreamWriter.EMPTY_RETENTION_FLAGS,
  302. }