ickc/pandoc-amsthm

View on GitHub
src/amsthm/__init__.py

Summary

Maintainability
D
3 days
Test Coverage
from __future__ import annotations

import re
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property, partial
from typing import TYPE_CHECKING

import panflute as pf
from panflute.tools import convert_text

from .helper import cancel_emph, cite_to_id_mode, cite_to_ref, merge_emph, parse_markdown_as_inline, to_emph
from .util import setup_logging

if TYPE_CHECKING:
    from typing import Union

    from panflute.elements import Doc, Element

    THM_DEF = list[Union[str, dict[str, str], dict[str, list[str]]]]

__version__: str = "2.0.0"

PARENT_COUNTERS: set[str] = {
    "part",
    "chapter",
    "section",
    "subsection",
    "subsubsection",
    "paragraph",
    "subparagraph",
}
STYLES: tuple[str, ...] = ("plain", "definition", "remark")
METADATA_KEY: str = "amsthm"
REF_REGEX = re.compile(r"^\\(ref|eqref)\{(.*)\}$")
LATEX_LIKE: set[str] = {"latex", "beamer"}
PLAIN_OR_DEF: set[str] = {"plain", "definition"}
COUNTER_DEPTH_DEFAULT: int = 0

logger = setup_logging()


def parse_info(info: str | None) -> list[Element]:
    """Convert theorem info to panflute AST inline elements."""
    return [pf.Str(r"(")] + parse_markdown_as_inline(info) + [pf.Str(r")")] if info else []


@dataclass
class NewTheorem:
    style: str
    env_name: str
    text: str = ""
    parent_counter: str | None = None
    shared_counter: str | None = None
    numbered: bool = True
    """A LaTeX amsthm new theorem.

    :param parent_counter: for LaTeX output, controlling the number of numbers in a theorem.
        Should be used with counter_depth to match LaTeX and non-LaTeX output.
    """

    def __post_init__(self) -> None:
        if self.env_name.endswith("*"):
            self.env_name = self.env_name[:-1]
            self.numbered = False
        if not self.text:
            logger.debug("Defaulting text to %s", self.env_name)
            self.text = self.env_name
        if (parent_counter := self.parent_counter) is not None and parent_counter not in PARENT_COUNTERS:
            logger.warning("Unsupported parent_counter %s, ignoring.", parent_counter)
        if self.numbered and parent_counter is not None and self.shared_counter is not None:
            logger.warning("Dropping shared_counter as parent_counter is defined.")
            self.shared_counter = None

    @property
    def latex(self) -> str:
        res = [r"\newtheorem"]
        if not self.numbered:
            res.append(f"*{{{self.env_name}}}{{{self.text}}}")
        elif self.shared_counter is None:
            if self.parent_counter is None:
                res.append(f"{{{self.env_name}}}{{{self.text}}}")
            else:
                res.append(f"{{{self.env_name}}}{{{self.text}}}[{self.parent_counter}]")
        else:
            res.append(f"{{{self.env_name}}}[{self.shared_counter}]{{{self.text}}}")
        return "".join(res)

    @property
    def class_name(self) -> str:
        """Name in pandoc div classes.

        It cannot have space.
        """
        return self.env_name.replace(" ", "_")

    @property
    def counter_name(self) -> str:
        return self.env_name if self.shared_counter is None else self.shared_counter

    def to_panflute_theorem_header(
        self,
        options: DocOptions,
        id: str | None,
        info: str | None,
    ) -> list[pf.Element]:
        """Return a theorem header as panflute AST.

        This mutates `options.theorem_counters`, `options.identifiers` in-place.
        """
        TextType: type[Element]
        text = self.text

        # text and number separated by Space

        NumberType: type[Element]
        theorem_number: str | None
        if self.numbered:
            counter_name = self.counter_name
            options.theorem_counters[counter_name] += 1
            theorem_counter = options.theorem_counters[counter_name]
            theorem_number = ".".join([str(i) for i in options.header_counters] + [str(theorem_counter)])
            if id:
                options.identifiers[id] = theorem_number
        else:
            theorem_number = None

        # no additional styling here
        info_list = parse_info(info)

        # append TextType of ".", Space

        # cases: PLAIN_OR_DEF, theorem_number, info_list
        if self.style in PLAIN_OR_DEF:
            TextType = pf.Strong
            NumberType = pf.Strong
        else:
            TextType = pf.Emph
            NumberType = pf.Str
        # We are normalizing the Emph/Strong boundary manually by having 6 cases
        if theorem_number is None:
            if info_list:
                res = [TextType(pf.Str(text)), pf.Space] + info_list + [TextType(pf.Str(".")), pf.Space]
            else:
                res = [TextType(pf.Str(f"{text}.")), pf.Space]
        else:
            if TextType is NumberType:
                if info_list:
                    res = (
                        [TextType(pf.Str(f"{text} {theorem_number}")), pf.Space]
                        + info_list
                        + [TextType(pf.Str(".")), pf.Space]
                    )
                else:
                    res = [TextType(pf.Str(f"{text} {theorem_number}.")), pf.Space]
            else:
                if info_list:
                    res = (
                        [TextType(pf.Str(text)), pf.Space, pf.Str(theorem_number), pf.Space]
                        + info_list
                        + [TextType(pf.Str(".")), pf.Space]
                    )
                else:
                    res = [
                        TextType(pf.Str(text)),
                        pf.Space,
                        pf.Str(theorem_number),
                        TextType(pf.Str(".")),
                        pf.Space,
                    ]
        return res


