Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-8695d4www pydantic 2 #476

Draft
wants to merge 46 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b0b3d43
CU-8695d4www: Bump pydantic requirement to 2.6+
mart-r Aug 12, 2024
cb0104f
CU-8695d4www: Update methods to use pydantic2 based ones
mart-r Aug 12, 2024
e806d54
CU-8695d4www: Update methods to use pydantic2 based ones [part 2]
mart-r Aug 12, 2024
ea7e04a
CU-8695d4www: Use identifier based config when setting last train dat…
mart-r Aug 12, 2024
3879fe5
CU-8695d4www: Use pydantic2-based model validation
mart-r Aug 12, 2024
960e405
CU-8695d4www: Add workarounds for pydantic1 methods
mart-r Aug 12, 2024
10a7a58
CU-8695d4www: Add missing utils module for pydantic1 methods
mart-r Aug 13, 2024
080ae71
Revert "CU-8695d4www: Bump pydantic requirement to 2.6+"
mart-r Aug 13, 2024
b86135a
CU-8695d4www: [TEMP] Add type-ingores to pydantic2-based methods for …
mart-r Aug 13, 2024
0eb9f76
CU-8695d4www: Make pydantic2-requires getattribute wrapper only apply…
mart-r Aug 13, 2024
0e9fe91
CU-8695d4www: Fix missin model dump getter abstraction
mart-r Aug 13, 2024
0cb31ee
CU-8695d4www: Fix missin model dump getter abstraction (in CAT)
mart-r Aug 13, 2024
a7aab98
CU-8695d4www: Update tests for pydantic 1 and 2 support
mart-r Aug 13, 2024
897df2d
Revert "CU-8695d4www: [TEMP] Add type-ingores to pydantic2-based meth…
mart-r Aug 13, 2024
1bbe88e
Reapply "CU-8695d4www: Bump pydantic requirement to 2.6+"
mart-r Aug 13, 2024
cc7c2ce
CU-8695d4www: Allow both pydantic 1 and 2
mart-r Aug 13, 2024
0ee1a8a
CU-8695d4www: Deprecated pydantic utils for removal in 1.15
mart-r Aug 13, 2024
a89e680
CU-8695d4www: Allow usage of specified deprecated method(s) during tests
mart-r Aug 13, 2024
825628e
CU-8695d4www: Allow usage of pydantic 1-2 workaround methods during t…
mart-r Aug 13, 2024
927f807
CU-8695d4www: Add documentation for argument allowing usage during te…
mart-r Aug 13, 2024
fadc7d1
CU-8695d4www: Fix allowing deprecation during test time
mart-r Aug 13, 2024
b1b11ce
CU-8695d4www: Fix model dump getting in regression checker
mart-r Aug 14, 2024
e30ca16
Revert "CU-8695d4www: Fix allowing deprecation during test time"
mart-r Aug 15, 2024
0c5b7ca
Revert "CU-8695d4www: Add documentation for argument allowing usage d…
mart-r Aug 15, 2024
6c76acc
Revert "CU-8695d4www: Allow usage of pydantic 1-2 workaround methods …
mart-r Aug 15, 2024
a4b2ea0
Revert "CU-8695d4www: Allow usage of specified deprecated method(s) d…
mart-r Aug 15, 2024
414f70a
Revert "CU-8695d4www: Deprecated pydantic utils for removal in 1.15"
mart-r Aug 15, 2024
ecc54ab
CU-8695d4www: Add comment regarding pydantic backwards compatiblity w…
mart-r Aug 21, 2024
b160295
CU-8695d4www: Add pydantic 1 check to GHA workflow
mart-r Aug 21, 2024
6c6881a
Merge branch 'master' into CU-8695d4www-pydantic-2
mart-r Aug 29, 2024
b5ddf91
Merge branch 'master' into CU-8695d4www-pydantic-2
mart-r Aug 29, 2024
23d03c7
CU-8695d4www: Fix usage of pydantic-1 based dict method in regression…
mart-r Aug 29, 2024
8777256
CU-8695d4www: Fix usage of pydantic-1 based dict method in regression…
mart-r Aug 29, 2024
44e470a
CU-8695d4www: New workflow step to install and run mypy on pydantic 1
mart-r Aug 29, 2024
9eab8f0
CU-8695d4www: Add type ignore comments to pydantic2 versions in versi…
mart-r Aug 29, 2024
3d19cd3
Merge branch 'master' into CU-8695d4www-pydantic-2
mart-r Aug 30, 2024
ebe17e0
Merge branch 'master' into CU-8695d4www-pydantic-2
mart-r Oct 15, 2024
6746b34
CU-8695d4www: Update pydantic requirement to 2.0+ only
mart-r Oct 15, 2024
b7f895e
CU-8695d4www: Update to pydantic 2 ONLY
mart-r Oct 15, 2024
3fe2c47
CU-869671bn4: Update mypy dev requirement to be less than 1.12
mart-r Oct 15, 2024
65d653f
CU-869671bn4: Fix model fields in config
mart-r Oct 15, 2024
11b1c7a
CU-869671bn4: Fix stats helper method - use correct type adapter
mart-r Oct 15, 2024
bc5458b
CU-869671bn4: Fix some model type issues
mart-r Oct 15, 2024
95d294e
CU-869671bn4: Line up with previous model dump methods
mart-r Oct 15, 2024
4e716ae
CU-869671bn4: Fix overwriting model dump methods
mart-r Oct 15, 2024
2834c27
CU-869671bn4: Remove pydantic1 workflow step
mart-r Oct 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions install_requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
'xxhash>=3.0.0' # allow later versions, tested with 3.1.0
'blis>=0.7.5,<1.0.0' # allow later versions, tested with 0.7.9, avoid 1.0.0 (depends on numpy 2)
'click>=8.0.4' # allow later versions, tested with 8.1.3
'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes
'pydantic>=1.10.0,<3.0' # avoid next major release
"humanfriendly~=10.0" # for human readable file / RAM sizes
"peft>=0.8.2"
"peft>=0.8.2"
3 changes: 2 additions & 1 deletion medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from medcat.stats.stats import get_stats
from medcat.utils.filters import set_project_filters
from medcat.utils.usage_monitoring import UsageMonitor
from medcat.utils.pydantic_version import get_model_dump


