Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 124 additions & 55 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,117 @@
"""Tests to verify all modules can be imported successfully."""
"""Tests to verify all modules can be imported and that __init__.py is consistent."""

import ast
import importlib
from pathlib import Path

import pytest

PACKAGE_DIR = Path(__file__).resolve().parent.parent / "ruf_common"

# Internal sub-modules that are not part of the public API
# (imported by other modules, not re-exported from __init__.py)
_INTERNAL_MODULES = {"database_sqlite3"}


def _get_module_files():
"""Return the set of module names derived from .py files in the package directory."""
return {
p.stem
for p in PACKAGE_DIR.glob("*.py")
if p.stem != "__init__" and not p.name.startswith("_")
}


def _get_public_module_files():
"""Return module files excluding known internal sub-modules."""
return _get_module_files() - _INTERNAL_MODULES


def _parse_init():
"""Parse __init__.py and return the set of imported names and the __all__ list."""
init_path = PACKAGE_DIR / "__init__.py"
source = init_path.read_text()
tree = ast.parse(source, filename=str(init_path))

imports = set()
all_list = None

for node in ast.walk(tree):
# Catch `from . import foo`
if isinstance(node, ast.ImportFrom) and node.module is None and node.level == 1:
for alias in node.names:
imports.add(alias.name)
# Catch `__all__ = [...]`
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == "__all__":
if isinstance(node.value, ast.List):
all_list = {
elt.value
for elt in node.value.elts
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
}

return imports, all_list


# ---------------------------------------------------------------------------
# Consistency tests — these would have caught the aws issue
# ---------------------------------------------------------------------------

class TestPackageConsistency:
"""Verify __init__.py imports, __all__, and module files are all in sync."""

def test_no_imports_without_module_files(self):
"""Every 'from . import X' in __init__.py must have a corresponding .py file."""
imports, _ = _parse_init()
module_files = _get_module_files()
missing_files = imports - module_files
assert not missing_files, (
f"__init__.py imports modules that have no .py file: {sorted(missing_files)}"
)

def test_no_all_entries_without_module_files(self):
"""Every entry in __all__ must have a corresponding .py file."""
_, all_list = _parse_init()
assert all_list is not None, "__all__ is not defined in __init__.py"
module_files = _get_module_files()
missing_files = all_list - module_files
assert not missing_files, (
f"__all__ lists modules that have no .py file: {sorted(missing_files)}"
)

def test_no_all_entries_without_import(self):
"""Every entry in __all__ must have a matching 'from . import' statement."""
imports, all_list = _parse_init()
assert all_list is not None, "__all__ is not defined in __init__.py"
missing_imports = all_list - imports
assert not missing_imports, (
f"__all__ lists modules not imported in __init__.py: {sorted(missing_imports)}"
)

def test_no_imports_missing_from_all(self):
"""Every 'from . import X' should be listed in __all__."""
imports, all_list = _parse_init()
assert all_list is not None, "__all__ is not defined in __init__.py"
missing_from_all = imports - all_list
assert not missing_from_all, (
f"__init__.py imports modules not listed in __all__: {sorted(missing_from_all)}"
)

def test_no_public_module_files_missing_from_init(self):
"""Every public .py module file should be imported in __init__.py."""
imports, _ = _parse_init()
public_modules = _get_public_module_files()
not_imported = public_modules - imports
assert not not_imported, (
f"Module files exist but are not imported in __init__.py: {sorted(not_imported)}"
)


# ---------------------------------------------------------------------------
# Import tests — verify each module actually loads at runtime
# ---------------------------------------------------------------------------

class TestModuleImports:
"""Verify all modules in the ruf_common package can be imported."""
Expand All @@ -11,57 +121,16 @@ def test_import_ruf_common(self):
import ruf_common
assert ruf_common is not None

def test_import_country_code_converter(self):
"""Test importing country_code_converter module."""
from ruf_common import country_code_converter
assert country_code_converter is not None

def test_import_data(self):
"""Test importing data module."""
from ruf_common import data
assert data is not None

def test_import_database(self):
"""Test importing database module."""
from ruf_common import database
assert database is not None

def test_import_helper(self):
"""Test importing helper module."""
from ruf_common import helper
assert helper is not None

def test_import_html_to_markdown(self):
"""Test importing html_to_markdown module."""
from ruf_common import html_to_markdown
assert html_to_markdown is not None

def test_import_lfs(self):
"""Test importing lfs module."""
from ruf_common import lfs
assert lfs is not None

def test_import_logging(self):
"""Test importing logging module."""
from ruf_common import logging
assert logging is not None

def test_import_network(self):
"""Test importing network module."""
from ruf_common import network
assert network is not None

def test_import_stats(self):
"""Test importing stats module."""
from ruf_common import stats
assert stats is not None

def test_import_timezone_lookup(self):
"""Test importing timezone_lookup module."""
from ruf_common import timezone_lookup
assert timezone_lookup is not None

def test_import_xml_formatter(self):
"""Test importing xml_formatter module."""
from ruf_common import xml_formatter
assert xml_formatter is not None
@pytest.mark.parametrize("module_name", sorted(_get_public_module_files()))
def test_import_module(self, module_name):
"""Test that each module file can be imported from the package."""
mod = importlib.import_module(f"ruf_common.{module_name}")
assert mod is not None

def test_all_entries_importable(self):
"""Every entry in __all__ must be importable."""
_, all_list = _parse_init()
assert all_list is not None
for name in sorted(all_list):
mod = importlib.import_module(f"ruf_common.{name}")
assert mod is not None, f"Could not import ruf_common.{name}"
Loading