@dataclass
class Proof(NewTheorem):
    style: str = "proof"
    env_name: str = "proof"
    text: str = "proof"
    parent_counter: str | None = None
    shared_counter: str | None = None
    numbered: bool = False

    def to_panflute_theorem_header(
        self,
        options: DocOptions,
        id: str | None,
        info: str | None,
    ) -> list[pf.Element]:
        """Return a theorem header as panflute AST."""
        if info is None:
            return [pf.Emph(pf.Str("Proof.")), pf.Space]
        else:
            # put it into a Para then walk
            ast = parse_markdown_as_inline(info)
            info_list = pf.Para(*ast)
            info_list.walk(to_emph)
            info_list.walk(cancel_emph)
            info_list.walk(merge_emph)
            return list(info_list.content) + [pf.Emph(pf.Str(".")), pf.Space]


@dataclass
class DocOptions:
    """Document options.

    :param: counter_depth: can be n=0-6 inclusive.
        n means n+1 numbers shown in non-LaTeX outputs.
        e.g. n=1 means x.y, where x is the heading 1 counter, y is the theorem counter.
        Should be used with parent_counter to match LaTeX and non-LaTeX output.
    """

    theorems: dict[str, NewTheorem] = field(default_factory=dict)
    counter_depth: int = COUNTER_DEPTH_DEFAULT
    counter_ignore_headings: set[str] = field(default_factory=set)

    def __post_init__(self) -> None:
        try:
            self.counter_depth = int(self.counter_depth)
        except ValueError:
            logger.warning("counter_depth must be int, default to 1.")
            self.counter_depth = COUNTER_DEPTH_DEFAULT

        # initial count is zero
        # should be += 1 before using
        self.header_counters: list[int] = [0] * self.counter_depth
        self.reset_theorem_counters()
        # from identifiers to numbers
        self.identifiers: dict[str, str] = {}

    def reset_theorem_counters(self) -> None:
        self.theorem_counters: dict[str, int] = defaultdict(int)

    @cached_property
    def theorems_set(self) -> set[str]:
        return set(self.theorems)

    @classmethod
    def from_doc(
        cls,
        doc: Doc,
    ) -> DocOptions:
        options: dict[
            str,
            dict[str, str | dict[str, str] | THM_DEF],
        ] = doc.get_metadata(METADATA_KEY, {})

        name_to_text: dict[str, str] = options.get("name_to_text", {})  # type: ignore[assignment, arg-type]
        parent_counter: str = options.get("parent_counter", None)  # type: ignore[assignment, arg-type]

        theorems: dict[str, NewTheorem] = {}
        for style in STYLES:
            option: THM_DEF = options.get(style, [])  # type: ignore[assignment]
            for opt in option:
                if isinstance(opt, dict):
                    for key, value in opt.items():
                        # key
                        theorem = NewTheorem(style, key, text=name_to_text.get(key, ""), parent_counter=parent_counter)
                        theorems[theorem.class_name] = theorem
                        # value(s)
                        if isinstance(value, list):
                            for v in value:
                                theorem = NewTheorem(style, v, text=name_to_text.get(v, ""), shared_counter=key)
                                theorems[theorem.class_name] = theorem
                        else:
                            v = value
                            theorem = NewTheorem(style, v, text=name_to_text.get(v, ""), shared_counter=key)
                            theorems[theorem.class_name] = theorem
                else:
                    key = opt
                    theorem = NewTheorem(style, key, text=name_to_text.get(key, ""), parent_counter=parent_counter)
                    theorems[theorem.class_name] = theorem
        # proof is predefined in amsthm
        theorems["proof"] = Proof()
        return cls(
            theorems,
            counter_depth=options.get("counter_depth", COUNTER_DEPTH_DEFAULT),  # type: ignore[arg-type] # will be verified at __post_init__
            counter_ignore_headings=set(options.get("counter_ignore_headings", set())),
        )

    @property
    def latex(self) -> str:
        cur_style: str = ""
        res: list[str] = []
        for theorem in self.theorems.values():
            # proof is predefined in amsthm
            if not isinstance(theorem, Proof):
                if theorem.style != cur_style:
                    cur_style = theorem.style
                    res.append(f"\\theoremstyle{{{cur_style}}}")
                res.append(theorem.latex)
        return "\n".join(res)

    @property
    def to_panflute(self) -> pf.RawBlock:
        return pf.RawBlock(self.latex, format="latex")


