Skip to content

Commit

Permalink
attention
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Oct 9, 2024
1 parent 2898c84 commit 5c47b3f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 18 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,3 @@ venv/

# OS
.DS_Store

# written by setuptools_scm
**/_version.py
2 changes: 2 additions & 0 deletions src/kapoorlabs_lightning/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__version__ = version = "5.4.6"
__version_tuple__ = version_tuple = (5, 4, 6)
54 changes: 39 additions & 15 deletions src/kapoorlabs_lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
t_max: int = None,
weight_decay: float = 1e-5,
eps: float = 1e-1,
attention_dim: int =64,
attention_dim: int = 64,
strategy: str = "auto",
):
self.npz_file = npz_file
Expand Down Expand Up @@ -316,30 +316,29 @@ def setup_densenet_model(self):
kernel_size=self.kernel_size,
)
print(f"Training Mitosis Inception Model {self.model}")

def setup_attention_model(self):

self.model = AttentionNet(
input_channels=self.input_channels,
num_classes=self.num_classes,
attention_dim=self.attention_dim # Add this as a parameter in your class
attention_dim=self.attention_dim, # Add this as a parameter in your class
)
print(f"Training Attention Model {self.model}")

def setup_hybrid_attention_model(self):
self.model = HybridAttentionDenseNet(
input_channels=self.input_channels,
num_classes=self.num_classes,
growth_rate=self.growth_rate,
block_config=self.block_config,
num_init_features=self.num_init_features,
bottleneck_size=self.bottleneck_size,
kernel_size=self.kernel_size,
attention_dim=self.attention_dim # Add this as a parameter in your class
)
print(f"Training Hybrid DenseNet with Attention Model {self.model}")
self.model = HybridAttentionDenseNet(
input_channels=self.input_channels,
num_classes=self.num_classes,
growth_rate=self.growth_rate,
block_config=self.block_config,
num_init_features=self.num_init_features,
bottleneck_size=self.bottleneck_size,
kernel_size=self.kernel_size,
attention_dim=self.attention_dim, # Add this as a parameter in your class
)
print(f"Training Hybrid DenseNet with Attention Model {self.model}")


def setup_mitosisnet_model(self):
self.model = MitosisNet(
self.input_channels,
Expand Down Expand Up @@ -986,6 +985,8 @@ def extract_mitosis_model(
num_init_features = mitosis_data["num_init_features"]
bottleneck_size = mitosis_data["bottleneck_size"]
kernel_size = mitosis_data["kernel_size"]
if "attention_dim" in mitosis_data.keys():
attention_dim = mitosis_data["attention_dim"]

if ckpt_model_path is None:
if local_model_path is None:
Expand Down Expand Up @@ -1028,6 +1029,29 @@ def extract_mitosis_model(
bottleneck_size=bottleneck_size,
kernel_size=kernel_size,
)
if (
"attention_dim" in mitosis_data.keys()
and "growth_rate" in mitosis_data.keys()
):
network = mitosis_model(
input_channels,
num_classes,
growth_rate=growth_rate,
block_config=block_config,
num_init_features=num_init_features,
bottleneck_size=bottleneck_size,
kernel_size=kernel_size,
attention_dim=attention_dim,
)
if (
"attention_dim" in mitosis_data.keys()
and "growth_rate" not in mitosis_data.keys()
):

network = mitosis_model(
input_channels, num_classes, attention_dim=attention_dim
)

checkpoint_lightning_model = cls.load_from_checkpoint(
most_recent_checkpoint_ckpt,
network=network,
Expand Down

0 comments on commit 5c47b3f

Please sign in to comment.