Source code for autodoc2.astroid_utils

"""Utilities for working with astroid nodes."""

from __future__ import annotations

import builtins
import itertools
import re
import typing as t

import astroid
from astroid import nodes


[docs]def resolve_import_alias(name: str, import_names: list[tuple[str, str | None]]) -> str: """Resolve a name from an aliased import to its original name. :param name: The potentially aliased name to resolve. :param import_names: The pairs of original names and aliases from the import. :returns: The original name. """ resolved_name = name for import_name, imported_as in import_names: if import_name == name: break if imported_as == name: resolved_name = import_name break return resolved_name
[docs]def is_constructor(node: nodes.NodeNG) -> bool: """Check if the function is a constructor.""" return ( node.parent and isinstance(node.parent.scope(), nodes.ClassDef) and node.name == "__init__" )
[docs]def get_full_import_name(import_from: nodes.ImportFrom, name: str) -> str: """Get the full path of a name from a ``from x import y`` statement. :returns: The full import path of the name. """ partial_basename = resolve_import_alias(name, import_from.names) module_name = import_from.modname if import_from.level: module = import_from.root() assert isinstance(module, nodes.Module) module_name = module.relative_to_absolute_name( import_from.modname, level=import_from.level ) return f"{module_name}.{partial_basename}"
[docs]def get_assign_value(node: nodes.NodeNG) -> None | tuple[str, t.Any]: """Get the name and value of the assignment of the given node. Assignments to multiple names are ignored, as per PEP 257. :param node: The node to get the assignment value from. :returns: The name that is assigned to, and the value assigned to the name (if it can be converted). """ try: targets = node.targets except AttributeError: targets = [node.target] if len(targets) == 1: target = targets[0] if isinstance(target, nodes.AssignName): name = target.name elif isinstance(target, nodes.AssignAttr): name = target.attrname else: return None return (name, get_const_values(node.value)) return None
[docs]def get_const_values(node: nodes.NodeNG) -> t.Any: """Get the value of a constant node.""" # TODO its not ideal that this goes to None if not understood # TODO better typing value: t.Any = None if isinstance(node, (nodes.List, nodes.Tuple)): new_value = [] for element in node.elts: if isinstance(element, nodes.Const): new_value.append(element.value) elif isinstance(element, (nodes.List, nodes.Tuple)): new_value.append(get_const_values(element)) else: break else: value = new_value if isinstance(node, nodes.Tuple): value = tuple(new_value) elif isinstance(node, nodes.Const): value = node.value elif isinstance(node, nodes.Call): # TODO represent also the arguments value = f"{node.func.repr_name()}(...)" return value
[docs]def get_assign_annotation(node: nodes.Assign) -> None | str: """Get the type annotation of the assignment of the given node. :returns: The type annotation as a string, or None if one does not exist. """ annotation_node = None try: annotation_node = node.annotation except AttributeError: annotation_node = node.type_annotation if annotation_node is None: return None return resolve_annotation(annotation_node)
[docs]def resolve_annotation(annotation: nodes.NodeNG) -> str: """Resolve a type annotation to a string.""" resolved: str if isinstance(annotation, nodes.Const): resolved = resolve_qualname(annotation, str(annotation.value)) elif isinstance(annotation, nodes.Name): resolved = resolve_qualname(annotation, annotation.name) elif isinstance(annotation, nodes.Attribute): resolved = resolve_qualname(annotation, annotation.as_string()) elif isinstance(annotation, nodes.Subscript): value = resolve_annotation(annotation.value) slice_node = annotation.slice try: if isinstance(slice_node, nodes.Index): slice_node = slice_node.value except AttributeError: pass # removed in astroid 3 if isinstance(slice_node, nodes.Tuple): slice_ = ", ".join(resolve_annotation(elt) for elt in slice_node.elts) else: slice_ = resolve_annotation(slice_node) resolved = f"{value}[{slice_}]" elif isinstance(annotation, nodes.Tuple): resolved = ( "(" + ", ".join(resolve_annotation(elt) for elt in annotation.elts) + ")" ) elif isinstance(annotation, nodes.List): resolved = ( "[" + ", ".join(resolve_annotation(elt) for elt in annotation.elts) + "]" ) elif isinstance(annotation, nodes.BinOp): resolved = ( resolve_annotation(annotation.left) + " " + annotation.op + " " + resolve_annotation(annotation.right) ) else: resolved = annotation.as_string() # note sphinx-autoapi had this, but # (a) its not needed, because sphinx strips it in the rendered HTML # and (b) it could lead to incorrect links for name clashes # if resolved.startswith("typing."): # return resolved[len("typing.") :] # Note, sphinx-autoapi had this, with the rationale: # > Sphinx is capable of linking anything in the same module # > without needing a fully qualified path. # However, (a) this breaks if following __all__ and # (b) lead to hard to decipher missing references # this does though lead to the fully qualified names # showing in the type annotations of the output HTML # TODO make sphinx not show fully qualified names of type annotations # module_prefix = annotation.root().name + "." # if resolved.startswith(module_prefix): # return resolved[len(module_prefix) :] # TODO sphinx type resolver does not understand Ellipsis, maybe it should? if resolved == "Ellipsis": return "..." return resolved
[docs]def resolve_qualname(node: nodes.NodeNG, basename: str) -> str: """Resolve where a node is defined to get its fully qualified name. :param node: The node representing the base name. :param basename: The partial base name to resolve. :returns: The fully resolved base name. """ full_basename = basename top_level_name = re.sub(r"\(.*\)", "", basename).split(".", 1)[0] lookup_node = node if isinstance(node, nodes.LocalsDictNodeNG) else node.scope() assigns = lookup_node.lookup(top_level_name)[1] for assignment in assigns: if isinstance(assignment, nodes.ImportFrom): import_name = get_full_import_name(assignment, top_level_name) full_basename = basename.replace(top_level_name, import_name, 1) break if isinstance(assignment, nodes.Import): import_name = resolve_import_alias(top_level_name, assignment.names) full_basename = basename.replace(top_level_name, import_name, 1) break if isinstance(assignment, nodes.ClassDef): full_basename = assignment.qname() break if isinstance(assignment, nodes.AssignName): full_basename = f"{assignment.scope().qname()}.{assignment.name}" if isinstance(node, nodes.Call): full_basename = re.sub(r"\(.*\)", "()", full_basename) if full_basename.startswith("builtins."): return full_basename[len("builtins.") :] if full_basename.startswith("__builtin__."): return full_basename[len("__builtin__.") :] return full_basename
[docs]def get_module_all(node: nodes.Module) -> None | list[str]: """Get the contents of the ``__all__`` variable from a module.""" all_ = None if "__all__" in node.locals: assigned = next(node.igetattr("__all__")) if assigned is not astroid.Uninferable: all_ = [] for elt in getattr(assigned, "elts", ()): try: elt_name = next(elt.infer()) except astroid.InferenceError: continue if elt_name is astroid.Uninferable: continue if isinstance(elt_name, nodes.Const) and isinstance( elt_name.value, str ): all_.append(elt_name.value) return all_
[docs]def is_decorated_with_singledispatch( node: nodes.FunctionDef | nodes.AsyncFunctionDef, ) -> bool: """Check if the function is decorated as a singledispatch.""" if not node.decorators: return False for decorator in node.decorators.nodes: if not isinstance(decorator, astroid.Name): continue try: if is_singledispatch_decorator(decorator): return True except astroid.InferenceError: pass return False
[docs]def is_singledispatch_decorator(decorator: astroid.Name) -> bool: """Check if the decorator is a singledispatch.""" def _is_singledispatch_func(func_node: nodes.FunctionDef) -> bool: return ( # type: ignore[no-any-return] func_node.name == "singledispatch" and func_node.root().name == "functools" ) for inferred in decorator.infer(): if not isinstance(inferred, nodes.FunctionDef): continue if _is_singledispatch_func(inferred): return True return False
[docs]def is_decorated_as_singledispatch_register( node: nodes.FunctionDef | nodes.AsyncFunctionDef, ) -> bool: """Check if the function is decorated as a singledispatch register.""" if not node.decorators: return False for decorator in node.decorators.nodes: if not isinstance(decorator, nodes.Call): continue if not isinstance(decorator.func, nodes.Attribute): continue if decorator.func.attrname == "register": return True # TODO any more checking? return False
[docs]def is_decorated_with_property( node: nodes.FunctionDef | nodes.AsyncFunctionDef, ) -> bool: """Check if the function is decorated as a property.""" if not node.decorators: return False for decorator in node.decorators.nodes: if not isinstance(decorator, astroid.Name): continue try: if is_property_decorator(decorator): return True except astroid.InferenceError: pass return False
[docs]def is_property_decorator(decorator: astroid.Name) -> bool: """Check if the decorator is a property.""" def _is_property_class(class_node: nodes.ClassDef) -> bool: return ( # type: ignore[no-any-return] class_node.name == "property" and class_node.root().name == builtins.__name__ ) for inferred in decorator.infer(): if not isinstance(inferred, nodes.ClassDef): continue if _is_property_class(inferred): return True if any(_is_property_class(ancestor) for ancestor in inferred.ancestors()): return True return False
[docs]def is_decorated_with_property_setter( node: nodes.FunctionDef | nodes.AsyncFunctionDef, ) -> bool: """Check if the function is decorated as a property setter. :param node: The node to check. :returns: True if the function is a property setter, False otherwise. """ if not node.decorators: return False for decorator in node.decorators.nodes: if ( isinstance(decorator, astroid.nodes.Attribute) and decorator.attrname == "setter" ): return True return False
[docs]def get_class_docstring(node: nodes.ClassDef) -> tuple[str, str | None]: """Get the docstring of a node, using a parent docstring if needed.""" doc_node = node.doc_node if doc_node is None: for base in node.ancestors(): if base.qname() in ( "__builtins__.object", "builtins.object", "builtins.type", ): continue if base.doc_node is not None: return base.doc_node.value, base.qname() return doc_node.value if doc_node is not None else "", None
[docs]def is_exception(node: nodes.ClassDef) -> bool: """Check if a class is an exception.""" if node.name in ("Exception", "BaseException") and node.root().name == "builtins": return True if not hasattr(node, "ancestors"): return False return any(is_exception(parent) for parent in node.ancestors(recurs=True))
[docs]def is_decorated_with_overload(node: nodes.FunctionDef) -> bool: """Check if the function is decorated as an overload definition.""" if not node.decorators: return False for decorator in node.decorators.nodes: if not isinstance(decorator, (astroid.Name, astroid.Attribute)): continue try: if is_overload_decorator(decorator): return True except astroid.InferenceError: pass return False
[docs]def is_overload_decorator(decorator: astroid.Name | astroid.Attribute) -> bool: for inferred in decorator.infer(): if not isinstance(inferred, astroid.nodes.FunctionDef): continue if inferred.name == "overload" and inferred.root().name == "typing": return True return False
[docs]def get_func_docstring(node: nodes.FunctionDef) -> tuple[str, None | str]: """Get the docstring of a node, using a parent docstring if needed.""" doc_node = node.doc_node if doc_node is None and isinstance(node.parent, nodes.ClassDef): for base in node.parent.ancestors(): if node.name in ("__init__", "__new__") and base.qname() in ( "__builtins__.object", "builtins.object", "builtins.type", ): continue for child in base.get_children(): if ( isinstance(child, node.__class__) and child.name == node.name and child.doc_node is not None ): return str(child.doc_node.value), child.qname() return doc_node.value if doc_node is not None else "", None
[docs]def get_return_annotation(node: nodes.FunctionDef) -> None | str: """Get the return annotation of a node.""" if node.returns: return resolve_annotation(node.returns) if node.type_comment_returns: return resolve_annotation(node.type_comment_returns) return None
[docs]def get_args_info( args_node: astroid.Arguments, ) -> list[tuple[None | str, None | str, None | str, None | str]]: """Get the arguments of a function. :returns: a list of (type, name, annotation, default) """ result: list[tuple[None | str, None | str, None | str, None | str]] = [] positional_only_defaults = [] positional_or_keyword_defaults = args_node.defaults if args_node.defaults: args = args_node.args or [] positional_or_keyword_defaults = args_node.defaults[-len(args) :] positional_only_defaults = args_node.defaults[ : len(args_node.defaults) - len(args) ] plain_annotations = args_node.annotations or () func_comment_annotations = args_node.parent.type_comment_args or [] if args_node.parent.type in ("method", "classmethod"): func_comment_annotations = [None, *func_comment_annotations] comment_annotations = args_node.type_comment_posonlyargs comment_annotations += args_node.type_comment_args or [] comment_annotations += args_node.type_comment_kwonlyargs annotations = list( _merge_annotations( plain_annotations, _merge_annotations(func_comment_annotations, comment_annotations), ) ) annotation_offset = 0 if args_node.posonlyargs: posonlyargs_annotations = args_node.posonlyargs_annotations if not any(args_node.posonlyargs_annotations): num_args = len(args_node.posonlyargs) posonlyargs_annotations = annotations[ annotation_offset : annotation_offset + num_args ] for arg, annotation, default in _iter_args( args_node.posonlyargs, posonlyargs_annotations, positional_only_defaults ): result.append((None, arg, annotation, default)) result.append(("/", None, None, None)) if not any(args_node.posonlyargs_annotations): annotation_offset += num_args if args_node.args: num_args = len(args_node.args) for arg, annotation, default in _iter_args( args_node.args, annotations[annotation_offset : annotation_offset + num_args], positional_or_keyword_defaults, ): result.append((None, arg, annotation, default)) annotation_offset += num_args if args_node.vararg: annotation = None if args_node.varargannotation: annotation = resolve_annotation(args_node.varargannotation) elif len(annotations) > annotation_offset and annotations[annotation_offset]: annotation = resolve_annotation(annotations[annotation_offset]) annotation_offset += 1 result.append(("*", args_node.vararg, annotation, None)) if args_node.kwonlyargs: if not args_node.vararg: result.append(("*", None, None, None)) kwonlyargs_annotations = args_node.kwonlyargs_annotations if not any(args_node.kwonlyargs_annotations): num_args = len(args_node.kwonlyargs) kwonlyargs_annotations = annotations[ annotation_offset : annotation_offset + num_args ] for arg, annotation, default in _iter_args( args_node.kwonlyargs, kwonlyargs_annotations, args_node.kw_defaults, ): result.append((None, arg, annotation, default)) if not any(args_node.kwonlyargs_annotations): annotation_offset += num_args if args_node.kwarg: annotation = None if args_node.kwargannotation: annotation = resolve_annotation(args_node.kwargannotation) elif len(annotations) > annotation_offset and annotations[annotation_offset]: annotation = resolve_annotation(annotations[annotation_offset]) annotation_offset += 1 result.append(("**", args_node.kwarg, annotation, None)) if args_node.parent.type in ("method", "classmethod") and result: result.pop(0) return result
[docs]def _iter_args( args: list[nodes.NodeNG], annotations: list[nodes.NodeNG], defaults: list[nodes.NodeNG], ) -> t.Iterable[tuple[str, None | str, str | None]]: """Iterate over arguments.""" default_offset = len(args) - len(defaults) packed = itertools.zip_longest(args, annotations) for i, (arg, annotation) in enumerate(packed): default = None if (defaults is not None and i >= default_offset) and ( defaults[i - default_offset] is not None ): default = defaults[i - default_offset].as_string() name = arg.name if isinstance(arg, astroid.Tuple): argument_names = ", ".join(x.name for x in arg.elts) name = f"({argument_names})" yield ( name, resolve_annotation(annotation) if annotation else None, default, )
[docs]def _merge_annotations( annotations: t.Iterable[t.Any], comment_annotations: t.Iterable[t.Any] ) -> t.Iterable[t.Any]: for ann, comment_ann in itertools.zip_longest(annotations, comment_annotations): if ann and not _is_ellipsis(ann): yield ann elif comment_ann and not _is_ellipsis(comment_ann): yield comment_ann else: yield None
[docs]def _is_ellipsis(node: t.Any) -> bool: return isinstance(node, astroid.Const) and node.value == Ellipsis