Skip to content

Commit

Permalink
Merge pull request #6 from dirac-institute/awo/use-name-in-config
Browse files Browse the repository at this point in the history
Updating to match changes in fibad.
  • Loading branch information
drewoldag authored Sep 23, 2024
2 parents 851797b + fe6fe64 commit d5515ed
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
49 changes: 37 additions & 12 deletions example_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@ log_destination = "stderr"
log_level = "info" # Emit informational messages, warnings and all errors
# log_level = "debug" # Very verbose, emit all log messages.

data_dir = "/home/drew/code/fibad/data/"

[download]
sw = "22asec"
sh = "22asec"
filter = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"]
type = "coadd"
rerun = "pdr3_wide"
username = "mtauraso@local"
password = "cCw+nX53lmNLHMy+JbizpH/dl4t7sxljiNm6a7k1"
max_connections = 2
fits_file = "../hscplay/temp.fits"
cutout_dir = "../hscplay/cutouts/"
username = false
password = false
num_sources = -1 # Values below 1 here indicate all sources in the catalog will be downloaded
offset = 0
num_sources = 500
concurrent_connections = 4
stats_print_interval = 60
fits_file = "./catalog.fits"

# These control the downloader's HTTP requests and retries
# `retry_wait` How long to wait before retrying a failed HTTP request in seconds. Default 30s
Expand All @@ -38,7 +40,14 @@ retries = 3
# `timepout` How long should we wait to get a full HTTP response from the server. Default 3600s (1hr)
timeout = 3600
# `chunksize` How many sky location rectangles should we request in a single request. Default is 990
chunksize = 990
chunk_size = 990

# Whether to retrieve the image layer
image = true
# Whether to retrieve the variance layer
variance = false
# Whether to retrieve the mask layer
mask = false

[model]
# The name of the built-in model to use or the libpath to an external model
Expand All @@ -48,19 +57,35 @@ name = "kbmod_ml.models.cnn.CNN"
weights_filepath = "example_model.pth"
epochs = 10

base_channel_size = 32
latent_dim =64

[data_loader]
# Name of the built-in data loader to use or the libpath to an external data loader
# e.g. "user_package.submodule.ExternalDataLoader" or "HSCDataLoader"
name = "CifarDataLoader"
# name = "HSCDataLoader"

# Directory path where the data is stored
path = "/home/drew/code/fibad/data/"

# Pixel dimensions used to crop all images prior to loading. Will prune any images that are too small.
#
# If not provided by user, the default of 'false' scans the directory for the smallest dimensioned files, and
# uses those pixel dimensions as the crop size.
#
#crop_to = [100,100]
crop_to = false

# Limit data loader to only particular filters when there are more in the data set.
#
# When not provided by the user, the number of filters will be automatically gleaned from the data set.
# Defaults behavior is produced by the false value.
#
#filters = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"]
filters = false

# Default PyTorch DataLoader parameters
batch_size = 10
batch_size = 4
shuffle = true
num_workers = 10
num_workers = 2

[predict]
batch_size = 32
4 changes: 2 additions & 2 deletions src/kbmod_ml/models/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@fibad_model
class CNN(nn.Module):
def __init__(self, model_config, shape):
def __init__(self, config, shape):
logger.info("This is an external model, not in FIBAD!!!")
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
Expand All @@ -25,7 +25,7 @@ def __init__(self, model_config, shape):
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

self.config = model_config
self.config = config

# Optimizer and criterion could be set directly, i.e. `self.optimizer = optim.SGD(...)`
# but we define them as methods as a way to allow for more flexibility in the future.
Expand Down

0 comments on commit d5515ed

Please sign in to comment.