diff --git a/aleph_message/models/__init__.py b/aleph_message/models/__init__.py index 10c1925..6d019af 100644 --- a/aleph_message/models/__init__.py +++ b/aleph_message/models/__init__.py @@ -1,12 +1,13 @@ import datetime import json +import logging from copy import copy from hashlib import sha256 from json import JSONDecodeError from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast -from pydantic import BaseModel, Extra, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from typing_extensions import TypeAlias from .abstract import BaseContent, HashableModel @@ -16,6 +17,10 @@ from .execution.program import ProgramContent from .item_hash import ItemHash, ItemType +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + __all__ = [ "AggregateContent", "AggregateMessage", @@ -54,8 +59,7 @@ class MongodbId(BaseModel): oid: str = Field(alias="$oid") - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class ChainRef(BaseModel): @@ -76,8 +80,7 @@ class MessageConfirmationHash(BaseModel): binary: str = Field(alias="$binary") type: str = Field(alias="$type") - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class MessageConfirmation(BaseModel): @@ -93,15 +96,13 @@ class MessageConfirmation(BaseModel): default=None, description="The address that published the transaction." ) - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class AggregateContentKey(BaseModel): name: str - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class PostContent(BaseContent): @@ -116,16 +117,15 @@ class PostContent(BaseContent): ) type: str = Field(description="User-generated 'content-type' of a POST message") - @validator("type") + @field_validator("type") def check_type(cls, v, values): if v == "amend": - ref = values.get("ref") + ref = values.data.get("ref") if not ref: raise ValueError("A 'ref' is required for POST type 'amend'") return v - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class AggregateContent(BaseContent): @@ -136,8 +136,7 @@ class AggregateContent(BaseContent): ) content: Dict = Field(description="The content of an aggregate must be a dict") - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class StoreContent(BaseContent): @@ -148,10 +147,11 @@ class StoreContent(BaseContent): size: Optional[int] = None # Generated by the node on storage content_type: Optional[str] = None # Generated by the node on storage ref: Optional[str] = None - metadata: Optional[Dict[str, Any]] = Field(description="Metadata of the VM") + metadata: Optional[Dict[str, Any]] = Field( + default=None, description="Metadata of the VM" + ) - class Config: - extra = Extra.allow + model_config = ConfigDict(extra="allow") class ForgetContent(BaseContent): @@ -214,9 +214,9 @@ class BaseMessage(BaseModel): forgotten_by: Optional[List[str]] - @validator("item_content") + @field_validator("item_content") def check_item_content(cls, v: Optional[str], values) -> Optional[str]: - item_type = values["item_type"] + item_type = values.data.get("item_type") if v is None: return None elif item_type == ItemType.inline: @@ -232,14 +232,14 @@ def check_item_content(cls, v: Optional[str], values) -> Optional[str]: ) return v - @validator("item_hash") + @field_validator("item_hash") def check_item_hash(cls, v: ItemHash, values) -> ItemHash: - item_type = values["item_type"] + item_type = values.data.get("item_type") if item_type == ItemType.inline: - item_content: str = values["item_content"] + item_content: str = values.data.get("item_content") # Double check that the hash function is supported - hash_type = values["hash_type"] or HashType.sha256 + hash_type = values.data.get("hash_type") or HashType.sha256 assert hash_type.value == HashType.sha256 computed_hash: str = sha256(item_content.encode()).hexdigest() @@ -255,23 +255,25 @@ def check_item_hash(cls, v: ItemHash, values) -> ItemHash: assert item_type == ItemType.storage return v - @validator("confirmed") + @field_validator("confirmed") def check_confirmed(cls, v, values): - confirmations = values["confirmations"] + confirmations = values.data.get("confirmations") if v is True and not bool(confirmations): raise ValueError("Message cannot be 'confirmed' without 'confirmations'") return v - @validator("time") + @field_validator("time") def convert_float_to_datetime(cls, v, values): if isinstance(v, float): v = datetime.datetime.fromtimestamp(v) assert isinstance(v, datetime.datetime) return v - class Config: - extra = Extra.forbid - exclude = {"id_", "_id"} + model_config = ConfigDict(extra="forbid") + + def custom_dump(self): + """Exclude MongoDB identifiers from dumps for historical reasons.""" + return self.model_dump(exclude={"id_", "_id"}) class PostMessage(BaseMessage): @@ -279,6 +281,7 @@ class PostMessage(BaseMessage): type: Literal[MessageType.post] content: PostContent + forgotten_by: Optional[List[str]] = None class AggregateMessage(BaseMessage): @@ -286,18 +289,22 @@ class AggregateMessage(BaseMessage): type: Literal[MessageType.aggregate] content: AggregateContent + forgotten_by: Optional[list] = None class StoreMessage(BaseMessage): type: Literal[MessageType.store] content: StoreContent + forgotten_by: Optional[list] = None + metadata: Optional[Dict[str, Any]] = None class ForgetMessage(BaseMessage): type: Literal[MessageType.forget] content: ForgetContent + forgotten_by: Optional[list] = None - @validator("forgotten_by") + @field_validator("forgotten_by") def cannot_be_forgotten(cls, v: Optional[List[str]], values) -> Optional[List[str]]: assert values if v: @@ -308,18 +315,21 @@ def cannot_be_forgotten(cls, v: Optional[List[str]], values) -> Optional[List[st class ProgramMessage(BaseMessage): type: Literal[MessageType.program] content: ProgramContent + forgotten_by: Optional[List[str]] = None - @validator("content") + @field_validator("content") def check_content(cls, v, values): - item_type = values["item_type"] + """Ensure that the content of the message is correctly formatted.""" + item_type = values.data.get("item_type") if item_type == ItemType.inline: - item_content = json.loads(values["item_content"]) - if v.dict(exclude_none=True) != item_content: - # Print differences - vdict = v.dict(exclude_none=True) - for key, value in item_content.items(): - if vdict[key] != value: - print(f"{key}: {vdict[key]} != {value}") + # Ensure that the content correct JSON + item_content = json.loads(values.data.get("item_content")) + # Ensure that the content matches the expected structure + if v.model_dump(exclude_none=True) != item_content: + logger.warning( + "Content and item_content differ for message %s", + values.data["item_hash"], + ) raise ValueError("Content and item_content differ") return v @@ -327,6 +337,7 @@ def check_content(cls, v, values): class InstanceMessage(BaseMessage): type: Literal[MessageType.instance] content: InstanceContent + forgotten_by: Optional[List[str]] = None AlephMessage: TypeAlias = Union[ @@ -363,12 +374,12 @@ def parse_message(message_dict: Dict) -> AlephMessage: message_class.__annotations__["type"].__args__[0] ) if message_dict["type"] == message_type: - return message_class.parse_obj(message_dict) + return message_class.model_validate(message_dict) else: raise ValueError(f"Unknown message type {message_dict['type']}") -def add_item_content_and_hash(message_dict: Dict, inplace: bool = False): +def add_item_content_and_hash(message_dict: Dict, inplace: bool = False) -> Dict: if not inplace: message_dict = copy(message_dict) @@ -390,7 +401,7 @@ def create_new_message( """ message_content = add_item_content_and_hash(message_dict) if factory: - return cast(T, factory.parse_obj(message_content)) + return cast(T, factory.model_validate(message_content)) else: return cast(T, parse_message(message_content)) @@ -405,7 +416,7 @@ def create_message_from_json( message_dict = json.loads(json_data) message_content = add_item_content_and_hash(message_dict, inplace=True) if factory: - return factory.parse_obj(message_content) + return factory.model_validate(message_content) else: return parse_message(message_content) @@ -422,7 +433,7 @@ def create_message_from_file( message_dict = decoder.load(fd) message_content = add_item_content_and_hash(message_dict, inplace=True) if factory: - return factory.parse_obj(message_content) + return factory.model_validate(message_content) else: return parse_message(message_content) @@ -436,5 +447,4 @@ class MessagesResponse(BaseModel): pagination_per_page: int pagination_item: str - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") diff --git a/aleph_message/models/abstract.py b/aleph_message/models/abstract.py index f272dbd..2af6f8b 100644 --- a/aleph_message/models/abstract.py +++ b/aleph_message/models/abstract.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Extra +from pydantic import BaseModel, ConfigDict def hashable(obj): @@ -24,5 +24,4 @@ class BaseContent(BaseModel): address: str time: float - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") diff --git a/aleph_message/models/execution/__init__.py b/aleph_message/models/execution/__init__.py index b45a658..c9f6561 100644 --- a/aleph_message/models/execution/__init__.py +++ b/aleph_message/models/execution/__init__.py @@ -1,7 +1,7 @@ from .abstract import BaseExecutableContent +from .base import Encoding, Interface, MachineType, Payment, PaymentType from .instance import InstanceContent from .program import ProgramContent -from .base import Encoding, MachineType, PaymentType, Payment, Interface __all__ = [ "BaseExecutableContent", diff --git a/aleph_message/models/execution/base.py b/aleph_message/models/execution/base.py index c139dda..be3d551 100644 --- a/aleph_message/models/execution/base.py +++ b/aleph_message/models/execution/base.py @@ -35,7 +35,7 @@ class Payment(HashableModel): chain: Chain """Which chain to check for funds""" - receiver: Optional[str] + receiver: Optional[str] = None """Optional alternative address to send tokens to""" type: PaymentType """Whether to pay by holding $ALEPH or by streaming tokens""" diff --git a/aleph_message/models/execution/environment.py b/aleph_message/models/execution/environment.py index e0eae1d..7c6a475 100644 --- a/aleph_message/models/execution/environment.py +++ b/aleph_message/models/execution/environment.py @@ -3,7 +3,7 @@ from enum import Enum from typing import List, Literal, Optional, Union -from pydantic import Extra, Field, validator +from pydantic import ConfigDict, Field, field_validator from ...utils import Mebibytes from ..abstract import HashableModel @@ -13,8 +13,7 @@ class Subscription(HashableModel): """A subscription is used to trigger a program in response to a FunctionTrigger.""" - class Config: - extra = Extra.allow + model_config = ConfigDict(extra="allow") class FunctionTriggers(HashableModel): @@ -29,8 +28,7 @@ class FunctionTriggers(HashableModel): description="Persist the execution of the program instead of running it on demand.", ) - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class NetworkProtocol(str, Enum): @@ -85,8 +83,7 @@ class CpuProperties(HashableModel): description="CPU features required by the virtual machine. Examples: 'sev', 'sev_es', 'sev_snp'.", ) - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class HypervisorType(str, Enum): @@ -132,8 +129,7 @@ class TrustedExecutionEnvironment(HashableModel): description="Policy of the TEE. Default value is 0x01 for SEV without debugging.", ) - class Config: - extra = Extra.allow + model_config = ConfigDict(extra="allow") class InstanceEnvironment(HashableModel): @@ -150,9 +146,9 @@ class InstanceEnvironment(HashableModel): reproducible: bool = False shared_cache: bool = False - @validator("trusted_execution", pre=True) + @field_validator("trusted_execution", mode="before") def check_hypervisor(cls, v, values): - if v and values.get("hypervisor") != HypervisorType.qemu: + if v and values.data.get("hypervisor") != HypervisorType.qemu: raise ValueError("Trusted Execution Environment is only supported for QEmu") return v @@ -166,8 +162,7 @@ class NodeRequirements(HashableModel): default=None, description="Hash of the compute resource node that must be used" ) - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class HostRequirements(HashableModel): @@ -178,6 +173,4 @@ class HostRequirements(HashableModel): default=None, description="Required Compute Resource Node properties" ) - class Config: - # Allow users to add custom requirements - extra = Extra.allow + model_config = ConfigDict(extra="allow") diff --git a/aleph_message/models/execution/instance.py b/aleph_message/models/execution/instance.py index ebb8d48..79fb99c 100644 --- a/aleph_message/models/execution/instance.py +++ b/aleph_message/models/execution/instance.py @@ -1,12 +1,16 @@ from __future__ import annotations +from typing import List, Optional + from pydantic import Field from aleph_message.models.abstract import HashableModel +from ...utils import Gigabytes, gigabyte_to_mebibyte from .abstract import BaseExecutableContent +from .base import Payment from .environment import InstanceEnvironment -from .volume import ParentVolume, PersistentVolumeSizeMib, VolumePersistence +from .volume import ParentVolume, VolumePersistence class RootfsVolume(HashableModel): @@ -20,15 +24,22 @@ class RootfsVolume(HashableModel): parent: ParentVolume persistence: VolumePersistence # Use the same size constraint as persistent volumes for now - size_mib: PersistentVolumeSizeMib + size_mib: int = Field( + gt=-1, le=gigabyte_to_mebibyte(Gigabytes(100)), strict=True # Limit to 1GiB + ) + forgotten_by: Optional[List[str]] = None class InstanceContent(BaseExecutableContent): """Message content for scheduling a VM instance on the network.""" + metadata: Optional[dict] = None + payment: Optional[Payment] = None environment: InstanceEnvironment = Field( description="Properties of the instance execution environment" ) rootfs: RootfsVolume = Field( description="Root filesystem of the system, will be booted by the kernel" ) + + authorized_keys: Optional[List[str]] = None diff --git a/aleph_message/models/execution/program.py b/aleph_message/models/execution/program.py index 8afb6d9..9bb2228 100644 --- a/aleph_message/models/execution/program.py +++ b/aleph_message/models/execution/program.py @@ -7,7 +7,7 @@ from ..abstract import HashableModel from ..item_hash import ItemHash from .abstract import BaseExecutableContent -from .base import Encoding, Interface, MachineType +from .base import Encoding, Interface, MachineType, Payment from .environment import FunctionTriggers @@ -43,8 +43,8 @@ class DataContent(HashableModel): encoding: Encoding mount: str - ref: ItemHash - use_latest: bool = False + ref: Optional[ItemHash] = None + use_latest: Optional[bool] = False class Export(HashableModel): @@ -69,3 +69,7 @@ class ProgramContent(BaseExecutableContent): default=None, description="Data to export after computation" ) on: FunctionTriggers = Field(description="Signals that trigger an execution") + + metadata: Optional[dict] = None + authorized_keys: Optional[List[str]] = None + payment: Optional[Payment] = None diff --git a/aleph_message/models/execution/volume.py b/aleph_message/models/execution/volume.py index 6102fba..2b300fa 100644 --- a/aleph_message/models/execution/volume.py +++ b/aleph_message/models/execution/volume.py @@ -4,7 +4,7 @@ from enum import Enum from typing import Literal, Optional, Union -from pydantic import ConstrainedInt, Extra +from pydantic import ConfigDict, Field from ...utils import Gigabytes, gigabyte_to_mebibyte from ..abstract import HashableModel @@ -18,27 +18,22 @@ class AbstractVolume(HashableModel, ABC): @abstractmethod def is_read_only(self): ... - class Config: - extra = Extra.forbid + model_config = ConfigDict(extra="forbid") class ImmutableVolume(AbstractVolume): - ref: ItemHash + ref: Optional[ItemHash] = None use_latest: bool = True def is_read_only(self): return True -class EphemeralVolumeSize(ConstrainedInt): - gt = 0 - le = 1000 # Limit to 1 GiB - strict = True - - class EphemeralVolume(AbstractVolume): ephemeral: Literal[True] = True - size_mib: EphemeralVolumeSize + size_mib: int = Field( + gt=0, le=gigabyte_to_mebibyte(Gigabytes(1)), strict=True # Limit to 1GiB + ) def is_read_only(self): return False @@ -58,17 +53,13 @@ class VolumePersistence(str, Enum): store = "store" -class PersistentVolumeSizeMib(ConstrainedInt): - gt = 0 - le = gigabyte_to_mebibyte(Gigabytes(100)) - strict = True # Limit to 100 GiB - - class PersistentVolume(AbstractVolume): - parent: Optional[ParentVolume] - persistence: VolumePersistence - name: str - size_mib: PersistentVolumeSizeMib + parent: Optional[ParentVolume] = None + persistence: Optional[VolumePersistence] = None + name: Optional[str] = None + size_mib: int = Field( + gt=0, le=gigabyte_to_mebibyte(Gigabytes(100)), strict=True # Limit to 100GiB + ) def is_read_only(self): return False diff --git a/aleph_message/models/item_hash.py b/aleph_message/models/item_hash.py index e029416..433daf6 100644 --- a/aleph_message/models/item_hash.py +++ b/aleph_message/models/item_hash.py @@ -1,6 +1,10 @@ from enum import Enum from functools import lru_cache +from pydantic import GetCoreSchemaHandler +from pydantic.functional_serializers import model_serializer +from pydantic_core import core_schema + from ..exceptions import UnknownHashError @@ -32,6 +36,10 @@ def is_storage(cls, item_hash: str): def is_ipfs(cls, item_hash: str): return cls.from_hash(item_hash) == cls.ipfs + @model_serializer + def __str__(self): + return self.value + class ItemHash(str): item_type: ItemType @@ -45,18 +53,22 @@ def __new__(cls, value: str): return obj @classmethod - def __get_validators__(cls): - # one or more validators may be yielded which will be called in the - # order to validate the input, each validator will receive as an input - # the value returned from the previous validator - yield cls.validate + def __get_pydantic_core_schema__( + cls, source, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + # This function validates the input after the initial type validation (as a string). + # The returned value from this function will be used as the final validated value. + + # Return a string schema and add a post-validation function to convert to ItemHash + return core_schema.no_info_after_validator_function( + cls.validate, core_schema.str_schema() + ) @classmethod def validate(cls, v): if not isinstance(v, str): raise TypeError("Item hash must be a string") - - return cls(v) + return cls(v) # Convert to ItemHash def __repr__(self): return f"" diff --git a/aleph_message/tests/test_models.py b/aleph_message/tests/test_models.py index 78cdf04..44d9be7 100644 --- a/aleph_message/tests/test_models.py +++ b/aleph_message/tests/test_models.py @@ -1,8 +1,10 @@ import json import os.path +from functools import partial from os import listdir from os.path import isdir, join from pathlib import Path +from unittest import mock import pytest import requests @@ -14,6 +16,7 @@ AggregateMessage, ForgetMessage, InstanceMessage, + ItemHash, ItemType, MessagesResponse, MessageType, @@ -27,7 +30,14 @@ parse_message, ) from aleph_message.models.execution.environment import AMDSEVPolicy +from aleph_message.models.execution.instance import RootfsVolume +from aleph_message.models.execution.volume import ( + EphemeralVolume, + ParentVolume, + VolumePersistence, +) from aleph_message.tests.download_messages import MESSAGES_STORAGE_PATH +from aleph_message.utils import Gigabytes, Mebibytes, gigabyte_to_mebibyte console = Console(color_system="windows") @@ -47,9 +57,9 @@ def test_message_response_aggregate(): data_dict = requests.get(f"{ALEPH_API_SERVER}{path}").json() message = data_dict["messages"][0] - AggregateMessage.parse_obj(message) + AggregateMessage.model_validate(message) - response = MessagesResponse.parse_obj(data_dict) + response = MessagesResponse.model_validate(data_dict) assert response @@ -60,7 +70,7 @@ def test_message_response_post(): ) data_dict = requests.get(f"{ALEPH_API_SERVER}{path}").json() - response = MessagesResponse.parse_obj(data_dict) + response = MessagesResponse.model_validate(data_dict) assert response @@ -71,7 +81,7 @@ def test_message_response_store(): ) data_dict = requests.get(f"{ALEPH_API_SERVER}{path}").json() - response = MessagesResponse.parse_obj(data_dict) + response = MessagesResponse.model_validate(data_dict) assert response @@ -103,7 +113,7 @@ def test_post_content(): time=1.0, ) assert p1.type == custom_type - assert p1.dict() == { + assert p1.model_dump() == { "address": "0x1", "time": 1.0, "content": {"blah": "bar"}, @@ -180,7 +190,7 @@ def test_validation_on_confidential_options(): assert e.errors()[0]["loc"] == ("content", "environment", "trusted_execution") assert ( e.errors()[0]["msg"] - == "Trusted Execution Environment is only supported for QEmu" + == "Value error, Trusted Execution Environment is only supported for QEmu" ) @@ -243,8 +253,9 @@ def test_message_machine_named(): message = create_message_from_file(path, factory=ProgramMessage) assert isinstance(message, ProgramMessage) - assert isinstance(message.content.metadata, dict) - assert message.content.metadata["version"] == "10.2" + if message.content is not None: + assert isinstance(message.content.metadata, dict) + assert message.content.metadata["version"] == "10.2" def test_message_forget(): @@ -262,8 +273,8 @@ def test_message_forget_cannot_be_forgotten(): message_raw["forgotten_by"] = ["abcde"] with pytest.raises(ValueError) as e: - ForgetMessage.parse_obj(message_raw) - assert e.value.args[0][0].exc.args == ("This type of message may not be forgotten",) + ForgetMessage.model_validate(message_raw) + assert "This type of message may not be forgotten" in str(e.value) def test_message_forgotten_by(): @@ -273,10 +284,12 @@ def test_message_forgotten_by(): message_raw = add_item_content_and_hash(message_raw) # Test different values for field 'forgotten_by' - _ = ProgramMessage.parse_obj(message_raw) - _ = ProgramMessage.parse_obj({**message_raw, "forgotten_by": None}) - _ = ProgramMessage.parse_obj({**message_raw, "forgotten_by": ["abcde"]}) - _ = ProgramMessage.parse_obj({**message_raw, "forgotten_by": ["abcde", "fghij"]}) + _ = ProgramMessage.model_validate(message_raw) + _ = ProgramMessage.model_validate({**message_raw, "forgotten_by": None}) + _ = ProgramMessage.model_validate({**message_raw, "forgotten_by": ["abcde"]}) + _ = ProgramMessage.model_validate( + {**message_raw, "forgotten_by": ["abcde", "fghij"]} + ) def test_item_type_from_hash(): @@ -334,6 +347,70 @@ def test_create_new_message(): assert create_message_from_json(json.dumps(message_dict)) +def test_volume_size_constraints(): + """Test size constraints for volumes""" + + _ = EphemeralVolume(size_mib=1) + # A ValidationError should be raised if the size negative + with pytest.raises(ValidationError): + _ = EphemeralVolume(size_mib=-1) + size_mib: Mebibytes = gigabyte_to_mebibyte(Gigabytes(1)) + # A size of 1GiB should be allowed + _ = EphemeralVolume(size_mib=size_mib) + # A ValidationError should be raised if the size is greater than 1GiB + with pytest.raises(ValidationError): + _ = EphemeralVolume(size_mib=size_mib + 1) + + # Use partial function to avoid repeating the same code + create_test_rootfs = partial( + RootfsVolume, + parent=ParentVolume( + ref=ItemHash("QmX8K1c22WmQBAww5ShWQqwMiFif7XFrJD6iFBj7skQZXW") + ), + persistence=VolumePersistence.store, + ) + + _ = create_test_rootfs(size_mib=1) + + # A ValidationError should be raised if the size negative + with pytest.raises(ValidationError): + _ = create_test_rootfs(size_mib=-1) + size_mib_rootfs: Mebibytes = gigabyte_to_mebibyte(Gigabytes(100)) + # A size of 100GiB should be allowed + _ = create_test_rootfs(size_mib=size_mib_rootfs) + # A ValidationError should be raised if the size is greater than 100GiB + with pytest.raises(ValidationError): + _ = create_test_rootfs(size_mib=size_mib_rootfs + 1) + + +def test_program_message_content_and_item_content_differ(): + # Test that a ValidationError is raised if the content and item_content differ + + # Get a program message as JSON-compatible dict + path = Path(__file__).parent / "messages/machine.json" + with open(path) as fd: + message_dict_original = json.load(fd) + message_dict: dict = add_item_content_and_hash(message_dict_original, inplace=True) + + # patch hashlib.sha256 with a mock else this raises an error first + mock_hash = mock.MagicMock() + mock_hash.hexdigest.return_value = ( + "cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe" + ) + message_dict["item_hash"] = ( + "cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe" + ) + + # Patch the content to differ from item_content + message_dict["content"]["replaces"] = "does-not-exist" + + # Test that a ValidationError is raised if the content and item_content differ + with mock.patch("aleph_message.models.sha256", return_value=mock_hash): + with pytest.raises(ValidationError) as excinfo: + ProgramMessage.model_validate(message_dict) + assert "Content and item_content differ" in str(excinfo.value) + + @pytest.mark.slow @pytest.mark.skipif(not isdir(MESSAGES_STORAGE_PATH), reason="No file on disk to test") def test_messages_from_disk(): diff --git a/aleph_message/tests/test_types.py b/aleph_message/tests/test_types.py index a322b8a..1e8c03f 100644 --- a/aleph_message/tests/test_types.py +++ b/aleph_message/tests/test_types.py @@ -25,35 +25,35 @@ class ModelWithItemHash(BaseModel): def test_item_hash(): storage_object_dict = {"hash": STORAGE_HASH} - storage_object = ModelWithItemHash.parse_obj(storage_object_dict) + storage_object = ModelWithItemHash.model_validate(storage_object_dict) assert storage_object.hash == STORAGE_HASH assert storage_object.hash.item_type == ItemType.storage ipfs_object_dict = {"hash": IPFS_HASH} - ipfs_object = ModelWithItemHash.parse_obj(ipfs_object_dict) + ipfs_object = ModelWithItemHash.model_validate(ipfs_object_dict) assert ipfs_object.hash == IPFS_HASH assert ipfs_object.hash.item_type == ItemType.ipfs assert repr(ipfs_object.hash).startswith("=1.10.5,<2.0.0", + "pydantic>=2", "typing_extensions>=4.5.0", ], license="MIT",