Skip to content

Commit

Permalink
final commit for de 0.1.8
Browse files Browse the repository at this point in the history
  • Loading branch information
T-Almeida committed Jan 26, 2022
1 parent aad908c commit 228cd9c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
2 changes: 1 addition & 1 deletion polus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
'''

__version__="0.1.7"
__version__="0.1.8"

# add main lib sub packages
import polus.callbacks
Expand Down
28 changes: 24 additions & 4 deletions polus/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def __init__(self,
self.accelerated_map_batch = accelerated_map_batch
self.accelerated_map_f = accelerated_map_f

if source_generator is not None:
self.name = source_generator.__name__
else:
self.name = "None"

self.sample_generator = self._build_sample_generator(source_generator)

dytpes, shapes = find_dtype_and_shapes(self.sample_generator(), k=magic_k)
Expand All @@ -100,6 +105,12 @@ def __init__(self,
for method in filter(lambda x: not x[0].startswith("_"), inspect.getmembers(self.tf_dataset, predicate=inspect.ismethod)):
setattr(self, method[0], method[1])

def set_name(self, _name):
self.name = _name

@property
def __name__(self):
return f"{self.__class__.__name__}_{self.name}"

def _build_sample_generator(self, source_generator):

Expand Down Expand Up @@ -226,11 +237,21 @@ def __init__(self,
try:
super().__init__(source_generator = source_generator, **kwargs)
except Exception as e:
# here we dont want to solve the exception, we just want to clean up de previously created files
self.logger.info("An error has occured so all the created files will be deleted")
self.clean()

if cache_index is None:
# here we dont want to solve the exception, we just want to clean up de previously created files
self.logger.info("An error has occured so all the created files will be deleted")
self.clean()
raise e

@property
def __name__(self):
if self.cache_index_path is not None:
_name = os.path.splitext(os.path.basename(self.cache_index_path))[0]
return f"{self.__class__.__name__}_{self.name}"
else:
return super().__name__

@classmethod
def from_cached_index(cls, index_path):

Expand All @@ -255,7 +276,6 @@ def merge(cls, *cache_dataloaders):

index_info["n_samples"] += index["n_samples"]
index_info["files"].extend(index["files"])

# this is a bit starange, but it is possible to have different DL with diff chunk size so we will pick the larger one to define the conjunt
index_info["cache_chunk_size"] = max(index_info["cache_chunk_size"], index["cache_chunk_size"])

Expand Down
2 changes: 0 additions & 2 deletions polus/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ def save(self,

path = os.path.join(base_path, self.name+extension)



with open(path+".cfg","w") as f:
json_str = complex_json_serializer(self.savable_config)
json.dump(json_str , f)
Expand Down

0 comments on commit 228cd9c

Please sign in to comment.