diff --git a/magma/clock.py b/magma/clock.py index 09c938bef0..fcd0f0a018 100644 --- a/magma/clock.py +++ b/magma/clock.py @@ -1,3 +1,4 @@ +import functools from .t import Direction, In from .digital import DigitalMeta, Digital from .wire import wire @@ -126,16 +127,22 @@ def _get_first_clock(port, clocktype): return None -def wireclocktype(defn, inst, clocktype): +def _get_first_clock_of_defn(defn, clocktype): # Check common case: top level clock port clks = (port if isinstance(port, clocktype) else None for port in defn.interface.ports.values()) defnclk = _first(clks) - if defnclk is None: - # Check recursive types - clks = (_get_first_clock(port, clocktype) - for port in defn.interface.ports.values()) - defnclk = _first(clks) + if defnclk is not None: + return defnclk + # Check recursive types + clks = (_get_first_clock(port, clocktype) + for port in defn.interface.ports.values()) + return _first(clks) + + +@functools.lru_cache(maxsize=None) +def wireclocktype(defn, inst, clocktype): + defnclk = _get_first_clock_of_defn(defn, clocktype) if defnclk is None: return for port in inst.interface.inputs(include_clocks=True):