Skip to content

Commit

Permalink
add test for qfilter + replace return_query by query + decorator is n…
Browse files Browse the repository at this point in the history
…ow usable without args
  • Loading branch information
jacquesfize committed Dec 8, 2023
1 parent 08d8134 commit 7637cf2
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 10 deletions.
32 changes: 22 additions & 10 deletions src/utils_flask_sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
def qfilter(*args_dec, **kwargs_dec):
"""
This decorator allows you to constrain a SQLAlchemy model method to return a whereclause (by default) or a query. If
its return_query is set to True and no query is given in a `query` parameter, it will create one with a simple select: `select(model)`. The latter
its `query` is set to True and no query is given in a `query` parameter, it will create one with a simple select: `select(model)`. The latter
is accessible through `kwargs.get("query")` in the decorated method.
The decorated query requires the following minimum parameters (cls,**kwargs).
>>> from utils_flask_sqla.models import qfilter
Expand All @@ -20,12 +19,12 @@ def qfilter(*args_dec, **kwargs_dec):
# If you wish the method to return a whereclause
@qfilter
def filter_by_params(cls,**kwargs):
query = kwargs("query") # select(Station)
filters = []
if "id_station" in kwargs:
query = query.filter_by(id_station=kwargs["id_station"])
filters.append(Station.id_station == kwargs["id_station"])
return query.whereclause
# If you wish the method to return a query
@qfilter(return_query=True)
@qfilter(query=True)
def filter_by_paramsQ(cls,**kwargs):
query = kwargs("query") # select(Station)
if "id_station" in kwargs:
Expand All @@ -35,6 +34,11 @@ def filter_by_paramsQ(cls,**kwargs):
>>> query = Station.filter_by_paramsQ(id_station=1)
>>> query2 = select(Station).where(Station.filter_by_params(id_station=1))
Parameters
----------
query : bool
decorated function must (or not) return a query (Select)
Returns
-------
function
Expand All @@ -45,13 +49,21 @@ def filter_by_paramsQ(cls,**kwargs):
ValueError
Method's class is not DefaultMeta class
ValueError
if return_query is True and return value of the decorated method is not Select
if query is True and return value of the decorated method is not Select
ValueError
if return_query is False and return value of the decorated method is not a : `bool` or sqlalchemy.sql.expression.BooleanClauseList` or `sqlalchemy.sql.expression.BinaryExpression`
if query is False and return value of the decorated method is not a : `bool` or sqlalchemy.sql.expression.BooleanClauseList` or `sqlalchemy.sql.expression.BinaryExpression`
"""
is_query = kwargs_dec.get("return_query", False)
if len(args_dec) == 1 and len(kwargs_dec) == 0 and callable(args_dec[0]):
return _qfilter()(args_dec[0])
else:
return _qfilter(*args_dec, **kwargs_dec)


def _qfilter(*args_dec, **kwargs_dec):
is_query = kwargs_dec.get("query", False)

def _qfilter(method):
def _qfilter_decorator(method):
def _(*args, **kwargs):
# verify if class of the method is ORM model
sqla_class = args[0]
Expand Down Expand Up @@ -85,4 +97,4 @@ def _(*args, **kwargs):

return classmethod(_)

return _qfilter
return _qfilter_decorator
70 changes: 70 additions & 0 deletions src/utils_flask_sqla/tests/test_qfilter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from flask import Flask
from sqlalchemy import func

from flask_sqlalchemy import SQLAlchemy

from utils_flask_sqla.models import qfilter


db = SQLAlchemy()


class FooModel(db.Model):
pk = db.Column(db.Integer, primary_key=True)


class BarModel(db.Model):
pk = db.Column(db.Integer, primary_key=True)

@qfilter
def where_pk(cls, pk, **kwargs):
return BarModel.pk == pk

@qfilter(query=True)
def where_pk_query(cls, pk, **kwargs):
query = kwargs["query"]
return query.where(BarModel.pk == pk)


@pytest.fixture(scope="session")
def app():
app = Flask("utils-flask-sqla")
app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///"
db.init_app(app)
with app.app_context():
db.create_all()
yield app


@pytest.fixture(scope="session")
def foo(app):
foo = FooModel()
db.session.add(foo)
db.session.commit()
return foo


@pytest.fixture(scope="session")
def bar(app):
bar = BarModel()
db.session.add(bar)
db.session.commit()
return bar


class TestQfilter:
def test_qfilter_returns_whereclause(self, bar):
assert db.session.scalars(BarModel.where_pk_query(bar.pk)).one_or_none() is bar
assert (
db.session.scalars(db.select(BarModel).where(BarModel.where_pk(bar.pk))).one_or_none()
is bar
)

assert db.session.scalars(BarModel.where_pk_query(bar.pk + 1)).one_or_none() is not bar
assert (
db.session.scalars(
db.select(BarModel).where(BarModel.where_pk(bar.pk + 1))
).one_or_none()
is not bar
)

0 comments on commit 7637cf2

Please sign in to comment.