Skip to content

Commit

Permalink
feat(polars): add Intersection and Difference ops
Browse files Browse the repository at this point in the history
  • Loading branch information
IndexSeek committed Dec 27, 2024
1 parent 28bafd1 commit 2742f44
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
38 changes: 38 additions & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,44 @@ def execute_union(op, **kw):
return result


@translate.register(ops.Intersection)
def execute_intersection(op, **kw):
with pl.SQLContext(
frames={
"left": translate(op.left, **kw),
"right": translate(op.right, **kw),
}
) as ctx:
sql = (
sg.select(STAR)
.from_(sg.to_identifier("left", quoted=True))
.intersect(sg.select(STAR).from_(sg.to_identifier("right", quoted=True)))
)
result = ctx.execute(sql.sql())
if op.distinct is True:
return result.unique()
return result


@translate.register(ops.Difference)
def execute_difference(op, **kw):
with pl.SQLContext(
frames={
"left": translate(op.left, **kw),
"right": translate(op.right, **kw),
}
) as ctx:
sql = (
sg.select(STAR)
.from_(sg.to_identifier("left", quoted=True))
.except_(sg.select(STAR).from_(sg.to_identifier("right", quoted=True)))
)
result = ctx.execute(sql.sql())
if op.distinct is True:
return result.unique()
return result


@translate.register(ops.Hash)
def execute_hash(op, **kw):
# polars' hash() returns a uint64, but we want to return an int64
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pytest import param

import ibis
import ibis.common.exceptions as com
import ibis.expr.types as ir
from ibis import _
from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError
Expand Down Expand Up @@ -84,7 +83,6 @@ def test_union_mixed_distinct(backend, union_subsets):
param(True, id="distinct"),
],
)
@pytest.mark.notimpl(["polars"])
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_intersect(backend, alltypes, df, distinct):
a = alltypes.filter((_.id >= 5200) & (_.id <= 5210))
Expand Down Expand Up @@ -129,7 +127,6 @@ def test_intersect(backend, alltypes, df, distinct):
param(True, id="distinct"),
],
)
@pytest.mark.notimpl(["polars"])
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_difference(backend, alltypes, df, distinct):
a = alltypes.filter((_.id >= 5200) & (_.id <= 5210))
Expand Down Expand Up @@ -238,7 +235,6 @@ def test_top_level_union(backend, con, alltypes, distinct, ordered):
),
],
)
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_top_level_intersect_difference(
backend, con, alltypes, distinct, opname, expected, ordered
Expand Down

0 comments on commit 2742f44

Please sign in to comment.