Source code for next.testing.signals

"""Signal capture utility for tests.

`SignalRecorder` is a context manager that connects to one or more
Django signals, stores every emission as a `SignalEvent`, and
disconnects on exit. It works with plain Django `TestCase`, the stdlib
`unittest.TestCase`, and pytest without any framework-specific code.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Self


if TYPE_CHECKING:
    from collections.abc import Iterator
    from types import TracebackType

    from django.dispatch import Signal


[docs] @dataclass(frozen=True, slots=True) class SignalEvent: """Single captured emission of a Django signal.""" signal: Signal sender: Any kwargs: dict[str, Any]
[docs] class SignalRecorder: """Record every emission of the given signals until `stop` is called. Use as a context manager for scoped capture or call `start` and `stop` explicitly when the lifecycle spans multiple helpers. """
[docs] def __init__(self, *signals: Signal) -> None: """Accept one or more Django signals to record.""" if not signals: msg = "SignalRecorder requires at least one signal" raise ValueError(msg) self.signals: tuple[Signal, ...] = signals self.events: list[SignalEvent] = [] self._started: bool = False
[docs] def start(self) -> Self: """Connect receivers for every tracked signal and return self.""" if self._started: return self for signal in self.signals: signal.connect(self._receiver, weak=False) self._started = True return self
[docs] def stop(self) -> None: """Disconnect receivers for every tracked signal.""" if not self._started: return for signal in self.signals: signal.disconnect(self._receiver) self._started = False
def _receiver( self, sender: Any, # noqa: ANN401 signal: Signal, **kwargs: Any, # noqa: ANN401 ) -> None: self.events.append( SignalEvent(signal=signal, sender=sender, kwargs=dict(kwargs)) )
[docs] def __enter__(self) -> Self: """Start recording on context entry.""" return self.start()
[docs] def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, ) -> None: """Stop recording on context exit.""" self.stop()
[docs] def __iter__(self) -> Iterator[SignalEvent]: """Iterate over captured events in emission order.""" return iter(self.events)
[docs] def __len__(self) -> int: """Return the number of captured events.""" return len(self.events)
[docs] def events_for(self, signal: Signal) -> list[SignalEvent]: """Return every captured event emitted by the given signal.""" return [event for event in self.events if event.signal is signal]
[docs] def first_for(self, signal: Signal) -> SignalEvent: """Return the first captured event for `signal` or raise `LookupError`.""" for event in self.events: if event.signal is signal: return event msg = f"No captured events for signal {signal!r}" raise LookupError(msg)
[docs] def last_for(self, signal: Signal) -> SignalEvent: """Return the last captured event for `signal` or raise `LookupError`.""" for event in reversed(self.events): if event.signal is signal: return event msg = f"No captured events for signal {signal!r}" raise LookupError(msg)
[docs] def clear(self) -> None: """Drop every captured event without disconnecting.""" self.events.clear()
[docs] def capture_signals(*signals: Signal) -> SignalRecorder: """Return a started `SignalRecorder` for use as a context manager. Equivalent to `SignalRecorder(*signals).start()` but reads like a verb at the call site: `with capture_signals(sig) as rec: ...`. """ return SignalRecorder(*signals).start()
[docs] def capture_framework_signals() -> SignalRecorder: """Return a recorder connected to every signal in `next.signals.__all__`. Handy when a test wants to verify that nothing unexpected fires without wiring each signal by hand. """ # Lazy-imported to keep `next.testing` from pulling in every framework # subsystem at import time. from next import signals as framework_signals # noqa: PLC0415 tracked = tuple( getattr(framework_signals, name) for name in framework_signals.__all__ ) return SignalRecorder(*tracked).start()
__all__ = [ "SignalEvent", "SignalRecorder", "capture_framework_signals", "capture_signals", ]