Skip to content

Commit

Permalink
Upgrade to pydantic v2
Browse files Browse the repository at this point in the history
Pydantic used to be in 1.10.5 now moving to up to v2 accepting from v2.x to the latest

Replaced `__get_pydantic_core_schema__` with a more efficient schema
handling using `core_schema.str_schema()` and custom validation for ItemHash.

- Updated code to explicitly specify optional keys where necessary.
- Replaced direct `.get` calls with `data.get()` to handle new validation logic.
- Migrated model configuration to use `model_config = ConfigDict(extra="forbid")`
or `model_config = ConfigDict(extra="allow")` in place of Pydantic v1's configuration style.
- Fix: Refactor to use `model_dump` and `model_dump_json` in place of deprecated methods.
- Replaced `.dict()` with `.model_dump()` for model serialization.
- Replaced deprecated `.json()` with `.model_dump_json()` for JSON serialization.

---------

Co-authored-by: Hugo Herter <[email protected]>
  • Loading branch information
Antonyjin and hoh authored Oct 16, 2024
1 parent 934c8a2 commit 8a62363
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 124 deletions.
104 changes: 57 additions & 47 deletions aleph_message/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import datetime
import json
import logging
from copy import copy
from hashlib import sha256
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast

from pydantic import BaseModel, Extra, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypeAlias

from .abstract import BaseContent, HashableModel
Expand All @@ -16,6 +17,10 @@
from .execution.program import ProgramContent
from .item_hash import ItemHash, ItemType

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


__all__ = [
"AggregateContent",
"AggregateMessage",
Expand Down Expand Up @@ -54,8 +59,7 @@ class MongodbId(BaseModel):

oid: str = Field(alias="$oid")

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class ChainRef(BaseModel):
Expand All @@ -76,8 +80,7 @@ class MessageConfirmationHash(BaseModel):
binary: str = Field(alias="$binary")
type: str = Field(alias="$type")

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class MessageConfirmation(BaseModel):
Expand All @@ -93,15 +96,13 @@ class MessageConfirmation(BaseModel):
default=None, description="The address that published the transaction."
)

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class AggregateContentKey(BaseModel):
name: str

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class PostContent(BaseContent):
Expand All @@ -116,16 +117,15 @@ class PostContent(BaseContent):
)
type: str = Field(description="User-generated 'content-type' of a POST message")

@validator("type")
@field_validator("type")
def check_type(cls, v, values):
if v == "amend":
ref = values.get("ref")
ref = values.data.get("ref")
if not ref:
raise ValueError("A 'ref' is required for POST type 'amend'")
return v

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class AggregateContent(BaseContent):
Expand All @@ -136,8 +136,7 @@ class AggregateContent(BaseContent):
)
content: Dict = Field(description="The content of an aggregate must be a dict")

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class StoreContent(BaseContent):
Expand All @@ -148,10 +147,11 @@ class StoreContent(BaseContent):
size: Optional[int] = None # Generated by the node on storage
content_type: Optional[str] = None # Generated by the node on storage
ref: Optional[str] = None
metadata: Optional[Dict[str, Any]] = Field(description="Metadata of the VM")
metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Metadata of the VM"
)

class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")


class ForgetContent(BaseContent):
Expand Down Expand Up @@ -214,9 +214,9 @@ class BaseMessage(BaseModel):

forgotten_by: Optional[List[str]]

@validator("item_content")
@field_validator("item_content")
def check_item_content(cls, v: Optional[str], values) -> Optional[str]:
item_type = values["item_type"]
item_type = values.data.get("item_type")
if v is None:
return None
elif item_type == ItemType.inline:
Expand All @@ -232,14 +232,14 @@ def check_item_content(cls, v: Optional[str], values) -> Optional[str]:
)
return v

@validator("item_hash")
@field_validator("item_hash")
def check_item_hash(cls, v: ItemHash, values) -> ItemHash:
item_type = values["item_type"]
item_type = values.data.get("item_type")
if item_type == ItemType.inline:
item_content: str = values["item_content"]
item_content: str = values.data.get("item_content")

# Double check that the hash function is supported
hash_type = values["hash_type"] or HashType.sha256
hash_type = values.data.get("hash_type") or HashType.sha256
assert hash_type.value == HashType.sha256

computed_hash: str = sha256(item_content.encode()).hexdigest()
Expand All @@ -255,49 +255,56 @@ def check_item_hash(cls, v: ItemHash, values) -> ItemHash:
assert item_type == ItemType.storage
return v

@validator("confirmed")
@field_validator("confirmed")
def check_confirmed(cls, v, values):
confirmations = values["confirmations"]
confirmations = values.data.get("confirmations")
if v is True and not bool(confirmations):
raise ValueError("Message cannot be 'confirmed' without 'confirmations'")
return v

@validator("time")
@field_validator("time")
def convert_float_to_datetime(cls, v, values):
if isinstance(v, float):
v = datetime.datetime.fromtimestamp(v)
assert isinstance(v, datetime.datetime)
return v

class Config:
extra = Extra.forbid
exclude = {"id_", "_id"}
model_config = ConfigDict(extra="forbid")