def prepare(doc: Doc) -> None:
    doc._amsthm = options = DocOptions.from_doc(doc)
    if doc.format in LATEX_LIKE:
        doc.content.insert(0, options.to_panflute)


def amsthm(elem: Element, doc: Doc) -> None:
    """General amsthm transformation working for all document types.

    Essentially we replicate LaTeX amsthm behavior in this filter.
    """
    options: DocOptions = doc._amsthm
    if isinstance(elem, pf.Header):
        if elem.level <= options.counter_depth:
            header_string = None
            if (counter_ignore_headings := options.counter_ignore_headings) and (
                header_string := pf.stringify(elem)
            ) in counter_ignore_headings:
                logger.debug("Ignoring header %s in header_counters as it is in counter_ignore_headings", header_string)
            else:
                # Header.level is 1-indexed, while list is 0-indexed
                options.header_counters[elem.level - 1] += 1
                # reset deeper levels
                for i in range(elem.level, options.counter_depth):
                    options.header_counters[i] = 0
                logger.debug(
                    "Header encounter: %s, current counter: %s", header_string or elem, options.header_counters
                )
                options.reset_theorem_counters()
    elif isinstance(elem, pf.Div):
        environments: set[str] = options.theorems_set.intersection(elem.classes)
        if environments:
            if len(environments) != 1:
                logger.warning("Multiple environments found: %s", environments)
                return None
            environment = environments.pop()
            theorem = options.theorems[environment]

            info = elem.attributes.get("info", None)
            id = elem.identifier

            res = theorem.to_panflute_theorem_header(options, id, info)

            # theorem body
            if theorem.style == "plain":
                elem.walk(to_emph)
                elem.walk(cancel_emph)
                elem.walk(merge_emph)
            try:
                # insert in the beginning of the first block element
                for r in reversed(res):
                    elem.content[0].content.insert(0, r)
            except AttributeError:
                # if fail, insert a Para before content
                elem.content.insert(0, pf.Para(*res))
            r = pf.RawInline("<span style='float: right'>◻</span>", format="html")
            try:
                # insert in the end of the last block element
                if theorem.style == "proof":
                    elem.content[-1].content.append(r)
            except AttributeError:
                # if fail, append a Para
                elem.content.append(pf.Para(r))


