Files

312 lines
10 KiB
Python

"""Sparse Bit Set encoding/decoding for IFT (Incremental Font Transfer).
Implements the sparse bit set format defined in the W3C IFT specification:
https://w3c.github.io/IFT/Overview.html#sparse-bit-set-decoding
"""
# Reference: https://github.com/googlefonts/fontations/blob/main/read-fonts/src/collections/int_set/sparse_bit_set.rs
import bisect
from collections import deque
from typing import Dict, Iterable, Optional, Set, Tuple
# Maximum tree height that fits within a 32-bit leaf node for each branching factor.
_BF_MAX_HEIGHT: Dict[int, int] = {2: 31, 4: 16, 8: 11, 32: 7}
class SparseBitSetDecodeError(Exception):
pass
def decode(
data: bytes, bias: int = 0, maxValue: int = 0xFFFFFFFF
) -> Tuple[Set[int], int]:
"""Decode a sparse bit set from binary data.
Args:
data: bytes-like object containing the sparse bit set encoding.
bias: integer added to each decoded value.
Returns:
A tuple (values, bytesConsumed) where values is a set of integers
and bytesConsumed is the number of bytes read from data.
"""
if not data:
raise SparseBitSetDecodeError("Empty data")
branchFactor, height = _decodeHeader(data[0])
maxHeight = _BF_MAX_HEIGHT[branchFactor]
if height > maxHeight:
raise SparseBitSetDecodeError(
f"Height {height} exceeds max {maxHeight} for branch factor {branchFactor}"
)
return _decodeImpl(data, branchFactor, height, bias, maxValue)
def _decodeImpl(
data: bytes, branchFactor: int, height: int, bias: int, maxValue: int
) -> Tuple[Set[int], int]:
if height == 0:
# 1 byte was used for the header.
return (set(), 1)
bitStream = _InputBitStream(data, branchFactor)
result: Set[int] = set()
# Queue entries are (startValue, depth), where startValue is the first
# integer that could be covered by this node.
queue: deque[Tuple[int, int]] = deque()
queue.append((0, 1))
while queue:
start, depth = queue.popleft()
bits = bitStream.next()
if bits is None:
raise SparseBitSetDecodeError("Unexpected end of data")
# all bits were were zero which is a special command to completely fill
# in all integers covered by this node.
if bits == 0:
exp = height - depth + 1
nodeSize = branchFactor**exp
fillStart = start + bias
if fillStart > maxValue:
continue
fillEnd = min(maxValue, start + nodeSize - 1 + bias)
if fillStart < 0:
fillStart = 0
if fillStart <= fillEnd:
result.update(range(fillStart, fillEnd + 1))
continue
# Non-zero node: each set bit identifies a child/value.
exp = height - depth
nextNodeSize = branchFactor**exp
while True:
bitIndex = _trailingZeros(bits, 32)
if bitIndex == 32:
break
if depth == height:
val = start + bitIndex + bias
if val > maxValue:
queue.clear()
break
if val >= 0:
result.add(val)
else:
startDelta = bitIndex * nextNodeSize
queue.append((start + startDelta, depth + 1))
bits &= ~(1 << bitIndex)
return (result, bitStream.bytesConsumed())
def encode(values: Iterable[int]) -> bytes:
"""Encode a set of integers as a sparse bit set.
Tries all branching factors and returns the shortest encoding.
Args:
values: iterable of non-negative integers.
Returns:
bytes containing the sparse bit set encoding.
"""
valuesSorted = sorted(set(values))
if not valuesSorted:
return _encodeHeader(2, 0)
maxValue = valuesSorted[-1]
valueSet = set(valuesSorted)
best: Optional[bytes] = None
for branchFactor in (2, 4, 8, 32):
height = _treeHeight(branchFactor, maxValue)
if height > _BF_MAX_HEIGHT[branchFactor]:
continue
encoded = _encodeWithBf(valueSet, branchFactor, height)
if best is None or len(encoded) < len(best):
best = encoded
if best is None:
raise ValueError(f"Cannot encode max value {maxValue}")
return best
def _encodeHeader(branchFactor: int, height: int) -> bytes:
branchFactorToId = {2: 0, 4: 1, 8: 2, 32: 3}
return bytes([(height << 2) | branchFactorToId[branchFactor]])
def _decodeHeader(headerByte: int) -> Tuple[int, int]:
id = headerByte & 0x03
idToBranchFactor = {0: 2, 1: 4, 2: 8, 3: 32}
height = (headerByte >> 2) & 0x1F
return idToBranchFactor[id], height
def _treeHeight(branchFactor: int, maxValue: int) -> int:
"""Return the minimum tree height needed to represent maxValue."""
height = 1
capacity = branchFactor
while capacity <= maxValue:
capacity *= branchFactor
height += 1
return height
def _encodeWithBf(valueSet: Set[int], branchFactor: int, height: int) -> bytes:
if height == 0:
return _encodeHeader(branchFactor, 0)
# Build layers bottom-up: layer 0 = leaves (individual values),
# each higher layer groups bf children into one parent bitmask.
layers: list[Dict[int, int]] = [{}] # list of dicts: nodeIndex -> bitmask
for v in valueSet:
nodeIndex = v // branchFactor
bitPos = v % branchFactor
if nodeIndex not in layers[0]:
layers[0][nodeIndex] = 0
layers[0][nodeIndex] |= 1 << bitPos
for _ in range(1, height):
prevLayer = layers[-1]
newLayer: Dict[int, int] = {}
for nodeIndex, bitmask in prevLayer.items():
parentIndex = nodeIndex // branchFactor
bitPos = nodeIndex % branchFactor
if parentIndex not in newLayer:
newLayer[parentIndex] = 0
newLayer[parentIndex] |= 1 << bitPos
layers.append(newLayer)
# For zero-node optimization: track count of values in sorted list.
valuesSorted = sorted(valueSet)
def rangeCount(lo: int, hi: int) -> int:
return bisect.bisect_right(valuesSorted, hi) - bisect.bisect_left(
valuesSorted, lo
)
# Emit nodes BFS order (root to leaves).
# Queue entries: (nodeIndex, depthFromRoot, rangeStart, rangeEnd)
stream = _OutputBitStream(branchFactor)
subtreeSize = branchFactor**height
queue: deque[Tuple[int, int, int, int]] = deque([(0, 0, 0, subtreeSize - 1)])
while queue:
nodeIndex, depth, rangeStart, rangeEnd = queue.popleft()
layerIdx = height - 1 - depth # layers[0]=leaves, layers[height-1]=root
bitmask = (
layers[layerIdx].get(nodeIndex, 0) if 0 <= layerIdx < len(layers) else 0
)
# Zero-node optimization: if entire range is filled on an INTERNAL node,
# write 0 and skip children. At leaf level we always write the explicit
# bitmask so the encoding matches the reference.
if (
depth < height - 1
and rangeCount(rangeStart, rangeEnd) == rangeEnd - rangeStart + 1
):
stream.write(0)
continue
stream.write(bitmask)
if bitmask != 0 and depth < height - 1:
childSize = (rangeEnd - rangeStart + 1) // branchFactor
bits = bitmask
while bits:
bitIndex = _trailingZeros(bits, 32)
childIndex = nodeIndex * branchFactor + bitIndex
childStart = rangeStart + bitIndex * childSize
childEnd = childStart + childSize - 1
queue.append((childIndex, depth + 1, childStart, childEnd))
bits &= ~(1 << bitIndex)
return _encodeHeader(branchFactor, height) + stream.toBytes()
def _trailingZeros(val: int, maxBits: int) -> int:
if val == 0:
return maxBits
count = 0
while (val & 1) == 0:
val >>= 1
count += 1
return count
class _InputBitStream:
"""Reads bit nodes from a byte array, starting after the header byte."""
def __init__(self, data: bytes, branchFactor: int):
self.data = data
self.branchFactor = branchFactor
self.byteIndex = 1 # skip the header byte
self.subIndex = 0 # bit offset within current byte (for bf=2,4)
def next(self) -> Optional[int]:
if self.branchFactor in (2, 4):
if self.byteIndex >= len(self.data):
return None
mask = (1 << self.branchFactor) - 1
val = (self.data[self.byteIndex] >> self.subIndex) & mask
self.subIndex += self.branchFactor
if self.subIndex >= 8:
self.subIndex = 0
self.byteIndex += 1
return val
elif self.branchFactor == 8:
if self.byteIndex >= len(self.data):
return None
val = self.data[self.byteIndex]
self.byteIndex += 1
return val
elif self.branchFactor == 32:
if self.byteIndex + 3 >= len(self.data):
return None
b = self.data
i = self.byteIndex
val = b[i] | (b[i + 1] << 8) | (b[i + 2] << 16) | (b[i + 3] << 24)
self.byteIndex += 4
return val
return None
def bytesConsumed(self) -> int:
return self.byteIndex + (1 if self.subIndex > 0 else 0)
class _OutputBitStream:
"""Writes bit nodes into a byte array."""
def __init__(self, branchFactor: int):
self.branchFactor = branchFactor
self.data = bytearray()
self.subIndex = 0 # bit offset within current byte (for bf=2,4)
def write(self, value: int) -> None:
if self.branchFactor in (2, 4):
mask = (1 << self.branchFactor) - 1
value &= mask
if self.subIndex == 0:
self.data.append(0)
self.data[-1] |= value << self.subIndex
self.subIndex += self.branchFactor
if self.subIndex >= 8:
self.subIndex = 0
elif self.branchFactor == 8:
self.data.append(value & 0xFF)
elif self.branchFactor == 32:
self.data.append(value & 0xFF)
self.data.append((value >> 8) & 0xFF)
self.data.append((value >> 16) & 0xFF)
self.data.append((value >> 24) & 0xFF)
def toBytes(self) -> bytes:
return bytes(self.data)