Skip to content

Commit

Permalink
CLN: Refactor model abstraction, fix #737 #726 (#753)
Browse files Browse the repository at this point in the history
* Change type annotations on models.definition.validate to just be Type, not Type[ModelDefinition]

* Rewrite models.base.Model to not subclass LightningModule, and instead to have class variables definition and family, that are *both* used by the from_config method to make a new instance of the family with instances of the definition's attributes (trust me, this makes perfect sense)

* Change model decorator to not subclass family, and to instead subclass Model, then make a new instance of the subclass, add that instance to the registry, and then return the instance. This makes it possible to call the 'from_config' method of the instance and get a new class instance. Think of the instance as a Singleton

* Rewrite FrameClassificationModel to subclass LightningModule directly, remove from_config method

* Rewrite ParemetricUMAPModel to subclass LightningModule directly, remove from_config method

* Rewrite base.Model class, move/rename methods so that we will get a singleton that can return new lightning.Module instances with the appropriate instances of network + loss + optimizer + metrics from its from_config method

* Add load_state_dict_from_path method to FrameClassificationModel and ParametricUmapModel

* Change model_family decorator to check if family_class is a subtype of LightningModule, not vak.base.Model

* Rewrite models.base.Model.__init__ to take definition and family attributes, that are used by from_config method

* Fix how models.decorator.model makes Model instance -- we don't subclass anymore, but we do change Model instance's __name__,__doc__, and __module__ to match those of the definition

* Fix FrameClassificationModel to subclass lightning.LightningModule (not pytorch_lightning.LightningModule :eyeroll:) and to no longer pass network, loss, etc to super().__init__ since its no longer a sub-class of models.base.Model

* Fix ParametricUMAPModel to subclass lightning.LightningModule (not lightning.pytorch.LightningModule) and to no longer pass network, loss, etc to super().__init__ since its no longer a sub-class of models.base.Model

* Fix how we add a Model instance to MODEL_REGISTRY

* Fix src/vak/models/frame_classification_model.py to set network/loss/optimizer/metrics as attributes on self

* Fix src/vak/models/parametric_umap_model.py to set network/loss/optimizer/metrics as attributes on self

* Fix how we get MODEL_FAMILY_FROM_NAME dict in models.registry.__getattr__

* Fix classes in tests/test_models/conftest.py so we can use them to run tests

* Fix tests in tests/test_models/test_base.py

* Add method from_instances to vak.models.base.Model

* Rename vak.models.base.Model -> vak.models.factory.ModelFactory

* Add tests in tests/test_models/test_factory.py from test_frame_classification_model

* Fix unit test in tests/test_models/test_convencoder_umap.py

* Fix unit tests in tests/test_models/test_decorator.py

* Fix unit tests in tests/test_models/test_tweetynet.py

* Fix adding load_from_state_dict method to ParametricUMAPModel

* Fix unit tests in tests/test_models/test_frame_classification_model.py

* Rename method in tests/test_models/test_convencoder_umap.py

* Fix unit tests in tests/test_models/test_ed_tcn.py

* Add a unit test from another test_models module to test_factory.py

* Add a unit test from another test_models module to test_factory.py

* Fix unit tests in tests/test_models/test_registry.py

* Remove unused fixture 'monkeypath' in tests/test_models/test_frame_classification_model.py

* Fix unit tests in tests/test_models/test_parametric_umap_model.py

* BUG: Fix how we check if we need to  add empty dicts to model config in src/vak/config/model.py

* Rename model_class -> model_factory in src/vak/models/get.py

* Clean up docstring in src/vak/models/factory.py

* Fix how we parametrize two unit tests in tests/test_config/test_model.py

* Fix ConvEncoderUMAP configs in tests/data_for_tests/configs to have network.encoder sub-table

* Rewrite docstring, fix type annotations, rename vars for clarity in src/vak/models/decorator.py

* Revise docstring in src/vak/models/definition.py

* Revise type hinting + docstring in src/vak/models/get.py

* Revise docstring + comment in src/vak/models/registry.py

* Fix unit test in tests/test_models/test_factory.py

