""" Backrefs Regex parser. Licensed under MIT Copyright (c) 2011 - 2020 Isaac Muse """ from __future__ import annotations import unicodedata as _unicodedata import copyreg as _copyreg from . import util as _util import regex as _regex # type: ignore[import] try: # pragma: no cover from regex.regex import _compile_replacement_helper # type: ignore[import] except ImportError: # pragma: no cover from regex._main import _compile_replacement_helper # type: ignore[import] from typing import Generic, AnyStr, Any, cast from ._bregex_typing import Pattern, Match _ASCII_LETTERS = frozenset( ( 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ) ) _DIGIT = frozenset(('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) _OCTAL = frozenset(('0', '1', '2', '3', '4', '5', '6', '7')) _HEX = frozenset(('a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9')) _LETTERS_UNDERSCORE = _ASCII_LETTERS | frozenset(('_',)) _WORD = _LETTERS_UNDERSCORE | _DIGIT _STANDARD_ESCAPES = frozenset(('a', 'b', 'f', 'n', 'r', 't', 'v')) _CURLY_BRACKETS = frozenset(('{', '}')) _PROPERTY_STRIP = frozenset((' ', '-', '_')) _PROPERTY = _WORD | _DIGIT | _PROPERTY_STRIP _GLOBAL_FLAGS = frozenset(('b', 'e', 'p', 'r', 'u')) _SCOPED_FLAGS = frozenset(('a', 'f', 'i', 'L', 'm', 's', 'u', 'w', 'x')) _VERSIONS = frozenset(('0', '1')) _SCOPED_END = frozenset((':', ')')) _CURLY_BRACKETS_ORD = frozenset((0x7b, 0x7d)) # Case upper or lower _UPPER = 1 _LOWER = 2 # Format Constants _BACK_SLASH_TRANSLATION = { "\\a": '\a', "\\b": '\b', "\\f": '\f', "\\r": '\r', "\\t": '\t', "\\n": '\n', "\\v": '\v', "\\\\": '\\' } _FMT_CONV_TYPE = ('a', 'r', 's') class LoopException(Exception): """Loop exception.""" class GlobalRetryException(Exception): """Global retry exception.""" class _SearchParser(Generic[AnyStr]): """Search Template.""" _new_refs = ("e", "R", "Q", "E") _line_break = r'(?>\r\n|[\n\v\f\r\x85\u2028\u2029])' _bytes_line_break = r'(?>\r\n|[\n\v\f\r\x85])' verbose: bool version: int global_flag_swap: dict[str, bool] temp_global_flag_swap: dict[str, bool] is_bytes: bool search: AnyStr def __init__(self, search: AnyStr, re_verbose: bool = False, re_version: int = 0) -> None: """Initialize.""" if isinstance(search, bytes): self.is_bytes = True else: self.is_bytes = False if self.is_bytes: self._re_line_break = self._bytes_line_break else: self._re_line_break = self._line_break self.re_verbose = re_verbose self.re_version = re_version self.search = search def process_quotes(self, text: str) -> str: """Process quotes.""" escaped = False in_quotes = False current = [] quoted = [] # type: list[str] i = _util.StringIter(text) for t in i: if not escaped and t == "\\": escaped = True elif escaped: escaped = False if t == "E": if in_quotes: current.append(_regex.escape("".join(quoted))) quoted = [] in_quotes = False elif t == "Q" and not in_quotes: in_quotes = True elif in_quotes: quoted.extend(["\\", t]) else: current.extend(["\\", t]) elif in_quotes: quoted.extend(t) else: current.append(t) if in_quotes and escaped: quoted.append("\\") elif escaped: current.append("\\") if quoted: current.append(_regex.escape("".join(quoted))) return "".join(current) def verbose_comment(self, t: str, i: _util.StringIter) -> list[str]: """Handle verbose comments.""" current = [] escaped = False try: while t != "\n": if not escaped and t == "\\": escaped = True current.append(t) elif escaped: escaped = False if t in self._new_refs: current.append("\\") current.append(t) else: current.append(t) t = next(i) except StopIteration: pass if t == "\n": current.append(t) return current def flags(self, text: str, scoped: bool = False) -> None: """Analyze flags.""" flags = text.split('-') enable = flags[0] disable = flags[1] if len(flags) > 1 else '' global_retry = False if (self.version == _regex.V1 or scoped) and 'x' in disable and self.verbose: self.verbose = False elif 'x' in enable and not self.verbose: self.verbose = True if not scoped and self.version == _regex.V0: self.temp_global_flag_swap['verbose'] = True global_retry = True if 'V0' in enable and self.version == _regex.V1: # pragma: no cover # Default is V0 if none is selected, # so it is unlikely that this will be selected. self.temp_global_flag_swap['version'] = True self.version = _regex.V0 global_retry = True elif "V1" in enable and self.version == _regex.V0: self.temp_global_flag_swap['version'] = True self.version = _regex.V1 global_retry = True if global_retry: raise GlobalRetryException('Global Retry') def reference(self, t: str, i: _util.StringIter, in_group: bool = False) -> list[str]: """Handle references.""" current = [] if not in_group and t == "R": current.append(self._re_line_break) else: current.extend(["\\", t]) return current def get_posix(self, i: _util.StringIter) -> str | None: """Get POSIX.""" index = i.index value = ['['] try: c = next(i) if c != ':': raise ValueError('Not a valid property!') else: value.append(c) c = next(i) if c == '^': value.append(c) c = next(i) while c != ':': if c not in _PROPERTY: raise ValueError('Not a valid property!') if c not in _PROPERTY_STRIP: value.append(c) c = next(i) value.append(c) c = next(i) if c != ']' or not value: raise ValueError('Unmatched ]') value.append(c) except Exception: i.rewind(i.index - index) value = [] return ''.join(value) if value else None def get_comments(self, i: _util.StringIter) -> str | None: """Get comments.""" index = i.index value = ['('] escaped = False try: c = next(i) if c != '?': i.rewind(1) return None value.append(c) c = next(i) if c != '#': i.rewind(2) return None value.append(c) c = next(i) while c != ')' or escaped is True: if escaped: escaped = False elif c == '\\': escaped = True value.append(c) c = next(i) value.append(c) except StopIteration as e: raise SyntaxError(f"Unmatched '(' at {index - 1}!") from e return ''.join(value) if value else None def get_flags(self, i: _util.StringIter) -> tuple[str | None, bool]: """ Get flags. Regex is more difficult to determine when flags are used in a global and scoped context. There is a specific list of global flags, but they can be used in scoped notation and will still be considered global, but that does not mean other flags are global. Additionally, flags that can be scoped can also used be used in global syntax, but can be disabled with a minus unlike in RE. Bregex only cares about capturing the "verbose" flag and the version flags. Version flags are always global and verbose flags will be scoped based on whether they are in a scoped group. The returned "scoped" parameter only refers to "verbose". """ index = i.index value = ['('] version = False toggle = False smells_scoped = False try: c = next(i) if c != '?': i.rewind(1) return None, False value.append(c) c = next(i) while c not in _SCOPED_END: if toggle: if c not in _SCOPED_FLAGS: raise ValueError('Bad scope') elif version: if c not in _VERSIONS: raise ValueError('Bad version') version = False elif c == '-': toggle = True elif c == 'V': version = True elif c not in _GLOBAL_FLAGS and c not in _SCOPED_FLAGS: raise ValueError("Bad flag") value.append(c) c = next(i) if c == ':': smells_scoped = True value.append(c) except Exception: i.rewind(i.index - index) value = [] return ''.join(value) if value else None, smells_scoped def subgroup(self, t: str, i: _util.StringIter) -> list[str]: """Handle parenthesis.""" # (?#comment) comments = self.get_comments(i) if comments: return [comments] verbose = self.verbose # (?flags:pattern) or (?flags) # "scoped" only refers to verbose flags, scoped = self.get_flags(i) if flags: t = flags self.flags(flags[2:-1], scoped=scoped) if not scoped: return [flags] current = [] # type: list[str] try: while t != ')': if not current: current.append(t) else: current.extend(self.normal(t, i)) t = next(i) except StopIteration: pass self.verbose = verbose if t == ")": current.append(t) return current def char_groups(self, t: str, i: _util.StringIter) -> list[str]: """Handle character groups.""" current = [] pos = i.index - 1 found = 0 sub_first = 0 escaped = False first = 0 try: while True: if not escaped and t == "\\": escaped = True elif escaped: escaped = False current.extend(self.reference(t, i, True)) elif t == "[" and not found: found += 1 first = pos current.append(t) elif t == "[" and found and self.version == _regex.V1: # Start of sub char set found posix = None if self.is_bytes else self.get_posix(i) if posix: current.append(posix) pos = i.index - 2 else: found += 1 sub_first = pos current.append(t) elif t == "[": posix = None if self.is_bytes else self.get_posix(i) if posix: current.append(posix) pos = i.index - 2 else: current.append(t) elif t == "^" and found == 1 and (pos == first + 1): # Found ^ at start of first char set; adjust 1st char position current.append(t) first = pos elif self.version == _regex.V1 and t == "^" and found > 1 and (pos == sub_first + 1): # Found ^ at start of sub char set; adjust 1st char sub position current.append(t) sub_first = pos elif t == "]" and found == 1 and (pos != first + 1): # First char set closed; log range current.append(t) found = 0 break elif self.version == _regex.V1 and t == "]" and found > 1 and (pos != sub_first + 1): # Sub char set closed; decrement depth counter found -= 1 current.append(t) else: current.append(t) pos += 1 t = next(i) except StopIteration: pass if escaped: current.append(t) return current def normal(self, t: str, i: _util.StringIter) -> list[str]: """Handle normal chars.""" current = [] if t == "\\": try: t = next(i) current.extend(self.reference(t, i)) except StopIteration: current.append(t) elif t == "(": current.extend(self.subgroup(t, i)) elif self.verbose and t == "#": current.extend(self.verbose_comment(t, i)) elif t == "[": current.extend(self.char_groups(t, i)) else: current.append(t) return current def main_group(self, i: _util.StringIter) -> list[str]: """The main group: group 0.""" current = [] try: while True: t = next(i) current.extend(self.normal(t, i)) except StopIteration: pass return current def _parse(self, search: str) -> str: """Begin parsing.""" self.verbose = bool(self.re_verbose) self.version = self.re_version if self.re_version else _regex.DEFAULT_VERSION self.global_flag_swap = { "version": self.re_version != 0, "verbose": False } self.temp_global_flag_swap = { "version": False, "verbose": False } new_pattern = [] i = _util.StringIter(self.process_quotes(search)) retry = True while retry: retry = False try: new_pattern = self.main_group(i) except GlobalRetryException as e: # Prevent a loop of retry over and over for a pattern like ((?V0)(?V1)) # or on V0 (?-x:(?x)) if self.temp_global_flag_swap['version']: if self.global_flag_swap['version']: raise LoopException('Global version flag recursion.') from e else: self.global_flag_swap["version"] = True if self.temp_global_flag_swap['verbose']: if self.global_flag_swap['verbose']: raise LoopException('Global verbose flag recursion.') from e else: self.global_flag_swap['verbose'] = True self.temp_global_flag_swap = { "version": False, "verbose": False } i.rewind(i.index) retry = True return "".join(new_pattern) def parse(self) -> AnyStr: """Apply search template.""" if isinstance(self.search, bytes): return self._parse(self.search.decode('latin-1')).encode('latin-1') else: return self._parse(self.search) class _ReplaceParser(Generic[AnyStr]): """Pre-replace template.""" def __init__(self, pattern: Pattern[AnyStr], template: AnyStr, use_format: bool = False) -> None: """Initialize.""" self.pattern = pattern # type: Pattern[AnyStr] self._original = template # type: AnyStr self._template = template # type: AnyStr self.use_format = use_format self.end_found = False self.group_slots = [] # type: list[tuple[int, tuple[int | None, int | None, Any]]] self.literal_slots = [] # type: list[str] self.result = [] # type: list[str] self.span_stack = [] # type: list[int] self.single_stack = [] # type: list[int] self.literals = [] # type: list[AnyStr | None] self.groups = [] # type: list[tuple[int, int]] self.slot = 0 self.manual = False self.auto = False self.auto_index = 0 self.is_bytes = isinstance(self._original, bytes) def parse_format_index(self, text: str) -> int | str: """Parse format index.""" base = 10 prefix = text[1:3] if text[0] == "-" else text[:2] if prefix[0:1] == "0": char = prefix[-1] if char == "b": base = 2 elif char == "o": base = 8 elif char == "x": base = 16 try: idx = int(text, base) # type: int | str except Exception: idx = text return idx def get_format(self, c: str, i: _util.StringIter) -> tuple[str, list[tuple[int, Any]]]: """Get format group.""" index = i.index field = '' value = [] # type: list[tuple[int, Any]] try: if c == '}': value.append((_util.FMT_FIELD, '')) value.append((_util.FMT_INDEX, None)) else: # Field temp = [] # type: list[str] if c in _LETTERS_UNDERSCORE: # Handle name temp.append(c) c = self.format_next(i) while c in _WORD: temp.append(c) c = self.format_next(i) elif c in _DIGIT: # Handle group number temp.append(c) c = self.format_next(i) while c in _DIGIT: temp.append(c) c = self.format_next(i) # Try and covert to integer index field = ''.join(temp).strip() try: value = [(_util.FMT_FIELD, str(int(field, 10)))] except ValueError: value = [(_util.FMT_FIELD, field)] pass if c != '[': value.append((_util.FMT_INDEX, None)) # Attributes and indexes while c in ('[', '.'): if c == '[': findex = [] sindex = i.index - 1 c = self.format_next(i) try: while c != ']': findex.append(c) c = self.format_next(i) except StopIteration as e: raise SyntaxError(f"Unmatched '[' at {sindex - 1}") from e idx = self.parse_format_index(''.join(findex)) value.append((_util.FMT_INDEX, idx)) c = self.format_next(i) else: findex = [] c = self.format_next(i) while c in _WORD: findex.append(c) c = self.format_next(i) value.append((_util.FMT_ATTR, ''.join(findex))) # Conversion if c == '!': c = self.format_next(i) if c not in _FMT_CONV_TYPE: raise SyntaxError(f"Invalid conversion type at {i.index - 1}!") value.append((_util.FMT_CONV, c)) c = self.format_next(i) # Format spec if c == ':': fill = None # type: str | None width = [] align = None convert = None c = self.format_next(i) if c in ('<', '>', '^'): # Get fill and alignment align = c c = self.format_next(i) if c in ('<', '>', '^'): fill = align align = c c = self.format_next(i) elif c in _DIGIT: # Get Width fill = c c = self.format_next(i) if c in ('<', '>', '^'): align = c c = self.format_next(i) else: width.append(fill) fill = None else: fill = c c = self.format_next(i) if fill == 's' and c == '}': convert = fill fill = None if fill is not None: if c not in ('<', '>', '^'): raise SyntaxError(f'Invalid format spec char at {i.index - 1}!') align = c c = self.format_next(i) while c in _DIGIT: width.append(c) c = self.format_next(i) if not align and len(width) and width[0] == '0': raise ValueError("'=' alignment is not supported!") if align and not fill and len(width) and width[0] == '0': fill = '0' if c == 's': convert = c c = self.format_next(i) if not fill: fill = ' ' value.append( ( _util.FMT_SPEC, ( fill.encode('latin-1') if self.is_bytes else fill, align, (int(''.join(width)) if width else 0), convert ) ) ) if c != '}': raise SyntaxError(f"Unmatched '{{' at {index - 1}") except StopIteration as e: raise SyntaxError(f"Unmatched '{{' at {index - 1}!") from e return field, value def handle_format(self, t: str, i: _util.StringIter) -> None: """Handle format.""" if t == '{': t = self.format_next(i) if t == '{': self.get_single_stack() self.result.append(t) else: field, text = self.get_format(t, i) self.handle_format_group(field, text) else: t = self.format_next(i) if t == '}': self.get_single_stack() self.result.append(t) else: raise SyntaxError(f"Unmatched '}}' at {i.index - 2}!") def get_octal(self, c: str, i: _util.StringIter) -> str | None: """Get octal.""" index = i.index value = [] zero_count = 0 try: if c == '0': for _ in range(3): if c != '0': break value.append(c) c = next(i) zero_count = len(value) if zero_count < 3: for _ in range(3 - zero_count): if c not in _OCTAL: break value.append(c) c = next(i) i.rewind(1) except StopIteration: pass octal_count = len(value) if not (self.use_format and octal_count) and not (zero_count and octal_count < 3) and octal_count != 3: i.rewind(i.index - index) value = [] return ''.join(value) if value else None def parse_octal(self, text: str, i: _util.StringIter) -> None: """Parse octal value.""" value = int(text, 8) if value > 0xFF and self.is_bytes: # Re fails on octal greater than `0o377` or `0xFF` raise ValueError("octal escape value outside of range 0-0o377!") else: single = self.get_single_stack() if self.span_stack: text = self.convert_case(chr(value), self.span_stack[-1]) value = ord(self.convert_case(text, single)) if single is not None else ord(text) elif single: value = ord(self.convert_case(chr(value), single)) if self.use_format and value in _CURLY_BRACKETS_ORD: self.handle_format(chr(value), i) elif value <= 0xFF: self.result.append(f'\\{value:03o}') else: self.result.append(chr(value)) def get_named_unicode(self, i: _util.StringIter) -> str: """Get named Unicode.""" index = i.index value = [] try: if next(i) != '{': raise SyntaxError(f"Named Unicode missing '{{' at {i.index - 1}!") c = next(i) while c != '}': value.append(c) c = next(i) except StopIteration as e: raise SyntaxError(f"Unmatched '{{' at {index}!") from e return ''.join(value) def parse_named_unicode(self, i: _util.StringIter) -> None: """Parse named Unicode.""" value = ord(_unicodedata.lookup(self.get_named_unicode(i))) single = self.get_single_stack() if self.span_stack: text = self.convert_case(chr(value), self.span_stack[-1]) value = ord(self.convert_case(text, single)) if single is not None else ord(text) elif single: value = ord(self.convert_case(chr(value), single)) if self.use_format and value in _CURLY_BRACKETS_ORD: self.handle_format(chr(value), i) elif value <= 0xFF: self.result.append(f'\\{value:03o}') else: self.result.append(chr(value)) def get_wide_unicode(self, i: _util.StringIter) -> str: """Get narrow Unicode.""" value = [] for _ in range(3): c = next(i) if c == '0': value.append(c) else: # pragma: no cover raise SyntaxError(f'Invalid wide Unicode character at {i.index - 1}!') c = next(i) if c in ('0', '1'): value.append(c) else: # pragma: no cover raise SyntaxError(f'Invalid wide Unicode character at {i.index - 1}!') for _ in range(4): c = next(i) if c.lower() in _HEX: value.append(c) else: # pragma: no cover raise SyntaxError(f'Invalid wide Unicode character at {i.index - 1}!') return ''.join(value) def get_narrow_unicode(self, i: _util.StringIter) -> str: """Get narrow Unicode.""" value = [] for _ in range(4): c = next(i) if c.lower() in _HEX: value.append(c) else: # pragma: no cover raise SyntaxError(f'Invalid Unicode character at {i.index - 1}!') return ''.join(value) def parse_unicode(self, i: _util.StringIter, wide: bool = False) -> None: """Parse Unicode.""" text = self.get_wide_unicode(i) if wide else self.get_narrow_unicode(i) value = int(text, 16) single = self.get_single_stack() if self.span_stack: text = self.convert_case(chr(value), self.span_stack[-1]) value = ord(self.convert_case(text, single)) if single is not None else ord(text) elif single: value = ord(self.convert_case(chr(value), single)) if self.use_format and value in _CURLY_BRACKETS_ORD: self.handle_format(chr(value), i) elif value <= 0xFF: self.result.append(f'\\{value:03o}') else: self.result.append(chr(value)) def get_byte(self, i: _util.StringIter) -> str: """Get byte.""" value = [] for _ in range(2): c = next(i) if c.lower() in _HEX: value.append(c) else: # pragma: no cover raise SyntaxError(f'Invalid byte character at {i.index - 1}!') return ''.join(value) def parse_bytes(self, i: _util.StringIter) -> None: """Parse byte.""" value = int(self.get_byte(i), 16) single = self.get_single_stack() if self.span_stack: text = self.convert_case(chr(value), self.span_stack[-1]) value = ord(self.convert_case(text, single)) if single is not None else ord(text) elif single: value = ord(self.convert_case(chr(value), single)) if self.use_format and value in _CURLY_BRACKETS_ORD: self.handle_format(chr(value), i) else: self.result.append(f'\\{value:03o}') def get_named_group(self, t: str, i: _util.StringIter) -> str: """Get group number.""" index = i.index value = [t] try: c = next(i) if c != "<": raise SyntaxError(f"Group missing '<' at {i.index - 1}!") value.append(c) c = next(i) if c in _DIGIT: value.append(c) c = next(i) while c != '>': if c in _DIGIT: value.append(c) c = next(i) value.append(c) elif c in _LETTERS_UNDERSCORE: value.append(c) c = next(i) while c != '>': if c in _WORD: value.append(c) c = next(i) value.append(c) else: raise SyntaxError(f"Invalid group character at {i.index - 1}!") except StopIteration as e: raise SyntaxError(f"Unmatched '<' at {index}!") from e return ''.join(value) def get_group(self, t: str, i: _util.StringIter) -> str | None: """Get group number.""" value = [] try: if t in _DIGIT and t != '0': value.append(t) t = next(i) if t in _DIGIT: value.append(t) else: i.rewind(1) except StopIteration: pass return ''.join(value) if value else None def format_next(self, i: _util.StringIter) -> str: """Get next format char.""" c = next(i) return self.format_references(next(i), i) if c == '\\' else c def format_references(self, t: str, i: _util.StringIter) -> str: """Handle format references.""" octal = self.get_octal(t, i) if octal: o = int(octal, 8) if o > 0xFF and self.is_bytes: # Re fails on octal greater than `0o377` or `0xFF` raise ValueError("octal escape value outside of range 0-0o377!") value = chr(o) elif t in _STANDARD_ESCAPES or t == '\\': value = _BACK_SLASH_TRANSLATION['\\' + t] elif not self.is_bytes and t == "U": value = chr(int(self.get_wide_unicode(i), 16)) elif not self.is_bytes and t == "u": value = chr(int(self.get_narrow_unicode(i), 16)) elif not self.is_bytes and t == "N": value = _unicodedata.lookup(self.get_named_unicode(i)) elif t == "x": value = chr(int(self.get_byte(i), 16)) else: i.rewind(1) value = '\\' return value def reference(self, t: str, i: _util.StringIter) -> None: """Handle references.""" octal = self.get_octal(t, i) if t in _OCTAL and octal: self.parse_octal(octal, i) elif (t in _DIGIT or t == 'g') and not self.use_format: group = self.get_group(t, i) if not group: group = self.get_named_group(t, i) self.handle_group('\\' + group) elif t in _STANDARD_ESCAPES: self.get_single_stack() self.result.append('\\' + t) elif t == "l": self.single_case(i, _LOWER) elif t == "L": self.span_case(i, _LOWER) elif t == "c": self.single_case(i, _UPPER) elif t == "C": self.span_case(i, _UPPER) elif t == "E": self.end_found = True elif not self.is_bytes and t == "U": self.parse_unicode(i, True) elif not self.is_bytes and t == "u": self.parse_unicode(i) elif not self.is_bytes and t == "N": self.parse_named_unicode(i) elif t == "x": self.parse_bytes(i) elif self.use_format and t in _CURLY_BRACKETS: self.result.append('\\\\') self.handle_format(t, i) elif self.use_format and t == 'g': self.result.append('\\\\') self.result.append(t) else: value = '\\' + t self.get_single_stack() if self.span_stack: value = self.convert_case(value, self.span_stack[-1]) self.result.append(value) def _parse_template(self, template: str) -> str: """Parse template.""" self.result = [""] i = _util.StringIter(template) try: while True: t = next(i) if self.use_format and t in _CURLY_BRACKETS: self.handle_format(t, i) elif t == '\\': try: t = next(i) self.reference(t, i) except StopIteration: self.result.append(t) raise else: self.result.append(t) except StopIteration: pass if len(self.result) > 1: self.literal_slots.append("".join(self.result)) del self.result[:] self.result.append("") self.slot += 1 return "".join(self.literal_slots) def parse_template(self) -> None: """Parse template.""" if isinstance(self._original, bytes): self._template = self._parse_template(self._original.decode('latin-1')).encode('latin-1') else: self._template = self._parse_template(self._original) count = 0 for part in _compile_replacement_helper(self.pattern, self._template): if isinstance(part, int): self.literals.append(None) self.groups.append((count, part)) else: self.literals.append(cast(AnyStr, part)) count += 1 def span_case(self, i: _util.StringIter, case: int) -> None: """Uppercase or lowercase the next range of characters until end marker is found.""" # A new \L, \C or \E should pop the last in the stack. if self.span_stack: self.span_stack.pop() if self.single_stack: self.single_stack.pop() self.span_stack.append(case) count = len(self.span_stack) self.end_found = False try: while not self.end_found: t = next(i) if self.use_format and t in _CURLY_BRACKETS: self.handle_format(t, i) elif t == '\\': try: t = next(i) self.reference(t, i) except StopIteration: self.result.append(t) raise else: self.result.append(self.convert_case(t, case)) if self.end_found or count > len(self.span_stack): self.end_found = False break except StopIteration: pass if count == len(self.span_stack): self.span_stack.pop() def convert_case(self, value: str, case: int) -> str: """Convert case.""" if self.is_bytes: cased = [] for c in value: if c in _ASCII_LETTERS: cased.append(c.lower() if case == _LOWER else c.upper()) else: cased.append(c) return "".join(cased) else: return value.lower() if case == _LOWER else value.upper() def single_case(self, i: _util.StringIter, case: int) -> None: """Uppercase or lowercase the next character.""" # Pop a previous case if we have consecutive ones. if self.single_stack: self.single_stack.pop() self.single_stack.append(case) try: t = next(i) if self.use_format and t in _CURLY_BRACKETS: self.handle_format(t, i) elif t == '\\': try: t = next(i) self.reference(t, i) except StopIteration: self.result.append(t) raise else: this_case = self.get_single_stack() if this_case is not None: self.result.append(self.convert_case(t, this_case)) except StopIteration: pass def get_single_stack(self) -> int | None: """Get the correct single stack item to use.""" single = None while self.single_stack: single = self.single_stack.pop() return single def handle_format_group(self, field: str, text: list[tuple[int, Any]]) -> None: """Handle format group.""" # Handle auto incrementing group indexes if field == '': if self.auto: field = str(self.auto_index) text[0] = (_util.FMT_FIELD, field) self.auto_index += 1 elif not self.manual and not self.auto: self.auto = True field = str(self.auto_index) text[0] = (_util.FMT_FIELD, field) self.auto_index += 1 else: raise ValueError("Cannot switch to auto format during manual format!") elif not self.manual and not self.auto: self.manual = True elif not self.manual: raise ValueError("Cannot switch to manual format during auto format!") self.handle_group(field, tuple(text), True) def handle_group( self, text: str, capture: tuple[tuple[int, Any], ...] | None = None, is_format: bool = False ) -> None: """Handle groups.""" if len(self.result) > 1: self.literal_slots.append("".join(self.result)) if is_format: self.literal_slots.extend(["\\g<", text, ">"]) else: self.literal_slots.append(text) del self.result[:] self.result.append("") self.slot += 1 elif is_format: self.literal_slots.extend(["\\g<", text, ">"]) else: self.literal_slots.append(text) self.group_slots.append( ( self.slot, ( (self.span_stack[-1] if self.span_stack else None), self.get_single_stack(), (() if self.is_bytes else '') if capture is None else capture ) ) ) self.slot += 1 def get_base_template(self) -> AnyStr: """Return the unmodified template before expansion.""" return self._original def parse(self) -> ReplaceTemplate[AnyStr]: """Parse template.""" if not isinstance(self.pattern.pattern, type(self._original)): raise TypeError('Pattern string type must match replace template string type!') self.parse_template() return ReplaceTemplate( tuple(self.groups), tuple(self.group_slots), tuple(self.literals), hash(self.pattern), self.use_format, self.is_bytes ) class ReplaceTemplate(_util.Immutable, Generic[AnyStr]): """Replacement template expander.""" __slots__ = ("groups", "group_slots", "literals", "pattern_hash", "use_format", "_hash", "_bytes") groups: tuple[tuple[int, int], ...] group_slots: tuple[tuple[int, tuple[int | None, int | None, Any]], ...] literals: tuple[AnyStr | None, ...] pattern_hash: int use_format: bool _hash: int _bytes: bool def __init__( self, groups: tuple[tuple[int, int], ...], group_slots: tuple[tuple[int, tuple[int | None, int | None, Any]], ...], literals: tuple[AnyStr | None, ...], pattern_hash: int, use_format: bool, is_bytes: bool ) -> None: """Initialize.""" super().__init__( use_format=use_format, groups=groups, group_slots=group_slots, literals=literals, pattern_hash=pattern_hash, _bytes=is_bytes, _hash=hash( ( type(self), groups, group_slots, literals, pattern_hash, use_format, is_bytes ) ) ) def __call__(self, m: Match[AnyStr] | None) -> AnyStr: """Call.""" return self.expand(m) def __hash__(self) -> int: """Hash.""" return self._hash def __eq__(self, other: Any) -> bool: """Equal.""" return ( isinstance(other, ReplaceTemplate) and self.groups == other.groups and self.group_slots == other.group_slots and self.literals == other.literals and self.pattern_hash == other.pattern_hash and self.use_format == other.use_format and self._bytes == other._bytes ) def __ne__(self, other: Any) -> bool: """Equal.""" return ( not isinstance(other, ReplaceTemplate) or self.groups != other.groups or self.group_slots != other.group_slots or self.literals != other.literals or self.pattern_hash != other.pattern_hash or self.use_format != other.use_format or self._bytes != other._bytes ) def __repr__(self) -> str: # pragma: no cover """Representation.""" return "{}.{}({!r}, {!r}, {!r}, {!r}, {!r})".format( self.__module__, self.__class__.__name__, self.groups, self.group_slots, self.literals, self.pattern_hash, self.use_format ) def _get_group_index(self, index: int) -> int: """Find and return the appropriate group index.""" g_index = 0 for group in self.groups: if group[0] == index: g_index = group[1] break return g_index def _get_group_attributes(self, index: int) -> tuple[int | None, int | None, Any]: """Find and return the appropriate group case.""" g_case = (None, None, -1) # type: tuple[int | None, int | None, Any] for group in self.group_slots: if group[0] == index: g_case = group[1] break return g_case def expand(self, m: Match[AnyStr] | None) -> AnyStr: """Using the template, expand the string.""" if m is None: raise ValueError("Match is None!") sep = m.re.pattern[:0] # type: AnyStr if isinstance(sep, bytes) != self._bytes: raise TypeError('Match string type does not match expander string type!') text = [] # Expand string for index in range(0, len(self.literals)): l = self.literals[index] # type: AnyStr | None if l is None: g_index = self._get_group_index(index) span_case, single_case, capture = self._get_group_attributes(index) if not self.use_format: # Non format replace try: l = cast('AnyStr | None', m.group(g_index)) if l is None: l = sep except IndexError as e: # pragma: no cover raise IndexError(f"'{g_index}' is out of range!") from e else: # String format replace try: obj = cast('list[AnyStr]', m.captures(g_index)) except IndexError as e: # pragma: no cover raise IndexError(f"'{g_index}' is out of range!") from e l = _util.format_captures( obj, capture, _util._to_bstr if isinstance(sep, bytes) else _util._to_str, sep ) if span_case is not None: if span_case == _LOWER: l = l.lower() else: l = l.upper() if single_case is not None: if single_case == _LOWER: l = l[0:1].lower() + l[1:] else: l = l[0:1].upper() + l[1:] text.append(l) return sep.join(text) def _pickle(r): # type: ignore[no-untyped-def] """Pickle.""" return ReplaceTemplate, (r.groups, r.group_slots, r.literals, r.pattern_hash, r.use_format, r._bytes) _copyreg.pickle(ReplaceTemplate, _pickle)