From 2742f440bf6ce462445e4206b70a693db0854507 Mon Sep 17 00:00:00 2001 From: Tyler White <50381805+IndexSeek@users.noreply.github.com> Date: Fri, 27 Dec 2024 21:46:41 +0000 Subject: [PATCH] feat(polars): add Intersection and Difference ops --- ibis/backends/polars/compiler.py | 38 +++++++++++++++++++++++++++++ ibis/backends/tests/test_set_ops.py | 4 --- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 04c8a8cc928d..13612de1ff21 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -1256,6 +1256,44 @@ def execute_union(op, **kw): return result +@translate.register(ops.Intersection) +def execute_intersection(op, **kw): + with pl.SQLContext( + frames={ + "left": translate(op.left, **kw), + "right": translate(op.right, **kw), + } + ) as ctx: + sql = ( + sg.select(STAR) + .from_(sg.to_identifier("left", quoted=True)) + .intersect(sg.select(STAR).from_(sg.to_identifier("right", quoted=True))) + ) + result = ctx.execute(sql.sql()) + if op.distinct is True: + return result.unique() + return result + + +@translate.register(ops.Difference) +def execute_difference(op, **kw): + with pl.SQLContext( + frames={ + "left": translate(op.left, **kw), + "right": translate(op.right, **kw), + } + ) as ctx: + sql = ( + sg.select(STAR) + .from_(sg.to_identifier("left", quoted=True)) + .except_(sg.select(STAR).from_(sg.to_identifier("right", quoted=True))) + ) + result = ctx.execute(sql.sql()) + if op.distinct is True: + return result.unique() + return result + + @translate.register(ops.Hash) def execute_hash(op, **kw): # polars' hash() returns a uint64, but we want to return an int64 diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 64467b067012..c459a4055346 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -6,7 +6,6 @@ from pytest import param import ibis -import ibis.common.exceptions as com import ibis.expr.types as ir from ibis import _ from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError @@ -84,7 +83,6 @@ def test_union_mixed_distinct(backend, union_subsets): param(True, id="distinct"), ], ) -@pytest.mark.notimpl(["polars"]) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) def test_intersect(backend, alltypes, df, distinct): a = alltypes.filter((_.id >= 5200) & (_.id <= 5210)) @@ -129,7 +127,6 @@ def test_intersect(backend, alltypes, df, distinct): param(True, id="distinct"), ], ) -@pytest.mark.notimpl(["polars"]) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) def test_difference(backend, alltypes, df, distinct): a = alltypes.filter((_.id >= 5200) & (_.id <= 5210)) @@ -238,7 +235,6 @@ def test_top_level_union(backend, con, alltypes, distinct, ordered): ), ], ) -@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) def test_top_level_intersect_difference( backend, con, alltypes, distinct, opname, expected, ordered