Skip to content

Commit

Permalink
CU-8695d4www: Update to pydantic 2 ONLY
Browse files Browse the repository at this point in the history
  • Loading branch information
mart-r committed Oct 15, 2024
1 parent 6746b34 commit b7f895e
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 66 deletions.
3 changes: 1 addition & 2 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from medcat.stats.stats import get_stats
from medcat.utils.filters import set_project_filters
from medcat.utils.usage_monitoring import UsageMonitor
from medcat.utils.pydantic_version import get_model_dump


logger = logging.getLogger(__name__) # separate logger from the package-level one
Expand Down Expand Up @@ -591,7 +590,7 @@ def _print_stats(self,

def _init_ckpts(self, is_resumed, checkpoint):
if self.config.general.checkpoint.steps is not None or checkpoint is not None:
checkpoint_config = CheckpointConfig(**get_model_dump(self.config.general.checkpoint))
checkpoint_config = CheckpointConfig(**self.config.general.checkpoint.model_dump())
checkpoint_manager = CheckpointManager('cat_train', checkpoint_config)
if is_resumed:
# TODO: probably remove is_resumed mark and always resume if a checkpoint is provided,
Expand Down
13 changes: 6 additions & 7 deletions medcat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from medcat.utils.matutils import intersect_nonempty_set
from medcat.utils.config_utils import attempt_fix_weighted_average_function
from medcat.utils.config_utils import weighted_average, is_old_type_config_dict
from medcat.utils.pydantic_version import get_model_dump, get_model_fields
from medcat.utils.saving.coding import CustomDelegatingEncoder, default_hook


Expand Down Expand Up @@ -125,7 +124,7 @@ def merge_config(self, config_dict: Dict) -> None:
attr = None # new attribute
value = config_dict[key]
if isinstance(value, BaseModel):
value = get_model_dump(value)
value = value.model_dump()
if isinstance(attr, MixingConfig):
attr.merge_config(value)
else:
Expand Down Expand Up @@ -177,7 +176,7 @@ def rebuild_re(self) -> None:
def _calc_hash(self, hasher: Optional[Hasher] = None) -> Hasher:
if hasher is None:
hasher = Hasher()
for _, v in get_model_dump(cast(BaseModel, self)).items():
for _, v in cast(BaseModel, self).model_dump().items():
if isinstance(v, MixingConfig):
v._calc_hash(hasher)
else:
Expand All @@ -189,7 +188,7 @@ def get_hash(self, hasher: Optional[Hasher] = None):
return hasher.hexdigest()

def __str__(self) -> str:
return str(get_model_dump(cast(BaseModel, self)))
return str(cast(BaseModel, self).model_dump())

@classmethod
def load(cls, save_path: str) -> "MixingConfig":
Expand Down Expand Up @@ -238,15 +237,15 @@ def asdict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The dictionary associated with this config
"""
return get_model_dump(cast(BaseModel, self))
return cast(BaseModel, self).model_dump()

def fields(self) -> dict:
"""Get the fields associated with this config.
Returns:
dict: The dictionary of the field names and fields
"""
return get_model_fields(cast(BaseModel, self))
return cast(BaseModel, self).model_dump()


class VersionInfo(MixingConfig, BaseModel):
Expand Down Expand Up @@ -618,7 +617,7 @@ def rebuild_re(self) -> None:
# Override
def get_hash(self):
hasher = Hasher()
for k, v in get_model_dump(self).items():
for k, v in self.model_dump().items():
if k in ['hash', ]:
# ignore hash
continue
Expand Down
30 changes: 0 additions & 30 deletions medcat/utils/pydantic_version.py

This file was deleted.

3 changes: 1 addition & 2 deletions medcat/utils/regression/checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pydantic import BaseModel, Field

from medcat.cat import CAT
from medcat.utils.pydantic_version import get_model_dump
from medcat.utils.regression.targeting import TranslationLayer, OptionSet
from medcat.utils.regression.targeting import FinalTarget, TargetedPhraseChanger
from medcat.utils.regression.utils import partial_substitute, MedCATTrainerExportConverter
Expand Down Expand Up @@ -412,7 +411,7 @@ def to_dict(self) -> dict:
d = {}
for case in self.cases:
d[case.name] = case.to_dict()
d['meta'] = get_model_dump(self.metadata)
d['meta'] = self.metadata.model_dump()
fix_np_float64(d['meta'])

return d
Expand Down
3 changes: 1 addition & 2 deletions medcat/utils/regression/regression_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Optional, Tuple

from medcat.cat import CAT
from medcat.utils.pydantic_version import get_model_dump
from medcat.utils.regression.checking import RegressionSuite, TranslationLayer
from medcat.utils.regression.results import Strictness, Finding, STRICTNESS_MATRIX

Expand Down Expand Up @@ -119,7 +118,7 @@ def main(model_pack_dir: Path, test_suite_file: Path,
examples_strictness = Strictness[examples_strictness_str]
if jsonpath:
logger.info('Writing to %s', str(jsonpath))
dumped = get_model_dump(res, strictness=examples_strictness)
dumped = res._model_dump(strictness=examples_strictness)
jsonpath.write_text(json.dumps(dumped, indent=jsonindent))
else:
logger.info(res.get_report(phrases_separately=phrases,
Expand Down
27 changes: 9 additions & 18 deletions medcat/utils/regression/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from medcat.utils.regression.targeting import TranslationLayer, FinalTarget
from medcat.utils.regression.utils import limit_str_len, add_doc_strings_to_enum
from medcat.utils.pydantic_version import HAS_PYDANTIC2, get_model_dump


class Finding(Enum):
Expand Down Expand Up @@ -396,17 +395,17 @@ def _dict(self, **kwargs) -> dict:
key.name: value for key, value in self.findings.items()
}
serialized_examples = [
(get_model_dump(ft, **kwargs), (f[0].name, f[1])) for ft, f in self.examples
(ft.model_dump(**kwargs), (f[0].name, f[1])) for ft, f in self.examples
# only count if NOT in strictness matrix (i.e 'failures')
if f[0] not in STRICTNESS_MATRIX[strictness]
]
model_dict = get_model_dump(cast(pydantic.BaseModel, super()), **kwargs)
model_dict = cast(pydantic.BaseModel, super()).model_dump(**kwargs)
model_dict['findings'] = serialized_dict
model_dict['examples'] = serialized_examples
return model_dict

def json(self, **kwargs) -> str:
d = get_model_dump(self, **kwargs)
d = self.model_dump(**kwargs)
return json.dumps(d)


Expand Down Expand Up @@ -479,15 +478,15 @@ def get_report(self, phrases_separately: bool = False) -> str:
for srd in self.per_phrase_results.values()])
return sr + '\n\t\t' + children.replace('\n', '\n\t\t')

def _dict(self, **kwargs) -> dict:
def _model_dump(self, **kwargs) -> dict:
if 'exclude' in kwargs and kwargs['exclude'] is not None:
exclude: set = kwargs['exclude']
else:
exclude = set()
kwargs['exclude'] = exclude
# NOTE: ignoring here so that examples are only present in the per phrase part
exclude.update(('examples', 'per_phrase_results'))
d = get_model_dump(cast(pydantic.BaseModel, super()), **kwargs)
d = cast(pydantic.BaseModel, super()).model_dump(**kwargs)
if 'examples' in d:
# NOTE: I don't really know why, but the examples still
# seem to be a part of the resulting dict, so I need
Expand All @@ -496,7 +495,7 @@ def _dict(self, **kwargs) -> dict:
# NOTE: need to propagate here manually so the strictness keyword
# makes sense and doesn't cause issues due being to unexpected keyword
per_phrase_results = {
phrase: get_model_dump(res, **kwargs) for phrase, res in
phrase: res.model_dump(**kwargs) for phrase, res in
sorted(self.per_phrase_results.items(), key=lambda it: it[0])
}
d['per_phrase_results'] = per_phrase_results
Expand Down Expand Up @@ -678,7 +677,7 @@ def get_report(self, phrases_separately: bool,
])
return "\n".join(ret_vals) + f"\n{delegated}"

def _dict(self, **kwargs) -> dict:
def _model_dump(self, **kwargs) -> dict:
if 'strictness' in kwargs:
strict_raw = kwargs.pop('strictness')
if isinstance(strict_raw, Strictness):
Expand All @@ -689,20 +688,12 @@ def _dict(self, **kwargs) -> dict:
raise ValueError(f"Unknown stircntess specified: {strict_raw}")
else:
strictness = Strictness.NORMAL
out_dict = get_model_dump(cast(pydantic.BaseModel, super()), exclude={'parts'}, **kwargs)
out_dict['parts'] = [get_model_dump(part, strictness=strictness) for part in self.parts]
out_dict = cast(pydantic.BaseModel, super()).model_dump(exclude={'parts'}, **kwargs)
out_dict['parts'] = [part._model_dump(strictness=strictness) for part in self.parts]
return out_dict


class MalformedFinding(ValueError):

def __init__(self, *args: object) -> None:
super().__init__(*args)


TO_BE_FIXED = [SingleResultDescriptor, ResultDescriptor, MultiDescriptor]
for fixer in TO_BE_FIXED:
if HAS_PYDANTIC2:
fixer.model_dump = fixer._dict # type: ignore
else:
fixer.dict = fixer._dict # type: ignore
8 changes: 4 additions & 4 deletions tests/utils/regression/test_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def setUpClass(cls) -> None:
final_phrase='FINAL PHRASE'), finding=(Finding.FOUND_OTHER, 'CUI=OTHER'))

def test_result_is_json_serialisable(self):
rd = self.res._dict()
rd = self.res._model_dump()
s = json.dumps(rd)
self.assertIsInstance(s, str)

Expand All @@ -249,19 +249,19 @@ def test_result_is_json_serialisable_pydantic(self):

def test_can_use_strictness(self):
e1 = [
example for part in self.res._dict(strictness=Strictness.STRICTEST)['parts']
example for part in self.res._model_dump(strictness=Strictness.STRICTEST)['parts']
for per_phrase in part['per_phrase_results'].values()
for example in per_phrase['examples']
]
e2 = [
example for part in self.res._dict(strictness=Strictness.LENIENT)['parts']
example for part in self.res._model_dump(strictness=Strictness.LENIENT)['parts']
for per_phrase in part['per_phrase_results'].values()
for example in per_phrase['examples']
]
self.assertGreater(len(e1), len(e2))

def test_dict_includes_all_parts(self):
d_parts = self.res._dict()['parts']
d_parts = self.res._model_dump()['parts']
self.assertEqual(len(self.res.parts), len(d_parts))


Expand Down
1 change: 0 additions & 1 deletion tests/utils/test_versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import shutil

import dill
import pydantic

from medcat.utils.versioning import get_version_from_modelcard, get_semantic_version_from_model
from medcat.utils.versioning import get_version_from_cdb_dump, get_version_from_modelpack_zip
Expand Down

0 comments on commit b7f895e

Please sign in to comment.