Skip to content

Commit

Permalink
fix some problems with the impl store
Browse files Browse the repository at this point in the history
  • Loading branch information
finn-rudolph committed Oct 18, 2024
1 parent 926ce48 commit b833cda
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 52 deletions.
4 changes: 3 additions & 1 deletion src/pydiverse/transform/_internal/backend/impl_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def add_impl(
assert op not in self.default_impl
self.default_impl[op] = f
else:
if op not in self.impl_trie:
self.impl_trie[op] = SignatureTrie()
self.impl_trie[op].insert(sig, f, is_vararg)

def get_impl(self, op: Operator, sig: Sequence[Dtype]) -> Callable | None:
Expand All @@ -40,7 +42,7 @@ def get_impl(self, op: Operator, sig: Sequence[Dtype]) -> Callable | None:
if (trie := self.impl_trie.get(op)) is not None:
_, best_match = trie.best_match(sig)
if best_match is None:
best_match = self.default_impl[op]
best_match = self.default_impl.get(op)

if best_match is None:
return None
Expand Down
19 changes: 4 additions & 15 deletions src/pydiverse/transform/_internal/backend/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,13 @@ def compile_order(
)


def compile_col_expr(
expr: ColExpr, name_in_df: dict[UUID, str], *, compile_literals: bool = True
) -> pl.Expr:
def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:
if isinstance(expr, Col):
return pl.col(name_in_df[expr._uuid])

elif isinstance(expr, ColFn):
impl = PolarsImpl.get_impl(expr.op, tuple(arg.dtype() for arg in expr.args))

# TODO: technically, constness of our parameters has nothing to do with whether
# the polars function wants a pdt.lit or python type. We should rather specify
# this in the impl or always pass both the compiled and uncompiled args. But if
# we know a polars function can only take a python scalar, then our param also
# has to be const, so we can unwrap the python scalar from the compiled expr.
args: list[pl.Expr] = [
compile_col_expr(arg, name_in_df, compile_literals=not param.const)
for arg, param in zip(expr.args, impl.impl.signature, strict=False)
]
args: list[pl.Expr] = [compile_col_expr(arg, name_in_df) for arg in expr.args]

if (partition_by := expr.context_kwargs.get("partition_by")) is not None:
partition_by = [compile_col_expr(pb, name_in_df) for pb in partition_by]
Expand Down Expand Up @@ -195,7 +184,7 @@ def compile_col_expr(
return compiled

elif isinstance(expr, LiteralCol):
return pl.lit(expr.val) if compile_literals else expr.val
return pl.lit(expr.val)

elif isinstance(expr, Cast):
compiled = compile_col_expr(expr.val, name_in_df).cast(
Expand Down Expand Up @@ -615,7 +604,7 @@ def _horizontal_min(*x):

@impl(ops.round)
def _round(x, digits):
return x.round(digits)
return x.round(pl.select(digits).item())

@impl(ops.exp)
def _exp(x):
Expand Down
23 changes: 4 additions & 19 deletions src/pydiverse/transform/_internal/backend/table_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]) -> Any: ..
def get_impl(cls, op: Operator, sig: Sequence[Dtype]) -> Any:
if (impl := cls.impl_store.get_impl(op, sig)) is not None:
return impl

if cls is TableImpl:
raise Exception
raise NotSupportedError

try:
super().get_impl(op, sig)
except Exception as err:
return cls.__bases__[0].get_impl(op, sig)
except NotSupportedError as err:
raise NotSupportedError(
f"operation `{op.name}` is not supported by the backend "
f"`{cls.__name__.lower()[:-4]}`"
Expand All @@ -112,22 +113,6 @@ def get_impl(cls, op: Operator, sig: Sequence[Dtype]) -> Any:

with TableImpl.impl_store.impl_manager as impl:

@impl(ops.nulls_first)
def _nulls_first(_):
raise AssertionError

@impl(ops.nulls_last)
def _nulls_last(_):
raise AssertionError

@impl(ops.ascending)
def _ascending(_):
raise AssertionError

@impl(ops.descending)
def _descending(_):
raise AssertionError

@impl(ops.add)
def _add(lhs, rhs):
return lhs + rhs
Expand Down
12 changes: 6 additions & 6 deletions src/pydiverse/transform/_internal/ops/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,22 @@ def best_signature_match(
) -> int:
assert len(candidates) > 0

best = candidates[0]
best_index = 0
best_distance = sig_distance(sig, candidates[0])

for match in candidates[1:]:
for i, match in enumerate(candidates[1:]):
if best_distance > (this_distance := sig_distance(sig, match)):
best = match
best_index = i
best_distance = this_distance

assert (
sum(int(best_distance == sig_distance(match, sig)) for match in candidates) == 1
sum(int(best_distance == sig_distance(sig, match)) for match in candidates) == 1
)
return best
return best_index


def sig_distance(sig: Sequence[Dtype], target: Sequence[Dtype]) -> tuple[int, int]:
return (
return tuple(
sum(z)
for z in zip(
*(types.conversion_cost(s, t) for s, t in zip(sig, target, strict=True)),
Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/transform/_internal/pipe/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def preprocess_arg(arg: Any, table: Table, *, update_partition_by: bool = True)
update_partition_by
and isinstance(expr, ColFn)
and "partition_by" not in expr.context_kwargs
and (expr.op().ftype in (Ftype.WINDOW, Ftype.AGGREGATE))
and (expr.op.ftype in (Ftype.WINDOW, Ftype.AGGREGATE))
):
expr.context_kwargs["partition_by"] = table._cache.partition_by

Expand Down
32 changes: 22 additions & 10 deletions src/pydiverse/transform/_internal/tree/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class Dtype:
def __init__(self, *, const: bool = False):
self.const = const

def __eq__(self, rhs: Dtype | type[Dtype]) -> bool:
def __eq__(self, rhs: Dtype | type[Dtype] | None) -> bool:
if rhs is None:
return False
if inspect.isclass(rhs) and issubclass(rhs, Dtype):
rhs = rhs()
if is_supertype(rhs.without_const()) and not is_subtype(rhs.without_const()):
Expand Down Expand Up @@ -51,10 +53,8 @@ def without_const(self) -> Dtype:

def converts_to(self, target: Dtype) -> bool:
return (
(not target.const or self.const)
and self.without_const() in IMPLICIT_CONVS
and target.without_const() in IMPLICIT_CONVS[self.without_const()]
)
not target.const or self.const
) and target.without_const() in IMPLICIT_CONVS[self.without_const()]


class Float(Dtype):
Expand Down Expand Up @@ -236,9 +236,21 @@ def is_subtype(dtype: Dtype) -> bool:


IMPLICIT_CONVS: dict[Dtype, dict[Dtype, tuple[int, int]]] = {
Int(): {Float(): (1, 0), Decimal(): (2, 0)},
**{int_subtype: {Int(): (0, 1)} for int_subtype in INT_SUBTYPES},
**{float_subtype: {Float(): (0, 1)} for float_subtype in FLOAT_SUBTYPES},
Int(): {Float(): (1, 0), Decimal(): (2, 0), Int(): (0, 0)},
**{
int_subtype: {Int(): (0, 1), int_subtype: (0, 0)}
for int_subtype in INT_SUBTYPES
},
**{
float_subtype: {Float(): (0, 1), float_subtype: (0, 0)}
for float_subtype in FLOAT_SUBTYPES
},
String(): {String(): (0, 0)},
Datetime(): {Datetime(): (0, 0)},
Date(): {Date(): (0, 0)},
Bool(): {Bool(): (0, 0)},
NullType(): {NullType(): (0, 0)},
Duration(): {Duration(): (0, 0)},
}

# compute transitive closure of cost graph
Expand All @@ -247,7 +259,7 @@ def is_subtype(dtype: Dtype) -> bool:
for intermediate_type, cost1 in IMPLICIT_CONVS[start_type].items():
if intermediate_type in IMPLICIT_CONVS:
for target_type, cost2 in IMPLICIT_CONVS[intermediate_type].items():
added_edges[target_type] = (
added_edges[target_type] = tuple(
sum(z) for z in zip(cost1, cost2, strict=True)
)
if start_type not in IMPLICIT_CONVS:
Expand All @@ -256,7 +268,7 @@ def is_subtype(dtype: Dtype) -> bool:


def conversion_cost(dtype: Dtype, target: Dtype) -> tuple[int, int]:
return IMPLICIT_CONVS[dtype][target]
return IMPLICIT_CONVS[dtype.without_const()][target.without_const()]


NUMERIC = (Int(), Float(), Decimal())
Expand Down

0 comments on commit b833cda

Please sign in to comment.