Skip to content

Commit

Permalink
Merge pull request #47 from PnX-SI/refactor_smartrelashionshipsmixin
Browse files Browse the repository at this point in the history
Refactor smartrelashionshipsmixin
  • Loading branch information
camillemonchicourt authored Dec 1, 2023
2 parents 0197e64 + a665b50 commit c809083
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 22 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@ jobs:

strategy:
matrix:
python-version: [ '3.7', '3.9' ]
python-version: ['3.9']
sqlalchemy-version: [ '1.3', '1.4' ]
include:
- sqlalchemy-version: '1.3'
sqlalchemy-lt-version: '1.4'
flask-sqlalchemy-version: '2.0'
flask-sqlalchemy-lt-version: '3.0'
flask-version: '2'
flask-lt-version: '3'
- sqlalchemy-version: '1.4'
sqlalchemy-lt-version: '2.0'
flask-sqlalchemy-version: '3.0'
flask-sqlalchemy-lt-version: '4.0'
flask-version: '3'
flask-lt-version: '4'

name: Python ${{ matrix.python-version }} - SQLAlchemy ${{ matrix.sqlalchemy-version }}

Expand All @@ -43,7 +47,8 @@ jobs:
python -m pip install --upgrade pip
python -m pip install -e .[tests] pytest-cov \
"sqlalchemy>=${{ matrix.sqlalchemy-version }},<${{ matrix.sqlalchemy-lt-version }}" \
"flask-sqlalchemy>=${{ matrix.flask-sqlalchemy-version }},<${{ matrix.flask-sqlalchemy-lt-version }}"
"flask-sqlalchemy>=${{ matrix.flask-sqlalchemy-version }},<${{ matrix.flask-sqlalchemy-lt-version }}" \
"flask>=${{ matrix.flask-version }},<${{ matrix.flask-lt-version }}"
- name: Test with pytest
run: |
Expand Down
36 changes: 30 additions & 6 deletions src/utils_flask_sqla/schema.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
from marshmallow.fields import Nested
from marshmallow_sqlalchemy.fields import RelatedList, Related

# from flask_marshmallow.fields import RelatedList


class SmartRelationshipsMixin:
"""
This mixin automatically exclude from serialization:
- Nested fields
- all fields with exclude=True in their metadata (e.g. fields.String(metadata={'exclude': True}))
Adding Nested fields to only will serialize defaults fields and specified Nested fields.
Adding exclude=True fields to only will serialize only specified fields (default marshmallow behaviour).
You can use '+field_name' syntax to serialize excluded fields without excluding defaults fields.
* Nested, RelatedList and Related fields
* all fields with exclude=True in their metadata (e.g. ``fields.String(metadata={'exclude': True})``)
Adding only Nested fields to ``only`` will not exclude others fields and serialize specified Nested fields.
Adding exclude=True fields to ``only`` will serialize only specified fields (default marshmallow behaviour).
You can use '+field_name' syntax on `only` to serialize default excluded fields (with metadata exclude = True) without other fields.
Examples :
.. code-block:: python
class FooSchema(SmartRelationshipsMixin):
id = fields.Int()
name = field.Str()
default_excluded_field = fields.Str(metadata={"exclude": True})
relationship = fields.Nested(OtherSchema) # or field.RelatedList() / field.Related()
FooSchema().dump() -> {"id": 1, "name": "toto" }
FooSchema(only=["+default_excluded_field"]).dump() -> {"id": 1, "name": "toto", default_excluded_field: "test" }
FooSchema(only=["relationship"]).dump() -> {"id": 1, "name": "toto", relationship : {OtherSchema...} }
FooSchema(only=["id", "relationship"]).dump() -> {"id": 1, relationship : {OtherSchema...} }
"""

def __init__(self, *args, **kwargs):
Expand All @@ -19,7 +38,11 @@ def __init__(self, *args, **kwargs):
# excluded fields at meta level are not even generated by auto-schema
if field is None:
continue
if isinstance(field, Nested):
if (
isinstance(field, Nested)
or isinstance(field, RelatedList)
or isinstance(field, Related)
):
nested_fields.add(name)
elif field.metadata.get("exclude", False):
excluded_fields.add(name)
Expand All @@ -40,6 +63,7 @@ def __init__(self, *args, **kwargs):
exclude = kwargs.pop("exclude", None)
exclude = set(exclude) if exclude is not None else set()
exclude |= (excluded_fields | nested_fields) - firstlevel_only

