huffman.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. """
  2. An implementation of a bitwise prefix tree specially built for decoding
  3. Huffman-coded content where we already know the Huffman table.
  4. """
  5. from __future__ import annotations
  6. class HuffmanEncoder:
  7. """
  8. Encodes a string according to the Huffman encoding table defined in the
  9. HPACK specification.
  10. """
  11. def __init__(self, huffman_code_list: list[int], huffman_code_list_lengths: list[int]) -> None:
  12. self.huffman_code_list = huffman_code_list
  13. self.huffman_code_list_lengths = huffman_code_list_lengths
  14. def encode(self, bytes_to_encode: bytes | None) -> bytes:
  15. """
  16. Given a string of bytes, encodes them according to the HPACK Huffman
  17. specification.
  18. """
  19. # If handed the empty string, just immediately return.
  20. if not bytes_to_encode:
  21. return b""
  22. final_num = 0
  23. final_int_len = 0
  24. # Turn each byte into its huffman code. These codes aren't necessarily
  25. # octet aligned, so keep track of how far through an octet we are. To
  26. # handle this cleanly, just use a single giant integer.
  27. for byte in bytes_to_encode:
  28. bin_int_len = self.huffman_code_list_lengths[byte]
  29. bin_int = self.huffman_code_list[byte] & (
  30. 2 ** (bin_int_len + 1) - 1
  31. )
  32. final_num <<= bin_int_len
  33. final_num |= bin_int
  34. final_int_len += bin_int_len
  35. # Pad out to an octet with ones.
  36. bits_to_be_padded = (8 - (final_int_len % 8)) % 8
  37. final_num <<= bits_to_be_padded
  38. final_num |= (1 << bits_to_be_padded) - 1
  39. # Convert the number to hex and strip off the leading '0x' and the
  40. # trailing 'L', if present.
  41. s = hex(final_num)[2:].rstrip("L")
  42. # If this is odd, prepend a zero.
  43. s = "0" + s if len(s) % 2 != 0 else s
  44. # This number should have twice as many digits as bytes. If not, we're
  45. # missing some leading zeroes. Work out how many bytes we want and how
  46. # many digits we have, then add the missing zero digits to the front.
  47. total_bytes = (final_int_len + bits_to_be_padded) // 8
  48. expected_digits = total_bytes * 2
  49. if len(s) != expected_digits:
  50. missing_digits = expected_digits - len(s)
  51. s = ("0" * missing_digits) + s
  52. return bytes.fromhex(s)