Skip to content

Commit

Permalink
Merge pull request #33 from pydiverse/codegen
Browse files Browse the repository at this point in the history
generate ColExpr methods and pdt. functions via pre-commit
  • Loading branch information
finn-rudolph authored Oct 15, 2024
2 parents 762999c + 1de0da7 commit 0047577
Show file tree
Hide file tree
Showing 33 changed files with 2,196 additions and 9,849 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ jobs:
- name: Checkout branch
uses: actions/checkout@v4

- name: Install Python 3.10
uses: actions/setup-python@v5
- name: Setup Pixi
uses: prefix-dev/setup-[email protected]
with:
python-version: "3.10"
environments: py310

- name: Linting - Run pre-commit checks
uses: pre-commit/[email protected]
run: pixi run postinstall && pixi run pre-commit run

test:
name: pytest
Expand Down
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
repos:
- repo: local
hooks:
- id: generate-col-ops
name: generate-col-ops
language: system
entry: python generate_col_ops.py
types: [python]
pass_filenames: false

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.5.7
hooks:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ cd pydiverse.transform

# Create the environment, activate it and install the pre-commit hooks
pixi install
pixi run postinstall
pixi run pre-commit install
```

Expand Down
242 changes: 242 additions & 0 deletions generate_col_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
from __future__ import annotations

import os
from collections.abc import Iterable
from types import NoneType

from pydiverse.transform._internal.backend.polars import PolarsImpl
from pydiverse.transform._internal.ops.core import NoExprMethod, Operator
from pydiverse.transform._internal.tree.dtypes import (
Dtype,
Template,
pdt_type_to_python,
)
from pydiverse.transform._internal.tree.registry import Signature

col_expr_path = "./src/pydiverse/transform/_internal/tree/col_expr.py"
fns_path = "./src/pydiverse/transform/_internal/pipe/functions.py"
reg = PolarsImpl.registry
namespaces = ["str", "dt"]
rversions = {
"__add__",
"__sub__",
"__mul__",
"__truediv__",
"__floordiv__",
"__pow__",
"__mod__",
"__and__",
"__or__",
"__xor__",
}


def format_param(name: str, dtype: Dtype) -> str:
if dtype.vararg:
return f"*{name}"
return name


def type_annotation(param: Dtype, specialize_generic: bool) -> str:
if not specialize_generic or isinstance(param, Template):
return "ColExpr"
if param.const:
python_type = pdt_type_to_python(param)
return python_type.__name__ if python_type is not NoneType else "None"
return f"ColExpr[{param.__class__.__name__}]"


def generate_fn_decl(
op: Operator, sig: Signature, *, name=None, specialize_generic: bool = True
) -> str:
if name is None:
name = op.name

defaults: Iterable = (
op.defaults if op.defaults is not None else (... for _ in op.arg_names)
)

annotated_args = ", ".join(
f"{format_param(name, param)}: "
+ type_annotation(param, specialize_generic)
+ (f" = {default_val}" if default_val is not ... else "")
for param, name, default_val in zip(
sig.params, op.arg_names, defaults, strict=True
)
)

if op.context_kwargs is not None:
context_kwarg_annotation = {
"partition_by": "Col | ColName | Iterable[Col | ColName]",
"arrange": "ColExpr | Iterable[ColExpr]",
"filter": "ColExpr[Bool] | Iterable[ColExpr[Bool]]",
}

annotated_kwargs = "".join(
f", {kwarg}: {context_kwarg_annotation[kwarg]} | None = None"
for kwarg in op.context_kwargs
)

if len(sig.params) == 0 or not sig.params[-1].vararg:
annotated_kwargs = "*" + annotated_kwargs
if len(sig.params) > 0:
annotated_kwargs = ", " + annotated_kwargs
else:
annotated_kwargs = ""

return (
f"def {name}({annotated_args}{annotated_kwargs}) "
f"-> {type_annotation(sig.return_type, specialize_generic)}:\n"
)


def generate_fn_body(
op: Operator,
sig: Signature,
arg_names: list[str] | None = None,
*,
rversion: bool = False,
):
if arg_names is None:
arg_names = op.arg_names

if rversion:
assert len(arg_names) == 2
assert not any(param.vararg for param in sig.params)
arg_names = list(reversed(arg_names))

args = "".join(
f", {format_param(name, param)}"
for param, name in zip(sig.params, arg_names, strict=True)
)

if op.context_kwargs is not None:
kwargs = "".join(f", {kwarg}={kwarg}" for kwarg in op.context_kwargs)
else:
kwargs = ""

return f' return ColFn("{op.name}"{args}{kwargs})\n\n'


def generate_overloads(
op: Operator, *, name: str | None = None, rversion: bool = False
):
res = ""
in_namespace = "." in op.name
if name is None:
name = op.name if not in_namespace else op.name.split(".")[1]

has_overloads = len(op.signatures) > 1
if has_overloads:
for sig in op.signatures:
res += (
"@overload\n"
+ generate_fn_decl(op, Signature.parse(sig), name=name)
+ " ...\n\n"
)

res += generate_fn_decl(
op,
Signature.parse(op.signatures[0]),
name=name,
specialize_generic=not has_overloads,
) + generate_fn_body(
op,
Signature.parse(op.signatures[0]),
["self.arg"] + op.arg_names[1:] if in_namespace else None,
rversion=rversion,
)

return res


def indent(s: str, by: int) -> str:
return "".join(" " * by + line + "\n" for line in s.split("\n"))


with open(col_expr_path, "r+") as file:
new_file_contents = ""
in_col_expr_class = False
in_generated_section = False
namespace_contents: dict[str, str] = {
name: (
"@dataclasses.dataclass(slots=True)\n"
f"class {name.title()}Namespace(FnNamespace):\n"
)
for name in namespaces
}

for line in file:
if line.startswith("class ColExpr"):
in_col_expr_class = True
elif not in_generated_section and line.startswith(" @overload"):
in_generated_section = True
elif in_col_expr_class and line.startswith("class Col"):
for op_name in sorted(PolarsImpl.registry.ALL_REGISTERED_OPS):
op = PolarsImpl.registry.get_op(op_name)
if isinstance(op, NoExprMethod):
continue

op_overloads = generate_overloads(op)
if op_name in rversions:
op_overloads += generate_overloads(
op, name=f"__r{op_name[2:]}", rversion=True
)

op_overloads = indent(op_overloads, 4)

if "." in op.name:
namespace_contents[op.name.split(".")[0]] += op_overloads
else:
new_file_contents += op_overloads

for name in namespaces:
new_file_contents += (
" @property\n"
f" def {name}(self):\n"
f" return {name.title()}Namespace(self)\n\n"
)

new_file_contents += (
"@dataclasses.dataclass(slots=True)\n"
"class FnNamespace:\n"
" arg: ColExpr\n"
)

for name in namespaces:
new_file_contents += namespace_contents[name]

in_generated_section = False
in_col_expr_class = False

if not in_generated_section:
new_file_contents += line

file.seek(0)
file.write(new_file_contents)
file.truncate()

os.system(f"ruff format {col_expr_path}")

with open(fns_path, "r+") as file:
new_file_contents = ""
display_name = {"hmin": "min", "hmax": "max"}

for line in file:
new_file_contents += line
if line.startswith(" return LiteralCol"):
for op_name in sorted(PolarsImpl.registry.ALL_REGISTERED_OPS):
op = PolarsImpl.registry.get_op(op_name)
if not isinstance(op, NoExprMethod):
continue

new_file_contents += generate_overloads(
op, name=display_name.get(op_name)
)
break

file.seek(0)
file.write(new_file_contents)
file.truncate()

os.system(f"ruff format {fns_path}")
Loading

0 comments on commit 0047577

Please sign in to comment.