Skip to content

Commit

Permalink
feat: Support complex __all__ assignments
Browse files Browse the repository at this point in the history
Issue #40: #40
  • Loading branch information
pawamoy committed Apr 18, 2022
1 parent 7191799 commit 9a2128b
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 22 deletions.
2 changes: 1 addition & 1 deletion config/flake8.ini
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ ignore =

per-file-ignores =
src/griffe/dataclasses.py:WPS115
src/griffe/agents/nodes.py:WPS115
src/griffe/agents/nodes.py:WPS115,WPS116,WPS120
src/griffe/visitor.py:N802,D102
src/griffe/encoders.py:WPS232
tests/*:WPS116,WPS118,WPS218
76 changes: 59 additions & 17 deletions src/griffe/agents/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ast import AnnAssign as NodeAnnAssign
from ast import Assign as NodeAssign
from ast import Attribute as NodeAttribute
from ast import AugAssign as NodeAugAssign
from ast import BinOp as NodeBinOp
from ast import BitAnd as NodeBitAnd
from ast import BitOr as NodeBitOr
Expand Down Expand Up @@ -523,31 +524,72 @@ def _join(sequence, item):
return new_sequence


def _parse__all__constant(node: NodeConstant, parent: Module) -> list[str]:
try:
return [node.value]
except AttributeError:
return [node.s] # TODO: remove once Python 3.7 is dropped


def _parse__all__name(node: NodeName, parent: Module) -> list[Name]:
return [Name(node.id, partial(parent.resolve, node.id))]


def _parse__all__starred(node: NodeStarred, parent: Module) -> list[str | Name]:
return _parse__all__(node.value, parent)


def _parse__all__sequence(node: NodeList | NodeSet | NodeTuple, parent: Module) -> list[str | Name]:
sequence = []
for elt in node.elts:
sequence.extend(_parse__all__(elt, parent))
return sequence


def _parse__all__binop(node: NodeBinOp, parent: Module) -> list[str | Name]:
left = _parse__all__(node.left, parent)
right = _parse__all__(node.right, parent)
return left + right


_node__all__map: dict[Type, Callable[[Any, Module], list[str | Name]]] = { # noqa: WPS234
NodeConstant: _parse__all__constant, # type: ignore[dict-item]
NodeName: _parse__all__name, # type: ignore[dict-item]
NodeStarred: _parse__all__starred,
NodeList: _parse__all__sequence,
NodeSet: _parse__all__sequence,
NodeTuple: _parse__all__sequence,
NodeBinOp: _parse__all__binop,
}

# TODO: remove once Python 3.7 support is dropped
if sys.version_info < (3, 8):

def parse__all__(node: NodeAssign) -> set[str]: # noqa: WPS116,WPS120
"""Get the values declared in `__all__`.
def _parse__all__nameconstant(node: NodeNameConstant, parent: Module) -> list[Name]:
return [node.value]

Parameters:
node: The assignment node.
def _parse__all__str(node: NodeStr, parent: Module) -> list[str]:
return [node.s]

Returns:
A set of names.
"""
return {elt.s for elt in node.value.elts} # type: ignore[attr-defined]
_node__all__map[NodeNameConstant] = _parse__all__nameconstant # type: ignore[assignment]
_node__all__map[NodeStr] = _parse__all__str # type: ignore[assignment]

else:

def parse__all__(node: NodeAssign) -> set[str]: # noqa: WPS116,WPS120,WPS440
"""Get the values declared in `__all__`.
def _parse__all__(node: AST, parent: Module) -> list[str | Name]:
return _node__all__map[type(node)](node, parent)

Parameters:
node: The assignment node.

Returns:
A set of names.
"""
return {elt.value for elt in node.value.elts} # type: ignore[attr-defined]
def parse__all__(node: NodeAssign | NodeAugAssign, parent: Module) -> list[str | Name]: # noqa: WPS120,WPS440
"""Get the values declared in `__all__`.
Parameters:
node: The assignment node.
parent: The parent module.
Returns:
A set of names.
"""
return _parse__all__(node.value, parent)


# ==========================================================
Expand Down
18 changes: 17 additions & 1 deletion src/griffe/agents/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def handle_attribute( # noqa: WPS231

if name == "__all__":
with suppress(AttributeError):
parent.exports = parse__all__(node) # type: ignore[arg-type]
parent.exports = parse__all__(node, self.current) # type: ignore[assignment,arg-type]

def visit_assign(self, node: ast.Assign) -> None:
"""Visit an assignment node.
Expand All @@ -601,6 +601,22 @@ def visit_annassign(self, node: ast.AnnAssign) -> None:
"""
self.handle_attribute(node, get_annotation(node.annotation, parent=self.current))

def visit_augassign(self, node: ast.AugAssign) -> None:
"""Visit an augmented assignment node.
Parameters:
node: The node to visit.
"""
with suppress(AttributeError):
all_augment = (
node.target.id == "__all__" # type: ignore[attr-defined]
and self.current.is_module
and isinstance(node.op, ast.Add)
)
if all_augment:
# we assume exports is not None at this point
self.current.exports.extend(parse__all__(node, self.current)) # type: ignore[arg-type,union-attr]

def visit_if(self, node: ast.If) -> None:
"""Visit an "if" node.
Expand Down
2 changes: 1 addition & 1 deletion src/griffe/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def __init__(
self.members: dict[str, Object | Alias] = {}
self.labels: set[str] = set()
self.imports: dict[str, str] = {}
self.exports: set[str] | None = None
self.exports: set[str] | list[str | Name] | None = None
self.aliases: dict[str, Alias] = {}
self.runtime: bool = runtime
self._lines_collection: LinesCollection | None = lines_collection
Expand Down
26 changes: 24 additions & 2 deletions src/griffe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from griffe.dataclasses import Alias, Kind, Module, Object
from griffe.docstrings.parsers import Parser
from griffe.exceptions import AliasResolutionError, CyclicAliasError, LoadingError, UnimportableModuleError
from griffe.expressions import Name
from griffe.finder import ModuleFinder
from griffe.logger import get_logger
from griffe.stats import stats
Expand Down Expand Up @@ -162,8 +163,10 @@ def resolve_aliases( # noqa: WPS231
unresolved: set[str] = set("0") # init to enter loop
iteration = 0
collection = self.modules_collection.members
for w_module in list(collection.values()):
self.expand_wildcards(w_module)
for exports_module in list(collection.values()):
self.expand_exports(exports_module)
for wildcards_module in list(collection.values()):
self.expand_wildcards(wildcards_module)
while unresolved and unresolved != prev_unresolved and iteration < max_iterations: # type: ignore[operator]
prev_unresolved = unresolved - {"0"}
unresolved = set()
Expand All @@ -179,6 +182,25 @@ def resolve_aliases( # noqa: WPS231
)
return unresolved, iteration

def expand_exports(self, module: Module) -> None:
"""Expand exports: try to recursively expand all module exports.
Parameters:
module: The module to recurse on.
"""
if module.exports is None:
return
expanded = set()
for export in module.exports:
if isinstance(export, Name):
module_path = export.full.rsplit(".", 1)[0] # remove trailing .__all__
next_module = self.modules_collection[module_path]
self.expand_exports(next_module)
expanded |= next_module.exports
else:
expanded.add(export)
module.exports = expanded

def expand_wildcards( # noqa: WPS231
self,
obj: Object,
Expand Down
49 changes: 49 additions & 0 deletions tests/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,52 @@ def absolute(self, path: str | Path) -> str | Path:
assert overloads[1].parameters["path"].annotation.source == "Path" # noqa: WPS219
assert overloads[0].returns.source == "str"
assert overloads[1].returns.source == "Path"


@pytest.mark.parametrize(
"statements",
[
"""__all__ = moda_all + modb_all + modc_all + ["CONST_INIT"]""",
"""__all__ = ["CONST_INIT", *moda_all, *modb_all, *modc_all]""",
"""
__all__ = ["CONST_INIT"]
__all__ += moda_all + modb_all + modc_all
""",
"""
__all__ = moda_all + modb_all + modc_all
__all__ += ["CONST_INIT"]
""",
"""
__all__ = ["CONST_INIT"]
__all__ += moda_all
__all__ += modb_all + modc_all
""",
],
)
def test_parse_complex__all__assignments(statements):
"""Check our ability to expand exports based on `__all__` [augmented] assignments.
Parameters:
statements: Parametrized text containing `__all__` [augmented] assignments.
"""
with temporary_pypackage("package", ["moda.py", "modb.py", "modc.py"]) as tmp_package:
tmp_package.path.joinpath("moda.py").write_text("CONST_A = 1\n\n__all__ = ['CONST_A']")
tmp_package.path.joinpath("modb.py").write_text("CONST_B = 1\n\n__all__ = ['CONST_B']")
tmp_package.path.joinpath("modc.py").write_text("CONST_C = 2\n\n__all__ = ['CONST_C']")
code = """
from package.moda import *
from package.moda import __all__ as moda_all
from package.modb import *
from package.modb import __all__ as modb_all
from package.modc import *
from package.modc import __all__ as modc_all
CONST_INIT = 0
"""
tmp_package.path.joinpath("__init__.py").write_text(dedent(code) + dedent(statements))

loader = GriffeLoader(search_paths=[tmp_package.tmpdir])
package = loader.load_module(tmp_package.name)
loader.resolve_aliases()

assert package.exports == {"CONST_INIT", "CONST_A", "CONST_B", "CONST_C"}

0 comments on commit 9a2128b

Please sign in to comment.