diff --git a/magma/circuit.py b/magma/circuit.py index 2a7b89174..9746d6146 100644 --- a/magma/circuit.py +++ b/magma/circuit.py @@ -27,6 +27,7 @@ pass from magma.clock import is_clock_or_nested_clock, Clock, ClockTypes +from magma.common import only, IterableException from magma.config import get_debug_mode, set_debug_mode, config, RuntimeConfig from magma.definition_context import ( DefinitionContext, @@ -40,7 +41,6 @@ from magma.protocol_type import MagmaProtocol from magma.ref import TempNamedRef, AnonRef from magma.t import In, Type -from magma.view import PortView from magma.wire_container import WiringLog, AggregateWireable @@ -735,6 +735,21 @@ def __new__(metacls, name, bases, dct): return self + def __getattr__(self, attr): + error = None + try: + return object.__getattribute__(self, attr) + except AttributeError as e: + # NOTE(rsetaluri): The scope of `e` is only within this `except` + # block. Therefore we have to stash it in a local variable to be + # able to raise it later. + error = e + try: + return only(filter(lambda i: i.name == attr, self.instances)) + except IterableException: + pass + raise error from None + @property def is_definition(self): return self._is_definition or self.verilog or self.verilogFile diff --git a/tests/test_bind2.py b/tests/test_bind2.py index 400605c6f..3d9260636 100644 --- a/tests/test_bind2.py +++ b/tests/test_bind2.py @@ -58,13 +58,14 @@ class TopBasicAsserts(m.Circuit): @pytest.mark.parametrize( - "backend,flatten_all_tuples", + "backend,flatten_all_tuples,inst_attr", ( - ("mlir", True), - ("mlir", False), + ("mlir", False, False), + ("mlir", True, False), + ("mlir", True, True), ) ) -def test_xmr(backend, flatten_all_tuples): +def test_xmr(backend, flatten_all_tuples, inst_attr): class T(m.Product): x = m.Bit @@ -80,8 +81,12 @@ class Middle(m.Circuit): class Top(m.Circuit): io = m.IO(I=m.In(T), O=m.Out(T)) - middle = Middle(name="middle") - io.O @= middle(io.I) + if inst_attr: + middle = Middle(name="middle") + O = middle(io.I) + else: + O = Middle(name="middle")(io.I) + io.O @= O class TopXMRAsserts(m.Circuit): name = f"TopXMRAsserts_{backend}" diff --git a/tests/test_circuit/test_instance.py b/tests/test_circuit/test_instance.py index fdd2cbbb3..44f182178 100644 --- a/tests/test_circuit/test_instance.py +++ b/tests/test_circuit/test_instance.py @@ -1,6 +1,7 @@ import pytest import magma as m +from magma.common import only def test_callback_basic(): @@ -36,3 +37,34 @@ class _Test(m.Circuit): m.register_instance_callback(_Test.reg, lambda _, __: None) with pytest.raises(AttributeError): m.register_instance_callback(_Test.reg, lambda _, __: None) + + +def test_getattr_instance_name(): + Bar = m.Register(m.Bit) # just any module + + class Foo(m.Circuit): + my_placeholder_var = Bar(name="my_instance_name") + + assert Foo.my_placeholder_var is Foo.my_instance_name + + +def test_getattr_instance_name_overwritten(): + Bar = m.Register(m.Bit) # just any module + + class Foo(m.Circuit): + my_placeholder_var = Bar(name="my_instance_name") + my_instance_name = None + + assert Foo.my_instance_name is None + assert Foo.my_placeholder_var is only(Foo.instances) + + +def test_getattr_attribute_error(): + + class Foo(m.Circuit): + pass + + with pytest.raises(AttributeError) as e: + Foo.bar + + assert "has no attribute 'bar'" in str(e)