* Fix ParametricUMAPModel to use a ModuleDict

* Fix unit test in tests/test_models/test_convencoder_umap.py

* Fix unit test in tests/test_models/test_factory.py

* Fix unit test in tests/test_models/test_parametric_umap_model.py

* Fix common.tensorboard.events2df to avoid pandas error about re-indexing with duplicate values -- we need to not use the 'epoch' Scalar since it's all zeros
  • Loading branch information
NickleDave authored May 6, 2024
1 parent d34c3e7 commit 4862c9d
Show file tree
Hide file tree
Showing 23 changed files with 1,103 additions and 993 deletions.
2 changes: 1 addition & 1 deletion src/vak/common/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ def events2df(
).set_index("step")
if drop_wall_time:
dfs[scalar_tag].drop("wall_time", axis=1, inplace=True)
df = pd.concat([v for k, v in dfs.items()], axis=1)
df = pd.concat([v for k, v in dfs.items() if k != "epoch"], axis=1)
return df
4 changes: 3 additions & 1 deletion src/vak/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def from_config_dict(cls, config_dict: dict):
f"Model name not found in registry: {model_name}\n"
f"Model names in registry:\n{MODEL_NAMES}"
)

# NOTE: we are getting model_config here
model_config = config_dict[model_name]
if not all(key in MODEL_TABLES for key in model_config.keys()):
invalid_keys = (
Expand All @@ -89,7 +91,7 @@ def from_config_dict(cls, config_dict: dict):
)
# for any tables not specified, default to empty dict so we can still use ``**`` operator on it
for model_table in MODEL_TABLES:
if model_table not in config_dict:
if model_table not in model_config:
model_config[model_table] = {}
return cls(name=model_name, **model_config)

Expand Down
8 changes: 4 additions & 4 deletions src/vak/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import base, decorator, definition, registry
from .base import Model
from . import decorator, definition, factory, registry
from .factory import ModelFactory
from .convencoder_umap import ConvEncoderUMAP
from .decorator import model
from .ed_tcn import ED_TCN
Expand All @@ -10,14 +10,14 @@
from .tweetynet import TweetyNet

