Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce hierarchical type system and refactor operator / implementation registration #36

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4288cea
move files around
finn-rudolph Oct 16, 2024
5b75205
add int32, int16, ...
finn-rudolph Oct 16, 2024
13c6149
add @overload to verbs for autocompletion
finn-rudolph Oct 16, 2024
315f1bb
add type conversion to polars / sqa
finn-rudolph Oct 16, 2024
88f1942
extend type conversion logic
finn-rudolph Oct 16, 2024
9b8ae70
add cast to ColExpr
finn-rudolph Oct 16, 2024
83f7a90
remove operator variants
finn-rudolph Oct 16, 2024
b8884e9
set schema_overrides in MSSQL for bool
finn-rudolph Oct 17, 2024
90f604b
fix skipped pre-commit checks in CI
finn-rudolph Oct 17, 2024
6a0a741
implement signature trie better, type conv costs
finn-rudolph Oct 17, 2024
b2c1155
implement new impl store / cm
finn-rudolph Oct 17, 2024
4ec5ddc
write a nice signature class
finn-rudolph Oct 17, 2024
17be153
rewrite all operator definitions
finn-rudolph Oct 17, 2024
9504903
make the operator store it's signatures
finn-rudolph Oct 17, 2024
c0d31af
handle the hierarchical impl store in table_impl
finn-rudolph Oct 18, 2024
3691ccf
use signature distance logic for case expression
finn-rudolph Oct 18, 2024
837a6c0
port polars impls to the new world
finn-rudolph Oct 18, 2024
f1639f2
fix signature *args
finn-rudolph Oct 18, 2024
63ae193
move all backend impls to new version
finn-rudolph Oct 18, 2024
c337694
propagate const through fns / case exprs / cast
finn-rudolph Oct 18, 2024
ea639f2
update code generation
finn-rudolph Oct 18, 2024
f0d50d4
fix a lot of mistakes
finn-rudolph Oct 18, 2024
926ce48
update code generation
finn-rudolph Oct 18, 2024
b833cda
fix some problems with the impl store
finn-rudolph Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
environments: py310

- name: Linting - Run pre-commit checks
run: pixi run postinstall && pixi run pre-commit run
run: pixi run postinstall && pixi run pre-commit run -a --color=always --show-diff-on-failure

test:
name: pytest
Expand Down
142 changes: 70 additions & 72 deletions generate_col_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@
from collections.abc import Iterable
from types import NoneType

