Source code for class_resolver.docs

"""Utilities for automatic documentation."""

from __future__ import annotations

import importlib
import inspect
import textwrap
from collections import defaultdict
from collections.abc import Callable
from typing import Any, TypeVar

from typing_extensions import ParamSpec

from .base import BaseResolver

__all__ = [
    "ResolverKey",
    "update_docstring_with_resolver_keys",
]

T = TypeVar("T")
P = ParamSpec("P")


def _get_qualpath_from_object(resolver: BaseResolver[Any, Any]) -> str:
    if resolver.location:
        return resolver.location
    raise NotImplementedError(
        "Can not get a qualified name for auto-generation of sphinx documentation "
        "for a resolver that doesn't have the `location` variable set"
    )


[docs] class ResolverKey: """An object storing information about how a resolver is used in a signature.""" name: str key: str resolver_path: str # note that resolver keys don't depend at all on the # types in the resolver resolver: BaseResolver[Any, Any] | None def __init__( self, name: str, resolver: str | BaseResolver[Any, Any], key: str | None = None, ) -> None: """Initialize the key for :func:`update_docstring_with_resolver_keys`.""" self.name = name self.key = f"{self.name}_kwargs" if key is None else key if isinstance(resolver, str): self.resolver_path = resolver try: module_name, variable_name = resolver.rsplit(".", 1) module = importlib.import_module(module_name) resolver_inst = getattr(module, variable_name) except (ImportError, ValueError): self.resolver = None except AttributeError as e: if "partially initialized module" not in str(e): raise # this happens in a circular import case. just let it go self.resolver = None else: self.resolver = resolver_inst elif isinstance(resolver, BaseResolver): self.resolver_path = _get_qualpath_from_object(resolver) self.resolver = resolver else: raise TypeError
def _clean_docstring(s: str) -> str: """Clean a docstring. :param s: Input docstring :returns: Cleaned docstring :raises ValueError: if the docstring is improperly formatted This method does the following 1. strip 2. pop off first line 3. ensure second line is blank 4. dedent on all remaining lines 5. chunk em back together """ s = s.strip() lines = s.splitlines() if len(lines) == 1: return lines[0] if len(lines) == 2: if not lines[1].strip(): return lines[0] else: raise ValueError("not sure how to clean a two line docstring") first, second, *rest = lines if second.strip(): raise ValueError rest_j = "\n".join(rest) rest_j = textwrap.dedent(rest_j) return f"{first.strip()}\n\n{rest_j}"
[docs] def update_docstring_with_resolver_keys(*resolver_keys: ResolverKey) -> Callable[[Callable[P, T]], Callable[P, T]]: """Build a decorator to add information about resolved parameter pairs. The decorator is intended for methods with follow the ``param`` + ``param_kwargs`` pattern and internally use a class resolver. .. code-block:: python from typing import Any from torch import Tensor, nn from class_resolver import update_docstring_with_resolver_keys, ResolverKey from class_resolver.contrib.torch import activation_resolver @update_docstring_with_resolver_keys( ResolverKey("activation", "class_resolver.contrib.torch.activation_resolver"), ) def f( tensor: Tensor, activation: None | str | type[nn.Module] | nn.Module, activation_kwargs: dict[str, Any] | None, ) -> Tensor: _activation = activation_resolver.make(activation, activation_kwargs) return _activation(tensor) This also can be stacked for multiple resolvers. .. code-block:: python from typing import Any from torch import Tensor, nn from class_resolver import update_docstring_with_resolver_keys from class_resolver.contrib.torch import activation_resolver, aggregation_resolver @update_docstring_with_resolver_keys( ResolverKey("activation", "class_resolver.contrib.torch.activation_resolver"), ResolverKey("aggregation", "class_resolver.contrib.torch.aggregation_resolver"), ) def f( tensor: Tensor, activation: None | str | type[nn.Module] | nn.Module, activation_kwargs: dict[str, Any] | None, aggregation: None | str | type[nn.Module] | nn.Module, aggregation_kwargs: dict[str, Any] | None, ) -> Tensor: _activation = activation_resolver.make(activation, activation_kwargs) _aggregation = aggregation_resolver.make(aggregation, aggregation_kwargs) return _aggregation(_activation(tensor)) It might be the case that you have two different arguments that use the same resolver. No prob! .. code-block:: python from typing import Any from torch import Tensor, nn from class_resolver import update_docstring_with_resolver_keys from class_resolver.contrib.torch import activation_resolver, aggregation_resolver @update_docstring_with_resolver_keys( ResolverKey("activation_1", "class_resolver.contrib.torch.activation_resolver"), ResolverKey("activation_2", "class_resolver.contrib.torch.activation_resolver"), ResolverKey("aggregation", "class_resolver.contrib.torch.aggregation_resolver"), ) def f( tensor: Tensor, activation_1: None | str | type[nn.Module] | nn.Module, activation_1_kwargs: dict[str, Any] | None, aggregation: None | str | type[nn.Module] | nn.Module, aggregation_kwargs: dict[str, Any] | None, activation_2: None | str | type[nn.Module] | nn.Module, activation_2_kwargs: dict[str, Any] | None, ) -> Tensor: _activation_1 = activation_resolver.make(activation_1, activation_1_kwargs) _activation_2 = activation_resolver.make(activation_2, activation_2_kwargs) _aggregation = aggregation_resolver.make(aggregation, aggregation_kwargs) return _activation_2(_aggregation(_activation_2(tensor))) :param resolver_keys: A variadic list of keys, each describing: 1. the names of the parameter 2. the resolver used to construct a reference via the ``:data:`` role. 3. the name of the parameter for giving keyword arguments. By default, this is constructed by taking the name and post-pending ``_kwargs``. :returns: a decorator which extends a function's docstring. :raises ValueError: When either no parameter name was provided, there was a duplicate parameter name. """ # input validation if not resolver_keys: raise ValueError("Must provided at least one parameter name.") # check for duplicates expanded_params = {e for key in resolver_keys for e in (key.name, key.key)} if len(expanded_params) < 2 * len(resolver_keys): raise ValueError(f"There are duplicates in (the expanded) {resolver_keys=}") # TODO: we could do some more sanitization, e.g., trying to match types, ... def add_note(func: Callable[P, T]) -> Callable[P, T]: """Extend the function's docstring with a note about resolved parameters. :param func: the function to decorate. :returns: the function with extended docstring. :raises ValueError: When the signature does not contain the resolved parameter names, or the docstring is missing. """ signature = inspect.signature(func) if missing := expanded_params.difference(signature.parameters): raise ValueError(f"{missing=} parameters in {signature=}.") if not func.__doc__: raise ValueError("docstring is empty") resolver_to_keys = defaultdict(list) for key in resolver_keys: resolver_to_keys[key.resolver_path].append(key) parameter_pair_strs = [] for resolver_qualname, subkeys in resolver_to_keys.items(): pair_strs = [f"``({key.name}, {key.key})``" for key in subkeys] if len(subkeys) > 1: parameter_pair_str = f"pairs {', '.join(pair_strs)} are" else: parameter_pair_str = f"pair {pair_strs[0]} is" parameter_pair_strs.append(f"The parameter {parameter_pair_str} used for :data:`{resolver_qualname}`") if len(parameter_pair_strs) == 1: note_str = f"""\ .. note :: {parameter_pair_strs[0]} An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/. """ else: bullet_points = "\n".join(" " * 4 + "- " + i for i in parameter_pair_strs) note_str = f"""\ .. note :: {len(resolver_keys)} resolvers are used in this function. {bullet_points} An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/. """ func.__doc__ = f"{_clean_docstring(func.__doc__)}\n\n{textwrap.dedent(note_str)}".rstrip() return func return add_note