__all__ = [
"base",
"factory",
"ConvEncoderUMAP",
"decorator",
"definition",
"ED_TCN",
"FrameClassificationModel",
"get",
"Model",
"ModelFactory",
"model",
"model_family",
"ParametricUMAPModel",
Expand Down
95 changes: 44 additions & 51 deletions src/vak/models/decorator.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
"""Decorator that makes a model class,
"""Decorator that makes a :class:`vak.models.ModelFactory`,
given a definition of the model,
and another class that represents a
and a :class:`lightning.LightningModule` that represents a
family of models that the new model belongs to.
The function returns a newly-created subclass
of the class representing the family of models.
The subclass can then be instantiated
and have all model methods.
The function returns a new instance of :class:`vak.models.ModelFactory`,
that can create new instances of the model with its
:meth:`~:class:`vak.models.ModelFactory.from_config` and
:meth:`~:class:`vak.models.ModelFactory.from_instances` methods.
"""

from __future__ import annotations

from typing import Type
from typing import Type, TYPE_CHECKING

import lightning

from .base import Model
from .definition import validate as validate_definition
from .registry import register_model

if TYPE_CHECKING:
from .factory import ModelFactory

class ModelDefinitionValidationError(Exception):
"""Exception raised when validating a model
Expand All @@ -28,16 +31,16 @@ class ModelDefinitionValidationError(Exception):
pass


def model(family: Type[Model]):
"""Decorator that makes a model class,
def model(family: lightning.pytorch.LightningModule):
"""Decorator that makes a :class:`vak.models.ModelFactory`,
given a definition of the model,
and another class that represents a
and a :class:`lightning.LightningModule` that represents a
family of models that the new model belongs to.
Returns a newly-created subclass
of the class representing the family of models.
The subclass can then be instantiated
and have all model methods.
The function returns a new instance of :class:`vak.models.ModelFactory`,
that can create new instances of the model with its
:meth:`~:class:`vak.models.ModelFactory.from_config` and
:meth:`~:class:`vak.models.ModelFactory.from_instances` methods.
Parameters
----------
Expand All @@ -46,50 +49,40 @@ def model(family: Type[Model]):
A class with all the class variables required
by :func:`vak.models.definition.validate`.
See docstring of that function for specification.
family : subclass of vak.models.Model
See also :class:`vak.models.definition.ModelDefinition`,
but note that it is not necessary to subclass
:class:`~vak.models.definition.ModelDefinition` to
define a model.
family : lightning.LightningModule
The class representing the family of models
that the new model will belong to.
E.g., :class:`vak.models.FrameClassificationModel`.
Should be a subclass of :class:`lightning.LightningModule`
that was registered with the
:func:`vak.models.registry.model_family` decorator.
Returns
-------
model : type
A sub-class of ``model_family``,
with attribute ``definition``,
model_factory : vak.models.ModelFactory
An instance of :class:`~vak.models.ModelFactory`,
with attribute ``definition`` and ``family``,
that will be used when making
new instances of the model.
new instances of the model by calling the
:meth:`~vak.models.ModelFactory.from_config` method
or the :meth:`~:class:`vak.models.ModelFactory.from_instances` method.
"""

def _model(definition: Type):
if not issubclass(family, Model):
raise TypeError(
"The ``family`` argument to the ``vak.models.model`` decorator"
"should be a subclass of ``vak.models.base.Model``,"
f"but the type was: {type(family)}, "
"which was not recognized as a subclass "
"of ``vak.models.base.Model``."
)

try:
validate_definition(definition)
except ValueError as err:
raise ModelDefinitionValidationError(
f"Validation failed for the following model definition:\n{definition}"
) from err
except TypeError as err:
raise ModelDefinitionValidationError(
f"Validation failed for the following model definition:\n{definition}"
) from err

attributes = dict(family.__dict__)
attributes.update({"definition": definition})
subclass_name = definition.__name__
subclass = type(subclass_name, (family,), attributes)
subclass.__module__ = definition.__module__

# finally, add model to registry
register_model(subclass)

return subclass
def _model(definition: Type) -> ModelFactory:
from .factory import ModelFactory # avoid circular import

model_factory = ModelFactory(
definition,
family
)
model_factory.__name__ = definition.__name__
model_factory.__doc__ = definition.__doc__
model_factory.__module__ = definition.__module__
register_model(model_factory)
return model_factory

return _model
17 changes: 10 additions & 7 deletions src/vak/models/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
class ModelDefinition:
"""A class that represents the definition of a neural network model.
Note it is **not** necessary to sub-class this class;
it exists mainly for type-checking purposes.
A model definition is a class that has the following class variables:
A model definition is any class that has the following class variables:
network: torch.nn.Module or dict
Neural network.
Expand All @@ -48,6 +45,12 @@ class ModelDefinition:
Used by ``vak.models.base.Model`` and its
sub-classes that represent model families. E.g., those classes will do:
``network = self.definition.network(**self.definition.default_config['network'])``.
Note it is **not** necessary to sub-class this class;
it exists mainly for type-checking purposes.
For more detail, see :func:`vak.models.decorator.model`
and :class:`vak.models.ModelFactory`.
"""

network: Union[torch.nn.Module, dict]
Expand All @@ -67,7 +70,7 @@ class ModelDefinition:
}


def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]:
def validate(definition: Type) -> Type:
"""Validate a model definition.
A model definition is a class that has the following class variables:
Expand Down Expand Up @@ -124,8 +127,8 @@ def validate(definition: Type[ModelDefinition]) -> Type[ModelDefinition]:
converting it into a sub-class ofhttps://peps.python.org/pep-0416/
``vak.models.Model``.
It's also used by ``vak.models.Model``
to validate a definition when initializing
It's also used by :class:`vak.models.ModelFactory`,
to validate a definition before building
a new model instance from the definition.
"""
# need to set this default first
Expand Down
Loading

0 comments on commit 4862c9d

Please sign in to comment.