diff --git a/plum/parametric.py b/plum/parametric.py index 7103869..0ef155a 100644 --- a/plum/parametric.py +++ b/plum/parametric.py @@ -13,6 +13,7 @@ "CovariantMeta", "parametric", "type_parameter", + "type_nonparametric", "type_unparametrized", "kind", "Kind", @@ -281,7 +282,7 @@ def __class_nonparametric__(cls): """Return the non-parametric type of an object. :mod:`plum.parametric` produces parametric subtypes of classes. This - method can be used to get the non-parametric type of an object. + method can be used to get the original non-parametric type of an object. See Also -------- @@ -468,6 +469,72 @@ def type_parameter(x): ) +def type_nonparametric(q: T) -> Type[T]: + """Return the non-parametric type of an object. + + :mod:`plum.parametric` produces parametric subtypes of classes. This method + can be used to get the original non-parametric type of an object. + + See Also + -------- + :func:`plum.type_unparametrized` + A function that returns the non-concrete, but still parametric, type of + an object. + + Examples + -------- + In this example we will demonstrate how to retrieve the original + non-parametric class from a :func:`plum.parametric` decorated class. + + :func:`plum.parametric` defines a parametric class of the same name as the + original class, and then creates a subclass of the original class with the + type parameter inferred from the arguments of the constructor. + + >>> from plum import parametric + >>> class Obj: + ... @classmethod + ... def __infer_type_parameter__(cls, *arg): + ... return type(arg[0]) + ... def __init__(self, x): + ... self.x = x + ... def __repr__(self): + ... return f"Obj({self.x})" + >>> PObj = parametric(Obj) + >>> pobj = PObj(1) + + >>> type(pobj).mro() + [, , + , ] + + Note that the class `Obj` appears twice in the MRO. The first one is the + parametric class, and the second one is the non-parametric class. The + non-parametric class is the original class that was passed to the + ``parametric`` decorator. + + Rather than navigating the MRO, we can get the non-parametric class of an + object by calling ``type_nonparametric`` function. + + >>> type(pobj) is PObj[int] + True + >>> type(pobj) is PObj + False + >>> type(pobj) is Obj + False + + >>> type_nonparametric(pobj) is PObj[int] + False + >>> type_nonparametric(pobj) is PObj + False + >>> type_nonparametric(pobj) is Obj + True + """ + return ( + q.__class_nonparametric__() + if isinstance(type(q), ParametricTypeMeta) + else type(q) + ) + + def type_unparametrized(q: T) -> Type[T]: """Return the unparametrized type of an object. diff --git a/tests/test_parametric.py b/tests/test_parametric.py index 0532f67..3c3b2e9 100644 --- a/tests/test_parametric.py +++ b/tests/test_parametric.py @@ -15,7 +15,13 @@ parametric, type_parameter, ) -from plum.parametric import CovariantMeta, is_concrete, is_type, type_unparametrized +from plum.parametric import ( + CovariantMeta, + is_concrete, + is_type, + type_nonparametric, + type_unparametrized, +) def test_covariantmeta(): @@ -606,3 +612,27 @@ def __repr__(self): assert type(pobj) is Obj[int] assert type_unparametrized(pobj) is not Obj[int] assert type_unparametrized(pobj) is Obj + + +def test_type_nonparametric(): + """Test the `type_nonparametric` function.""" + + class NonParametricObj: + @classmethod + def __infer_type_parameter__(cls, *arg): + return type(arg[0]) + + def __init__(self, x): + self.x = x + + def __repr__(self): + return f"Obj({self.x})" + + Obj = parametric(NonParametricObj) + + pobj = Obj(1) + + assert type(pobj) is Obj[int] + assert type_nonparametric(pobj) is not Obj[int] + assert type_nonparametric(pobj) is not Obj + assert type_nonparametric(pobj) is NonParametricObj