logger = logging.getLogger(__name__) # separate logger from the package-level one
Expand Down Expand Up @@ -585,7 +586,7 @@ def _print_stats(self,

def _init_ckpts(self, is_resumed, checkpoint):
if self.config.general.checkpoint.steps is not None or checkpoint is not None:
checkpoint_config = CheckpointConfig(**self.config.general.checkpoint.dict())
checkpoint_config = CheckpointConfig(**get_model_dump(self.config.general.checkpoint))
checkpoint_manager = CheckpointManager('cat_train', checkpoint_config)
if is_resumed:
# TODO: probably remove is_resumed mark and always resume if a checkpoint is provided,
Expand Down
42 changes: 22 additions & 20 deletions medcat/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datetime import datetime
from pydantic import BaseModel, Extra, ValidationError
from pydantic.fields import ModelField
from pydantic import BaseModel, ValidationError
from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union, Type, Literal
from multiprocessing import cpu_count
import logging
Expand All @@ -13,6 +12,7 @@
from medcat.utils.matutils import intersect_nonempty_set
from medcat.utils.config_utils import attempt_fix_weighted_average_function
from medcat.utils.config_utils import weighted_average, is_old_type_config_dict
from medcat.utils.pydantic_version import get_model_dump, get_model_fields
from medcat.utils.saving.coding import CustomDelegatingEncoder, default_hook


Expand Down Expand Up @@ -125,7 +125,7 @@ def merge_config(self, config_dict: Dict) -> None:
attr = None # new attribute
value = config_dict[key]
if isinstance(value, BaseModel):
value = value.dict()
value = get_model_dump(value)
if isinstance(attr, MixingConfig):
attr.merge_config(value)
else:
Expand Down Expand Up @@ -177,7 +177,7 @@ def rebuild_re(self) -> None:
def _calc_hash(self, hasher: Optional[Hasher] = None) -> Hasher:
if hasher is None:
hasher = Hasher()
for _, v in cast(BaseModel, self).dict().items():
for _, v in get_model_dump(cast(BaseModel, self)).items():
if isinstance(v, MixingConfig):
v._calc_hash(hasher)
else:
Expand All @@ -189,7 +189,7 @@ def get_hash(self, hasher: Optional[Hasher] = None):
return hasher.hexdigest()

def __str__(self) -> str:
return str(cast(BaseModel, self).dict())
return str(get_model_dump(cast(BaseModel, self)))

@classmethod
def load(cls, save_path: str) -> "MixingConfig":
Expand Down Expand Up @@ -238,15 +238,15 @@ def asdict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The dictionary associated with this config
"""
return cast(BaseModel, self).dict()
return get_model_dump(cast(BaseModel, self))

def fields(self) -> Dict[str, ModelField]:
def fields(self) -> dict:
"""Get the fields associated with this config.

Returns:
Dict[str, ModelField]: The dictionary of the field names and fields
dict: The dictionary of the field names and fields
"""
return cast(BaseModel, self).__fields__
return get_model_fields(cast(BaseModel, self))


class VersionInfo(MixingConfig, BaseModel):
Expand All @@ -272,7 +272,7 @@ class VersionInfo(MixingConfig, BaseModel):
"""Which version of medcat was used to build the CDB"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -290,7 +290,7 @@ class CDBMaker(MixingConfig, BaseModel):
"""Minimum number of letters required in a name to be accepted for a concept"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -303,7 +303,7 @@ class AnnotationOutput(MixingConfig, BaseModel):
include_text_in_output: bool = False

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -317,7 +317,7 @@ class CheckPoint(MixingConfig, BaseModel):
"""When training the maximum checkpoints will be kept on the disk"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -351,7 +351,7 @@ class General(MixingConfig, BaseModel):
'entity_linker', 'sentencizer', 'entity_ruler', 'merge_noun_chunks',
'merge_entities', 'merge_subtokens']
checkpoint: CheckPoint = CheckPoint()
usage_monitor = UsageMonitor()
usage_monitor: UsageMonitor = UsageMonitor()
"""Checkpointing config"""
log_level: int = logging.INFO
"""Logging config for everything | 'tagger' can be disabled, but will cause a drop in performance"""
Expand Down Expand Up @@ -392,7 +392,7 @@ class General(MixingConfig, BaseModel):
reliable due to not taking into account all the details of the changes."""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -417,7 +417,7 @@ class Preprocessing(MixingConfig, BaseModel):
"""Documents longer than this will be trimmed"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -437,7 +437,7 @@ class Ner(MixingConfig, BaseModel):
"""Try reverse word order for short concepts (2 words max), e.g. heart disease -> disease heart"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -572,7 +572,7 @@ class Linking(MixingConfig, BaseModel):
"""If true when the context of a concept is calculated (embedding) the words making that concept are not taken into accout"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -593,7 +593,7 @@ class Config:
# this if for word_skipper and punct_checker which would otherwise
# not have a validator
arbitrary_types_allowed = True
extra = Extra.allow
extra = 'allow'
validate_assignment = True

