Source code for autodoc2.db

"""A database interface for storing and querying the analysis items."""
from __future__ import annotations

import json
import re
import typing as t

if t.TYPE_CHECKING:
    from .utils import ItemData


[docs]class UniqueError(KeyError): """An error raised when a unique constraint is violated."""
[docs]class Database(t.Protocol): """A simple interface for storing and querying the analysis items, from a single package. This allows for potential extensibility in the future, e.g. using a persistent sqlite database. """
[docs] def add(self, item: ItemData) -> None: """Add an item to the database."""
[docs] def remove(self, full_name: str, descendants: bool) -> None: """Remove an item from the database, by full_name. If `descendants` is True, remove all descendants of this item. """
[docs] def __contains__(self, full_name: str) -> bool: """Check if an item is in the database, by full_name."""
[docs] def get_item(self, full_name: str) -> ItemData | None: """Get an item from the database, by full_name."""
[docs] def get_items_like(self, full_name: str) -> t.Iterable[ItemData]: """Get an item from the database, matching the wildcards `*` and `?`. `*` matches any number of characters, and `?` matches any single character. """
[docs] def get_type(self, full_name: str) -> None | str: """Get the type of an item from the database, by full_name."""
[docs] def get_by_type(self, type_: str) -> t.Iterable[ItemData]: """Get all items from the database, by type."""
[docs] def get_overloads(self, full_name: str) -> t.Iterable[ItemData]: """Get all function overloads for this name."""
[docs] def get_children( self, full_name: str, types: None | set[str] = None, *, sort_name: bool = False ) -> t.Iterable[ItemData]: """Get all items that are direct children of this name, i.e. `{full_name}.{name}`. :param full_name: The full name of the item. :param types: If given, only return items of these types. :param sort_name: If True, sort the names alphabetically. """
[docs] def get_children_names( self, full_name: str, types: None | set[str] = None, *, sort_name: bool = False ) -> t.Iterable[str]: """Get all names of direct children of this name, i.e. `{full_name}.{name}`. :param full_name: The full name of the item. :param types: If given, only return items of these types. :param sort_name: If True, sort the names alphabetically. """
[docs] def get_ancestors( self, full_name: str, include_self: bool ) -> t.Iterable[ItemData | None]: """Get all ancestors of this name, e.g. `a.b`, `a` for `a.b.c`. The order is guaranteed from closest to furthest ancestor. :param full_name: The full name of the item. :param include_self: If True, include the item itself. """
_LIKE_REGEX = re.compile(r"([\*\?])")
[docs]class InMemoryDb(Database): """A simple in-memory database for storing and querying the analysis items.""" def __init__(self) -> None: """Create the database.""" self._items: dict[str, ItemData] = {} self._overloads: dict[str, list[ItemData]] = {}
[docs] def add(self, item: ItemData) -> None: if item["type"] == "overload": # note we do this here and not in the analyser, # because overloads come before the function they overload # and we don't want the analyzer to have to "look ahead" self._overloads.setdefault(item["full_name"], []).append(item) return if item["full_name"] in self._items: raise UniqueError(f"Item {item['full_name']} already exists") self._items[item["full_name"]] = item
[docs] def remove(self, full_name: str, descendants: bool) -> None: # remove the item itself self._items.pop(full_name, None) self._overloads.pop(full_name, None) if descendants: # remove all descendants for name in list(self._items): if name.startswith(full_name + "."): self._items.pop(name, None) self._overloads.pop(name, None)
[docs] def __contains__(self, full_name: str) -> bool: return full_name in self._items
[docs] def get_item(self, full_name: str) -> ItemData | None: return self._items.get(full_name)
[docs] def get_items_like(self, full_name: str) -> t.Iterable[ItemData]: parts = _LIKE_REGEX.split(full_name) pattern = re.compile( "".join( [ ".*" if part == "*" else ("." if part == "?" else re.escape(part)) for part in parts ] ) ) return ( item for item in self._items.values() if pattern.fullmatch(item["full_name"]) )
[docs] def get_type(self, full_name: str) -> None | str: item = self._items.get(full_name) if item is None: return None return item["type"]
[docs] def get_by_type(self, type_: str) -> t.Iterable[ItemData]: return (item for item in self._items.values() if item["type"] == type_)
[docs] def get_overloads(self, full_name: str) -> t.Iterable[ItemData]: return self._overloads.get(full_name, [])
[docs] def get_children( self, full_name: str, types: None | set[str] = None, *, sort_name: bool = False ) -> t.Iterable[ItemData]: generator = ( item for item in self._items.values() if item["full_name"].startswith(full_name + ".") and "." not in item["full_name"][len(full_name) + 1 :] and ((types is None) or (item["type"] in types)) ) if sort_name: return sorted(generator, key=lambda item: item["full_name"]) return generator
[docs] def get_children_names( self, full_name: str, types: None | set[str] = None, *, sort_name: bool = False ) -> t.Iterable[str]: generator = ( item["full_name"] for item in self._items.values() if item["full_name"].startswith(full_name + ".") and "." not in item["full_name"][len(full_name) + 1 :] and ((types is None) or (item["type"] in types)) ) if sort_name: return sorted(generator) return generator
[docs] def get_ancestors( self, full_name: str, include_self: bool ) -> t.Iterable[ItemData | None]: if include_self: yield self.get_item(full_name) parts = full_name.split(".")[:-1] while parts: yield self.get_item(".".join(parts)) parts.pop()
[docs] def write(self, stream: t.TextIO) -> None: """Write the database to a file.""" json.dump({"items": self._items, "overloads": self._overloads}, stream)
[docs] @classmethod def read(cls, stream: t.TextIO) -> InMemoryDb: """Read the database from a file.""" items = json.load(stream) db = cls() db._items = items["items"] db._overloads = items["overloads"] return db