# If only contains only nested & additional fields, we need to add included_fields to serialize nested, additional & included fields.
# If only does not contains nested or additional fields, we do nothing and marshmallow will serialize only specified fields.
if only and not firstlevel_only - nested_fields - additional_fields:
Expand Down
60 changes: 46 additions & 14 deletions src/utils_flask_sqla/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,35 @@ class Parent(db.Model):
col = Column(String)


cor_hobby_child = db.Table(
"cor_hobby_child",
db.Column("id_child", db.Integer, ForeignKey("child.pk")),
db.Column("id_hobby", db.Integer, ForeignKey("hobby.pk")),
)


class Hobby(db.Model):
__tablename__ = "hobby"
pk = Column(Integer, primary_key=True)
name = Column(Integer)


class Address(db.Model):
__tablename__ = "address"
pk = Column(Integer, primary_key=True)
street = Column(Integer)
city = Column(Integer)


class Child(db.Model):
__tablename__ = "child"
pk = Column(Integer, primary_key=True)
col = Column(String)
parent_pk = Column(Integer, ForeignKey(Parent.pk))
address_pk = Column(Integer, ForeignKey(Address.pk))
parent = relationship("Parent", backref="childs")
hobbies = relationship(Hobby, secondary=cor_hobby_child)
address = relationship(Address)


class ParentSchema(SmartRelationshipsMixin, SQLAlchemyAutoSchema):
Expand All @@ -34,18 +58,34 @@ class Meta:
childs = Nested("ChildSchema", many=True)


class HobbySchema(SQLAlchemyAutoSchema):
class Meta:
model = Hobby


class AdressSchema(SQLAlchemyAutoSchema):
class Meta:
model = Address


class ChildSchema(SmartRelationshipsMixin, SQLAlchemyAutoSchema):
class Meta:
model = Child
include_fk = True

parent = Nested(ParentSchema)
hobbies = (
auto_field()
) # For a n-n relationship a RelatedList field is created by marshmallow_sqalchemy
address = auto_field()


class TestSmartRelationshipsMixin:
def test_only(self):
parent = Parent(pk=1, col="p")
child = Child(pk=1, col="c", parent_pk=1, parent=parent)
child = Child(pk=1, col="c", parent_pk=1, address_pk=1, parent=parent)
child.hobbies = [Hobby(pk=1, name="Tennis"), Hobby(pk=2, name="petanque")]
child.address = Address(pk=1, street="5th avenue", city="New-York")
parent.childs = [child]

TestCase().assertDictEqual(
Expand All @@ -58,20 +98,16 @@ def test_only(self):

TestCase().assertDictEqual(
ChildSchema().dump(child),
{
"pk": 1,
"col": "c",
"parent_pk": 1,
},
{"pk": 1, "col": "c", "parent_pk": 1, "address_pk": 1},
)

TestCase().assertDictEqual(
ParentSchema(only=["childs"]).dump(parent),
ParentSchema(only=["childs", "childs.hobbies"]).dump(parent),
{
"pk": 1,
"col": "p",
"childs": [
{"pk": 1, "col": "c", "parent_pk": 1},
{"pk": 1, "col": "c", "parent_pk": 1, "address_pk": 1, "hobbies": [1, 2]},
],
},
)
Expand All @@ -97,6 +133,7 @@ def test_only(self):
"pk": 1,
"col": "c",
"parent_pk": 1,
"address_pk": 1,
"parent": {
"pk": 1,
"col": "p",
Expand Down Expand Up @@ -176,12 +213,7 @@ def test_null_relationship(self):

TestCase().assertDictEqual(
ChildSchema(only=("parent",)).dump(child),
{
"pk": 1,
"col": None,
"parent_pk": None,
"parent": None,
},
{"pk": 1, "col": None, "parent_pk": None, "parent": None, "address_pk": None},
)

def test_polymorphic_model(self):
Expand Down

0 comments on commit c809083

Please sign in to comment.