def __init__(self, *args, **kwargs):
Expand All @@ -611,7 +611,7 @@ def rebuild_re(self) -> None:
# Override
def get_hash(self):
hasher = Hasher()
for k, v in self.dict().items():
for k, v in get_model_dump(self).items():
if k in ['hash', ]:
# ignore hash
continue
Expand Down Expand Up @@ -667,4 +667,6 @@ def wrapper(*args, **kwargs):
# we get a nicer exceptio
_waf_advice = "You can use `cat.cdb.weighted_average_function` to access it directly"
Linking.__getattribute__ = _wrapper(Linking.__getattribute__, Linking, _waf_advice, AttributeError) # type: ignore
if hasattr(Linking, '__getattr__'):
Linking.__getattr__ = _wrapper(Linking.__getattr__, Linking, _waf_advice, AttributeError) # type: ignore
Linking.__getitem__ = _wrapper(Linking.__getitem__, Linking, _waf_advice, KeyError) # type: ignore
12 changes: 6 additions & 6 deletions medcat/config_meta_cat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Any
from medcat.config import MixingConfig, BaseModel, Optional, Extra
from medcat.config import MixingConfig, BaseModel, Optional


class General(MixingConfig, BaseModel):
Expand Down Expand Up @@ -57,7 +57,7 @@ class General(MixingConfig, BaseModel):
Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -136,7 +136,7 @@ class Model(MixingConfig, BaseModel):
"""If set to True center positions will be ignored when calculating representation"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -158,7 +158,7 @@ class Train(MixingConfig, BaseModel):
"""If set only this CUIs will be used for training"""
auto_save_model: bool = True
"""Should do model be saved during training for best results"""
last_train_on: Optional[int] = None
last_train_on: Optional[float] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"""When was the last training run"""
metric: Dict[str, str] = {'base': 'weighted avg', 'score': 'f1-score'}
"""What metric should be used for choosing the best model"""
Expand All @@ -173,7 +173,7 @@ class Train(MixingConfig, BaseModel):
"""Focal Loss hyperparameter - determines importance the loss gives to hard-to-classify examples"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -184,5 +184,5 @@ class ConfigMetaCAT(MixingConfig, BaseModel):
train: Train = Train()

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True
8 changes: 4 additions & 4 deletions medcat/config_rel_cat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Dict, Any, List
from medcat.config import MixingConfig, BaseModel, Optional, Extra
from medcat.config import MixingConfig, BaseModel, Optional


class General(MixingConfig, BaseModel):
Expand Down Expand Up @@ -56,7 +56,7 @@ class Model(MixingConfig, BaseModel):
"""If set to True center positions will be ignored when calculating represenation"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -83,7 +83,7 @@ class Train(MixingConfig, BaseModel):
"""Should the model be saved during training for best results"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -94,5 +94,5 @@ class ConfigRelCAT(MixingConfig, BaseModel):
train: Train = Train()

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True
8 changes: 4 additions & 4 deletions medcat/config_transformers_ner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from medcat.config import MixingConfig, BaseModel, Optional, Extra
from medcat.config import MixingConfig, BaseModel, Optional