def resolve_ref(elem: Element, doc: Doc) -> pf.Str | None:
    """Resolve references to theorem numbers.

    Consider this as post-process ref for general output formats.
    """
    options: DocOptions = doc._amsthm
    # from [@...] to number
    if isinstance(elem, pf.Cite):
        if (temp := cite_to_id_mode(elem)) is not None and (id := temp[0]) in options.identifiers:
            mode = temp[1]
            # @[...]
            if mode == "NormalCitation":
                return pf.Str(f"({options.identifiers[id]})")
            # @...
            elif mode == "AuthorInText":
                return pf.Str(options.identifiers[id])
            else:
                logger.warning("Unknown citation mode %s from Cite: %s. Ignoring...", mode, elem)
                return None

    # from \ref{...} to number
    elif isinstance(elem, pf.RawInline) and elem.format == "tex":
        text = elem.text
        if matches := REF_REGEX.findall(text):
            if len(matches) != 1:
                logger.warning("Ignoring ref matching in %s: %s", text, matches)
                return None
            ref_type, id = matches[0]
            if id in options.identifiers:
                if ref_type == "eqref":
                    return pf.Str(f"({options.identifiers[id]})")
                else:
                    return pf.Str(options.identifiers[id])
    return None


def collect_ref_id(elem: Element, doc: Doc) -> None:
    """Only collect all amsthm environment id.

    This should be used before the `amsthm_latex` filter.
    This is done in 2 passes as the id may be cited/referenced earlier than definition.
    Consider this as pre-process of ref for LaTeX output.

    `options.identifiers` modified in-place.
    """
    # check if it is a Div, and the class is an amsthm environment
    options: DocOptions = doc._amsthm
    environments: set[str]
    if isinstance(elem, pf.Div) and (environments := options.theorems_set.intersection(elem.classes)):
        if len(environments) != 1:
            logger.warning("Multiple environments found: %s", environments)
            return None
        if id := elem.identifier:
            # in LaTeX output, we only need to keep a reference of the id
            # the numbering (value of this dict) is handled by LaTeX
            options.identifiers[id] = ""
    return None


def amsthm_latex(elem: Element, doc: Doc) -> pf.RawBlock | None:
    """Transform amsthm defintion to LaTeX package specifications."""
    # check if it is a Div, and the class is an amsthm environment
    options: DocOptions = doc._amsthm
    if isinstance(elem, pf.Div):
        environments: set[str] = options.theorems_set.intersection(elem.classes)
        if environments:
            if len(environments) != 1:
                logger.warning("Multiple environments found: %s", environments)
                return None
            environment = environments.pop()
            theorem = options.theorems[environment]
            div_content = pf.convert_text(elem, input_format="panflute", output_format="latex")
            info = elem.attributes.get("info", None)
            id = elem.identifier
            res = [f"\\begin{{{theorem.env_name}}}"]
            if info:
                # wrap in Para for walk
                ast = pf.Para(*parse_markdown_as_inline(info))
                ast.walk(partial(cite_to_ref, check_id=options.identifiers))
                ast = convert_text(ast, input_format="panflute", output_format="latex").strip()
                res += [f"[{ast}]"]
            if id:
                res.append(f"\\label{{{id}}}")
            res.append(f"\n{div_content}\n\\end{{{theorem.env_name}}}")
            return pf.RawBlock("".join(res), format="latex")
    # check if pf.Cite is done inside cite_to_ref
    else:
        return cite_to_ref(elem, doc, options.identifiers)
    return None


def action1(elem: Element, doc: Doc) -> pf.RawBlock | None:
    if doc.format in LATEX_LIKE:
        collect_ref_id(elem, doc)
    else:
        amsthm(elem, doc)
    return None


def action2(elem: Element, doc: Doc) -> pf.Str | pf.RawInline | None:
    if doc.format in LATEX_LIKE:
        return amsthm_latex(elem, doc)
    else:
        return resolve_ref(elem, doc)


def finalize(doc: Doc) -> None:
    del doc._amsthm


def main(doc: Doc | None = None) -> None:
    return pf.run_filters(
        (action1, action2),
        prepare=prepare,
        finalize=finalize,
        doc=doc,
    )