from pydiverse.transform._internal.backend.polars import PolarsImpl
from pydiverse.transform._internal.ops.core import NoExprMethod, Operator
from pydiverse.transform._internal.tree.dtypes import (
from pydiverse.transform._internal.ops import ops
from pydiverse.transform._internal.ops.op import Operator
from pydiverse.transform._internal.ops.signature import Signature
from pydiverse.transform._internal.tree.types import (
Dtype,
Template,
Tvar,
pdt_type_to_python,
)
from pydiverse.transform._internal.tree.registry import Signature

col_expr_path = "./src/pydiverse/transform/_internal/tree/col_expr.py"
fns_path = "./src/pydiverse/transform/_internal/pipe/functions.py"
reg = PolarsImpl.registry
namespaces = ["str", "dt"]
rversions = {
COL_EXPR_PATH = "./src/pydiverse/transform/_internal/tree/col_expr.py"
FNS_PATH = "./src/pydiverse/transform/_internal/pipe/functions.py"

NAMESPACES = ["str", "dt"]

RVERSIONS = {
"__add__",
"__sub__",
"__mul__",
Expand All @@ -31,19 +32,18 @@
}


def format_param(name: str, dtype: Dtype) -> str:
if dtype.vararg:
return f"*{name}"
return name
def add_vararg_star(formatted_args: str) -> str:
last_arg = "*" + formatted_args.split(", ")[-1]
return ", ".join(formatted_args.split(", ")[:-1] + [last_arg])


def type_annotation(param: Dtype, specialize_generic: bool) -> str:
if not specialize_generic or isinstance(param, Template):
def type_annotation(dtype: Dtype, specialize_generic: bool) -> str:
if (not specialize_generic and not dtype.const) or isinstance(dtype, Tvar):
return "ColExpr"
if param.const:
python_type = pdt_type_to_python(param)
if dtype.const:
python_type = pdt_type_to_python(dtype)
return python_type.__name__ if python_type is not NoneType else "None"
return f"ColExpr[{param.__class__.__name__}]"
return f"ColExpr[{dtype.__class__.__name__}]"


def generate_fn_decl(
Expand All @@ -53,19 +53,24 @@ def generate_fn_decl(
name = op.name

defaults: Iterable = (
op.defaults if op.defaults is not None else (... for _ in op.arg_names)
op.default_values
if op.default_values is not None
else (... for _ in op.param_names)
)

annotated_args = ", ".join(
f"{format_param(name, param)}: "
+ type_annotation(param, specialize_generic)
name
+ ": "
+ type_annotation(dtype, specialize_generic)
+ (f" = {default_val}" if default_val is not ... else "")
for param, name, default_val in zip(
sig.params, op.arg_names, defaults, strict=True
for dtype, name, default_val in zip(
sig.types, op.param_names, defaults, strict=True
)
)
if sig.is_vararg:
annotated_args = add_vararg_star(annotated_args)

if op.context_kwargs is not None:
if len(op.context_kwargs) > 0:
context_kwarg_annotation = {
"partition_by": "Col | ColName | Iterable[Col | ColName]",
"arrange": "ColExpr | Iterable[ColExpr]",
Expand All @@ -77,9 +82,9 @@ def generate_fn_decl(
for kwarg in op.context_kwargs
)

if len(sig.params) == 0 or not sig.params[-1].vararg:
if len(sig.types) == 0 or not sig.is_vararg:
annotated_kwargs = "*" + annotated_kwargs
if len(sig.params) > 0:
if len(sig.types) > 0:
annotated_kwargs = ", " + annotated_kwargs
else:
annotated_kwargs = ""
Expand All @@ -93,33 +98,33 @@ def generate_fn_decl(
def generate_fn_body(
op: Operator,
sig: Signature,
arg_names: list[str] | None = None,
param_names: list[str] | None = None,
*,
op_var_name: str,
rversion: bool = False,
):
if arg_names is None:
arg_names = op.arg_names
if param_names is None:
param_names = op.param_names

if rversion:
assert len(arg_names) == 2
assert not any(param.vararg for param in sig.params)
arg_names = list(reversed(arg_names))
assert len(param_names) == 2
assert not sig.is_vararg
param_names = list(reversed(param_names))

args = "".join(
f", {format_param(name, param)}"
for param, name in zip(sig.params, arg_names, strict=True)
)
args = "".join(f", {name}" for name in param_names)
if sig.is_vararg:
args = add_vararg_star(args)

if op.context_kwargs is not None:
kwargs = "".join(f", {kwarg}={kwarg}" for kwarg in op.context_kwargs)
else:
kwargs = ""

return f' return ColFn("{op.name}"{args}{kwargs})\n\n'
return f" return ColFn(ops.{op_var_name}{args}{kwargs})\n\n"


def generate_overloads(
op: Operator, *, name: str | None = None, rversion: bool = False
op: Operator, *, name: str | None = None, rversion: bool = False, op_var_name: str
):
res = ""
in_namespace = "." in op.name
Expand All @@ -129,22 +134,16 @@ def generate_overloads(
has_overloads = len(op.signatures) > 1
if has_overloads:
for sig in op.signatures:
res += (
"@overload\n"
+ generate_fn_decl(op, Signature.parse(sig), name=name)
+ " ...\n\n"
)
res += "@overload\n" + generate_fn_decl(op, sig, name=name) + " ...\n\n"

res += generate_fn_decl(
op,
Signature.parse(op.signatures[0]),
name=name,
specialize_generic=not has_overloads,
op, op.signatures[0], name=name, specialize_generic=not has_overloads
) + generate_fn_body(
op,
Signature.parse(op.signatures[0]),
["self.arg"] + op.arg_names[1:] if in_namespace else None,
op.signatures[0],
["self.arg"] + op.param_names[1:] if in_namespace else None,
rversion=rversion,
op_var_name=op_var_name,
)

return res
Expand All @@ -154,7 +153,7 @@ def indent(s: str, by: int) -> str:
return "".join(" " * by + line + "\n" for line in s.split("\n"))


with open(col_expr_path, "r+") as file:
with open(COL_EXPR_PATH, "r+") as file:
new_file_contents = ""
in_col_expr_class = False
in_generated_section = False
Expand All @@ -163,7 +162,7 @@ def indent(s: str, by: int) -> str:
"@dataclasses.dataclass(slots=True)\n"
f"class {name.title()}Namespace(FnNamespace):\n"
)
for name in namespaces
for name in NAMESPACES
}

for line in file:
Expand All @@ -172,15 +171,18 @@ def indent(s: str, by: int) -> str:
elif not in_generated_section and line.startswith(" @overload"):
in_generated_section = True
elif in_col_expr_class and line.startswith("class Col"):
for op_name in sorted(PolarsImpl.registry.ALL_REGISTERED_OPS):
op = PolarsImpl.registry.get_op(op_name)
if isinstance(op, NoExprMethod):
for op_var_name in sorted(ops.__dict__):
op = ops.__dict__[op_var_name]
if not isinstance(op, Operator) or not op.generate_expr_method:
continue

op_overloads = generate_overloads(op)
if op_name in rversions:
op_overloads = generate_overloads(op, op_var_name=op_var_name)
if op.name in RVERSIONS:
op_overloads += generate_overloads(
op, name=f"__r{op_name[2:]}", rversion=True
op,
name=f"__r{op.name[2:]}",
rversion=True,
op_var_name=op_var_name,
)

op_overloads = indent(op_overloads, 4)
Expand All @@ -190,7 +192,7 @@ def indent(s: str, by: int) -> str:
else:
new_file_contents += op_overloads

for name in namespaces:
for name in NAMESPACES:
new_file_contents += (
" @property\n"
f" def {name}(self):\n"
Expand All @@ -203,7 +205,7 @@ def indent(s: str, by: int) -> str:
" arg: ColExpr\n"
)

for name in namespaces:
for name in NAMESPACES:
new_file_contents += namespace_contents[name]

in_generated_section = False
Expand All @@ -216,27 +218,23 @@ def indent(s: str, by: int) -> str:
file.write(new_file_contents)
file.truncate()

os.system(f"ruff format {col_expr_path}")
os.system(f"ruff format {COL_EXPR_PATH}")


with open(fns_path, "r+") as file:
with open(FNS_PATH, "r+") as file:
new_file_contents = ""
display_name = {"hmin": "min", "hmax": "max"}

for line in file:
new_file_contents += line
if line.startswith(" return LiteralCol"):
for op_name in sorted(PolarsImpl.registry.ALL_REGISTERED_OPS):
op = PolarsImpl.registry.get_op(op_name)
if not isinstance(op, NoExprMethod):
continue

new_file_contents += generate_overloads(
op, name=display_name.get(op_name)
)
for op_var_name in sorted(ops.__dict__):
op = ops.__dict__[op_var_name]
if isinstance(op, Operator) and not op.generate_expr_method:
new_file_contents += generate_overloads(op, op_var_name=op_var_name)
break

file.seek(0)
file.write(new_file_contents)
file.truncate()

os.system(f"ruff format {fns_path}")
os.system(f"ruff format {FNS_PATH}")
51 changes: 24 additions & 27 deletions src/pydiverse/transform/_internal/backend/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,82 +4,79 @@
import sqlalchemy as sqa
from sqlalchemy.sql.type_api import TypeEngine as TypeEngine

from pydiverse.transform._internal import ops
from pydiverse.transform._internal.backend import sql
from pydiverse.transform._internal.backend.sql import SqlImpl
from pydiverse.transform._internal.backend.targets import Polars, Target
from pydiverse.transform._internal.tree import dtypes, verbs
from pydiverse.transform._internal.ops import ops
from pydiverse.transform._internal.ops.ops.aggregation import Any
from pydiverse.transform._internal.tree import types, verbs
from pydiverse.transform._internal.tree.ast import AstNode
from pydiverse.transform._internal.tree.col_expr import Cast, Col, ColFn, LiteralCol


class DuckDbImpl(SqlImpl):
@classmethod
def export(cls, nd: AstNode, target: Target, final_select: list[Col]):
def export(
cls,
nd: AstNode,
target: Target,
final_select: list[Col],
schema_overrides: dict[str, Any],
):
# insert casts after sum() over integer columns (duckdb converts them to floats)
for desc in nd.iter_subtree():
if isinstance(desc, verbs.Verb):
desc.map_col_nodes(
lambda u: Cast(u, dtypes.Int64())
lambda u: Cast(u, types.Int64())
if isinstance(u, ColFn)
and u.name == "sum"
and u.dtype() == dtypes.Int64
and u.dtype() == types.Int64
else u
)

if isinstance(target, Polars):
engine = sql.get_engine(nd)
with engine.connect() as conn:
return pl.read_database(
DuckDbImpl.build_query(nd, final_select), connection=conn
DuckDbImpl.build_query(nd, final_select),
connection=conn,
schema_overrides=schema_overrides,
)
return SqlImpl.export(nd, target, final_select)
return SqlImpl.export(nd, target, final_select, schema_overrides)

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int64:
if cast.val.dtype() == types.Float64 and cast.target_type == types.Int64:
return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast(
sqa.BigInteger()
)
return super().compile_cast(cast, sqa_col)

@classmethod
def compile_lit(cls, lit: LiteralCol) -> sqa.ColumnElement:
if lit.dtype() == dtypes.Int64:
if lit.dtype() == types.Int64:
return sqa.cast(lit.val, sqa.BigInteger)
return super().compile_lit(lit)


with DuckDbImpl.op(ops.FloorDiv()) as op:
with DuckDbImpl.impl_store.impl_manager as impl:

@op.auto
@impl(ops.floordiv)
def _floordiv(lhs, rhs):
return sqa.func.divide(lhs, rhs)


with DuckDbImpl.op(ops.IsInf()) as op:

@op.auto
@impl(ops.is_inf)
def _is_inf(x):
return sqa.func.isinf(x)


with DuckDbImpl.op(ops.IsNotInf()) as op:

@op.auto
@impl(ops.is_not_inf)
def _is_not_inf(x):
return sqa.func.isfinite(x)


with DuckDbImpl.op(ops.IsNan()) as op:

@op.auto
@impl(ops.is_nan)
def _is_nan(x):
return sqa.func.isnan(x)


with DuckDbImpl.op(ops.IsNotNan()) as op:

@op.auto
@impl(ops.is_not_nan)
def _is_not_nan(x):
return ~sqa.func.isnan(x)
Loading
Loading