Skip to content

Commit

Permalink
Removing external_class as a config parameter, and checking if `nam…
Browse files Browse the repository at this point in the history
…e` is a key the registry or a libpath.
  • Loading branch information
drewoldag committed Sep 19, 2024
1 parent 2fb3a74 commit fa20613
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
11 changes: 4 additions & 7 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,18 @@ timeout = 3600
chunk_size = 990

[model]
# The name of the built-in model to use or the libpath to an external model
# e.g. "user_package.submodule.ExternalModel" or "ExampleAutoencoder"
name = "ExampleAutoencoder"

# An example of requesting an external model class
# external_class = "user_package.submodule.ExternalModel"

weights_filepath = "example_model.pth"
epochs = 10

[data_loader]
# Name of data loader to use
# Name of the built-in data loader to use or the libpath to an external data loader
# e.g. "user_package.submodule.ExternalDataLoader" or "HSCDataLoader"
name = "HSCDataLoader"

# An example of requesting an external data loader class
# external_class = "user_package.submodule.ExternalDataLoader"

# Directory path where the data is stored
path = "./data"

Expand Down
18 changes: 9 additions & 9 deletions src/fibad/plugin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@ def get_or_load_class(config: dict, registry: dict) -> type:
a `name` nor `external_cls` key was found in the config.
"""

# User specifies one of the built in classes by name
#! Once we have confidence in the config having default values, we can remove this check
if "name" in config:
class_name = config.get("name")
returned_class = None

if class_name not in registry:
raise ValueError(f"Could not find {class_name} in registry: {registry.keys()}")
# attempt to find the class in the registry
if class_name in registry:
returned_class = registry[class_name]

returned_class = registry[class_name]

# User provides an external class, attempt to import it with the module spec
elif "external_cls" in config:
returned_class = import_module_from_string(config["external_cls"])
# if the class is not in the registry, attempt to load it dynamically
else:
returned_class = import_module_from_string(class_name)

# User failed to define a class to load
else:
raise ValueError("No class requested. Specify a `name` or `external_cls` key in the runtime config.")
raise ValueError("No class requested. Specify a `name` key in the runtime config.")

return returned_class

Expand Down
1 change: 1 addition & 0 deletions src/fibad/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def run(config):

# Fetch model class specified in config and create an instance of it
model_cls = fetch_model_class(config)
logger.info(f"Training model class: {model_cls}")

Check warning on line 33 in src/fibad/train.py

View check run for this annotation

Codecov / codecov/patch

src/fibad/train.py#L33

Added line #L33 was not covered by tests
model = model_cls(model_config=config.get("model", {}), shape=data_loader.shape())

# Create trainer, a pytorch-ignite `Engine` object
Expand Down
4 changes: 2 additions & 2 deletions tests/fibad/test_plugin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_import_module_from_string_no_class():

def test_fetch_model_class():
"""Test the fetch_model_class function."""
config = {"model": {"external_cls": "builtins.BaseException"}}
config = {"model": {"name": "builtins.BaseException"}}

model_cls = fetch_model_class(config)

Expand All @@ -73,7 +73,7 @@ def test_fetch_model_class_no_model():
def test_fetch_model_class_no_model_cls():
"""Test that an exception is raised when a non-existent model class is requested."""

config = {"model": {"external_cls": "builtins.Nonexistent"}}
config = {"model": {"name": "builtins.Nonexistent"}}

with pytest.raises(AttributeError) as excinfo:
fetch_model_class(config)
Expand Down

0 comments on commit fa20613

Please sign in to comment.