class General(MixingConfig, BaseModel):
Expand All @@ -16,11 +16,11 @@ class General(MixingConfig, BaseModel):
chunking_overlap_window: Optional[int] = 5
"""Size of the overlap window used for chunking"""
test_size: float = 0.2
last_train_on: Optional[int] = None
last_train_on: Optional[float] = None
verbose_metrics: bool = False

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -29,5 +29,5 @@ class ConfigTransformersNER(MixingConfig, BaseModel):
general: General = General()

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True
6 changes: 3 additions & 3 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def get_hash(self) -> str:
"""
hasher = Hasher()
# Set last_train_on if None
if self.config.train['last_train_on'] is None:
self.config.train['last_train_on'] = datetime.now().timestamp()
if self.config.train.last_train_on is None:
self.config.train.last_train_on = datetime.now().timestamp()

hasher.update(self.config.get_hash())
return hasher.hexdigest()
Expand Down Expand Up @@ -311,7 +311,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
# Save everything now
self.save(save_dir_path=save_dir_path)

self.config.train['last_train_on'] = datetime.now().timestamp()
self.config.train.last_train_on = datetime.now().timestamp()
return report

def eval(self, json_path: str) -> Dict:
Expand Down
6 changes: 3 additions & 3 deletions medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def get_hash(self) -> str:
"""
hasher = Hasher()
# Set last_train_on if None
if self.config.general['last_train_on'] is None:
self.config.general['last_train_on'] = datetime.now().timestamp()
if self.config.general.last_train_on is None:
self.config.general.last_train_on = datetime.now().timestamp()

hasher.update(self.config.get_hash())
return hasher.hexdigest()
Expand Down Expand Up @@ -236,7 +236,7 @@ def train(self,
trainer.train() # type: ignore

# Save the training time
self.config.general['last_train_on'] = datetime.now().timestamp() # type: ignore
self.config.general.last_train_on = datetime.now().timestamp() # type: ignore

# Save everything
self.save(save_dir_path=os.path.join(self.training_arguments.output_dir, 'final_model'))
Expand Down
9 changes: 7 additions & 2 deletions medcat/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@ def _format_version(ver: Tuple[int, int, int]) -> str:
return ".".join(str(v) for v in ver)


def deprecated(message: str, depr_version: Tuple[int, int, int], removal_version: Tuple[int, int, int]) -> Callable:
def deprecated(message: str, depr_version: Tuple[int, int, int],
removal_version: Tuple[int, int, int],
allow_usage: bool = False) -> Callable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe there is no need to add and expose this argument to the public API? It looks deprecated() will be monkey-patched before tests run anyway.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method does indeed get monkey-patched during testing.

I suppose we could add **kwargs to the method to avoid describing the argument. But we can't remove it entirely since then the addition of the allow_usage keyword/argument would cause an exception during use (TypeError due to unexpected keyword argument).
But at the same time, we if this was hidden, the future maintainer of this resource might not know how to achieve the same result.

The main reason the method gets monkey-patched during testing is because we (generally) want to avoid using deprecated methods in our code. And if we had 100% test coverage (which we certainly don't - but that's besides the point), raising an exception during test time would guarantee that we're not.
But I've made an exception to this one, mostly so we can pre-specify when we'd stop supporting pydantic 1. So that we don't forget to do this at this later date, and so that it's actually documented (at least in code) and caught during GHA workflow (at the appropriate release time).

But if you've got some ideas on how to do this more elegantly, then don't hesitate to propose them!

Copy link
Member

@baixiac baixiac Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. In that case, I think there is no need to deprecate get_model_dump and get_model_fields in this PR given the goal is to support both versions ATM and they are supposed to do the job well. Besides, it is a bit weird to add a new method and deprecate it at the same time while I have no strong opinion on this.

There will be warnings like Field "model_X" has conflict with protected namespace "model_" so to me, extra base model field renaming will be needed before medcat fully embraces pydantic v2 and deprecates v1.

"""Deprecate a method.

NOTE: The `allow_usage` argument is only read and used during test time.

Args:
message (str): The deprecation message.
depr_version (Tuple[int, int, int]): The first version of MedCAT where this was deprecated.
removal_version (Tuple[int, int, int]): The first version of MedCAT where this will be removed.
allow_usage (bool): Whether to allow usage during test time.

Returns:
Callable: _description_
Callable: The wrapped method.
"""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
Expand Down
Loading
Loading