def custom_dump(self):
"""Exclude MongoDB identifiers from dumps for historical reasons."""
return self.model_dump(exclude={"id_", "_id"})


class PostMessage(BaseMessage):
"""Unique data posts (unique data points, events, ...)"""

type: Literal[MessageType.post]
content: PostContent
forgotten_by: Optional[List[str]] = None


class AggregateMessage(BaseMessage):
"""A key-value storage specific to an address"""

type: Literal[MessageType.aggregate]
content: AggregateContent
forgotten_by: Optional[list] = None


class StoreMessage(BaseMessage):
type: Literal[MessageType.store]
content: StoreContent
forgotten_by: Optional[list] = None
metadata: Optional[Dict[str, Any]] = None


class ForgetMessage(BaseMessage):
type: Literal[MessageType.forget]
content: ForgetContent
forgotten_by: Optional[list] = None

@validator("forgotten_by")
@field_validator("forgotten_by")
def cannot_be_forgotten(cls, v: Optional[List[str]], values) -> Optional[List[str]]:
assert values
if v:
Expand All @@ -308,25 +315,29 @@ def cannot_be_forgotten(cls, v: Optional[List[str]], values) -> Optional[List[st
class ProgramMessage(BaseMessage):
type: Literal[MessageType.program]
content: ProgramContent
forgotten_by: Optional[List[str]] = None

@validator("content")
@field_validator("content")
def check_content(cls, v, values):
item_type = values["item_type"]
"""Ensure that the content of the message is correctly formatted."""
item_type = values.data.get("item_type")
if item_type == ItemType.inline:
item_content = json.loads(values["item_content"])
if v.dict(exclude_none=True) != item_content:
# Print differences
vdict = v.dict(exclude_none=True)
for key, value in item_content.items():
if vdict[key] != value:
print(f"{key}: {vdict[key]} != {value}")
# Ensure that the content correct JSON
item_content = json.loads(values.data.get("item_content"))
# Ensure that the content matches the expected structure
if v.model_dump(exclude_none=True) != item_content:
logger.warning(
"Content and item_content differ for message %s",
values.data["item_hash"],
)
raise ValueError("Content and item_content differ")
return v


class InstanceMessage(BaseMessage):
type: Literal[MessageType.instance]
content: InstanceContent
forgotten_by: Optional[List[str]] = None


AlephMessage: TypeAlias = Union[
Expand Down Expand Up @@ -363,12 +374,12 @@ def parse_message(message_dict: Dict) -> AlephMessage:
message_class.__annotations__["type"].__args__[0]
)
if message_dict["type"] == message_type:
return message_class.parse_obj(message_dict)
return message_class.model_validate(message_dict)
else:
raise ValueError(f"Unknown message type {message_dict['type']}")


def add_item_content_and_hash(message_dict: Dict, inplace: bool = False):
def add_item_content_and_hash(message_dict: Dict, inplace: bool = False) -> Dict:
if not inplace:
message_dict = copy(message_dict)

Expand All @@ -390,7 +401,7 @@ def create_new_message(
"""
message_content = add_item_content_and_hash(message_dict)
if factory:
return cast(T, factory.parse_obj(message_content))
return cast(T, factory.model_validate(message_content))
else:
return cast(T, parse_message(message_content))

Expand All @@ -405,7 +416,7 @@ def create_message_from_json(
message_dict = json.loads(json_data)
message_content = add_item_content_and_hash(message_dict, inplace=True)
if factory:
return factory.parse_obj(message_content)
return factory.model_validate(message_content)
else:
return parse_message(message_content)

Expand All @@ -422,7 +433,7 @@ def create_message_from_file(
message_dict = decoder.load(fd)
message_content = add_item_content_and_hash(message_dict, inplace=True)
if factory:
return factory.parse_obj(message_content)
return factory.model_validate(message_content)
else:
return parse_message(message_content)

Expand All @@ -436,5 +447,4 @@ class MessagesResponse(BaseModel):
pagination_per_page: int
pagination_item: str

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")
5 changes: 2 additions & 3 deletions aleph_message/models/abstract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Extra
from pydantic import BaseModel, ConfigDict


def hashable(obj):
Expand All @@ -24,5 +24,4 @@ class BaseContent(BaseModel):
address: str
time: float

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")
2 changes: 1 addition & 1 deletion aleph_message/models/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .abstract import BaseExecutableContent
from .base import Encoding, Interface, MachineType, Payment, PaymentType
from .instance import InstanceContent
from .program import ProgramContent
from .base import Encoding, MachineType, PaymentType, Payment, Interface

__all__ = [
"BaseExecutableContent",
Expand Down
2 changes: 1 addition & 1 deletion aleph_message/models/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Payment(HashableModel):

chain: Chain
"""Which chain to check for funds"""
receiver: Optional[str]
receiver: Optional[str] = None
"""Optional alternative address to send tokens to"""
type: PaymentType
"""Whether to pay by holding $ALEPH or by streaming tokens"""
Expand Down
Loading

0 comments on commit 8a62363

Please sign in to comment.