Skip to content

Commit

Permalink
add shuffle_queue_size for random shuffling #7 + fixed problem of dup…
Browse files Browse the repository at this point in the history
…licate data when num_workers > 0 + attempt to solve the problem of generator pickling #5
  • Loading branch information
mboudiaf committed Apr 1, 2021
1 parent a421087 commit cad64b3
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 68 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

scripts/
conversion/

# Sublime files
*.sublime-project
Expand Down
14 changes: 11 additions & 3 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import argparse
import pytorch_meta_dataset.pipeline as pipeline
from pytorch_meta_dataset.utils import worker_init_fn_


def parse_args() -> argparse.Namespace:
Expand All @@ -31,6 +32,9 @@ def parse_args() -> argparse.Namespace:

parser.add_argument('--num_workers', type=int, default=4)

parser.add_argument('--shuffle_queue_size', type=int, default=10,
help='Number of samples loaded into memory and shuffled')

# Episode configuration
parser.add_argument('--num_ways', type=int, default=None,
help='Set it if you want a fixed # of ways per task')
Expand Down Expand Up @@ -128,7 +132,8 @@ def main(args: argparse.Namespace) -> None:
# Use a standard dataloader
episodic_loader = DataLoader(dataset=episodic_dataset,
batch_size=1,
num_workers=data_config.num_workers)
num_workers=data_config.num_workers,
worker_init_fn=worker_init_fn_)

# Training or validation loop
for i, (support, query, support_labels, query_labels) in enumerate(episodic_loader):
Expand All @@ -140,17 +145,20 @@ def main(args: argparse.Namespace) -> None:
support_labels.unique().size(0),
list(support.size()),
list(query.size())))
print(support_labels)
break

# Form a batch dataset
batch_dataset = pipeline.make_batch_pipeline(dataset_spec_list=all_dataset_specs,
split=split,
data_config=data_config)
data_config=data_config,
)

# Use a standard dataloader
batch_loader = DataLoader(dataset=batch_dataset,
batch_size=data_config.batch_size,
num_workers=data_config.num_workers)
num_workers=data_config.num_workers,
worker_init_fn=worker_init_fn_)
# Training or validation loop
for i, (input, target) in enumerate(batch_loader):
input, target = input.to(device), target.long().to(device, non_blocking=True)
Expand Down
1 change: 1 addition & 0 deletions pytorch_meta_dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
self.path = args.data_path
self.batch_size = args.batch_size
self.num_workers = args.num_workers
self.shuffle_queue_size = args.shuffle_queue_size

# Transforms and augmentations
self.image_size = args.image_size
Expand Down
44 changes: 22 additions & 22 deletions pytorch_meta_dataset/dataset_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,31 +309,31 @@ def get_classes(self,
"""
return get_classes(split, self.classes_per_split)

def to_dict(self, ret_Dict):
def to_dict(self):
"""Returns a dictionary for serialization to JSON.
Each member is converted to an elementary type that can be serialized to
JSON readily.
"""
# Start with the dict representation of the namedtuple
ret_dict = self._asdict()
red_dict = self._asdict()
# Add the class name for reconstruction when deserialized
ret_Dict['__class__'] = self.__class__.__name__
red_dict['__class__'] = self.__class__.__name__
# Convert Split enum instances to their name (string)
ret_Dict['classes_per_split'] = {
red_dict['classes_per_split'] = {
split.name: count
for split, count in six.iteritems(ret_Dict['classes_per_split'])
for split, count in six.iteritems(red_dict['classes_per_split'])
}
# Convert binary class names to unicode strings if necessary
class_names = {}
for class_id, name in six.iteritems(ret_Dict['class_names']):
for class_id, name in six.iteritems(red_dict['class_names']):
if isinstance(name, six.binary_type):
name = name.decode()
elif isinstance(name, np.integer):
name = six.text_type(name)
class_names[class_id] = name
ret_Dict['class_names'] = class_names
return ret_dict
red_dict['class_names'] = class_names
return red_dict


class BiLevelDatasetSpecification(
Expand Down Expand Up @@ -520,22 +520,22 @@ class id's relative to the split (between 0 and num classes in split).

return rel_class_ids, class_ids

def to_dict(self, ret_Dict):
def to_dict(self, red_dict):
"""Returns a dictionary for serialization to JSON.
Each member is converted to an elementary type that can be serialized to
JSON readily.
"""
# Start with the dict representation of the namedtuple
ret_dict = self._asdict()
red_dict = self._asdict()
# Add the class name for reconstruction when deserialized
ret_Dict['__class__'] = self.__class__.__name__
red_dict['__class__'] = self.__class__.__name__
# Convert Split enum instances to their name (string)
ret_Dict['superclasses_per_split'] = {
red_dict['superclasses_per_split'] = {
split.name: count
for split, count in six.iteritems(ret_Dict['superclasses_per_split'])
for split, count in six.iteritems(red_dict['superclasses_per_split'])
}
return ret_dict
return red_dict


class HierarchicalDatasetSpecification(
Expand Down Expand Up @@ -691,32 +691,32 @@ def get_total_images_per_class(self, class_id=None, pool=None):
return self.images_per_class[s][n]
raise ValueError('Class id {} not found.'.format(class_id))

def to_dict(self, ret_Dict):
def to_dict(self, red_dict):
"""Returns a dictionary for serialization to JSON.
Each member is converted to an elementary type that can be serialized to
JSON readily.
"""
# Start with the dict representation of the namedtuple
ret_dict = self._asdict()
red_dict = self._asdict()
# Add the class name for reconstruction when deserialized
ret_Dict['__class__'] = self.__class__.__name__
red_dict['__class__'] = self.__class__.__name__
# Convert the graph for each split into a serializable form
split_subgraphs = {}
for split, subgraph in six.iteritems(ret_Dict['split_subgraphs']):
for split, subgraph in six.iteritems(red_dict['split_subgraphs']):
exported_subgraph = imagenet_specification.export_graph(subgraph)
split_subgraphs[split.name] = exported_subgraph
ret_Dict['split_subgraphs'] = split_subgraphs
red_dict['split_subgraphs'] = split_subgraphs
# WordNet synsets to their WordNet ID as a string in images_per_class.
images_per_class = {}
for split, synset_counts in six.iteritems(ret_Dict['images_per_class']):
for split, synset_counts in six.iteritems(red_dict['images_per_class']):
wn_id_counts = {
synset.wn_id: count for synset, count in six.iteritems(synset_counts)
}
images_per_class[split.name] = wn_id_counts
ret_Dict['images_per_class'] = images_per_class
red_dict['images_per_class'] = images_per_class

return ret_dict
return red_dict


def as_dataset_spec(dct: Dict[str, Any]):
Expand Down
53 changes: 40 additions & 13 deletions pytorch_meta_dataset/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import torch
from .transform import get_transforms
import numpy as np
from .utils import cycle, Split
from .utils import Split, cycle_
from typing import List, Union
from .dataset_spec import HierarchicalDatasetSpecification as HDS
from .dataset_spec import BiLevelDatasetSpecification as BDS
from .dataset_spec import DatasetSpecification as DS
from .config import EpisodeDescriptionConfig, DataConfig
from tfrecord.torch.dataset import TFRecordDataset
from .sampling import EpisodeDescriptionSampler
RNG = np.random.RandomState(seed=None)


def make_episode_pipeline(dataset_spec_list: List[Union[HDS, BDS, DS]],
Expand All @@ -39,6 +38,7 @@ def make_episode_pipeline(dataset_spec_list: List[Union[HDS, BDS, DS]],
for i in range(len(dataset_spec_list)):
episode_reader = reader.Reader(dataset_spec=dataset_spec_list[i],
split=split,
shuffle_queue_size=data_config.shuffle_queue_size,
offset=0)
class_datasets = episode_reader.construct_class_datasets()
sampler = sampling.EpisodeDescriptionSampler(
Expand Down Expand Up @@ -76,6 +76,7 @@ def make_batch_pipeline(dataset_spec_list: List[Union[HDS, BDS, DS]],
for dataset_spec in dataset_spec_list:
batch_reader = reader.Reader(dataset_spec=dataset_spec,
split=split,
shuffle_queue_size=data_config.shuffle_queue_size,
offset=offset)

class_datasets = batch_reader.construct_class_datasets()
Expand All @@ -97,26 +98,27 @@ def __init__(self,
max_support_size: int,
max_query_size: int):
super(EpisodicDataset).__init__()
self.class_datasets = [cycle(dataset) for dataset in class_datasets]
self.class_datasets = class_datasets
self.sampler = sampler
self.transforms = transforms
self.max_query_size = max_query_size
self.max_support_size = max_support_size
self.random_gen = np.random.RandomState()

def __iter__(self):
while True:
episode_description = self.sampler.sample_episode_description()
episode_description = self.sampler.sample_episode_description(self.random_gen)
support_images = []
support_labels = []
query_images = []
query_labels = []
episode_classes = list({class_ for class_, _, _ in episode_description})
for class_id, nb_support, nb_query in episode_description:
for _ in range(nb_support):
sample_dic = next(self.class_datasets[class_id])
sample_dic = self.get_next(class_id)
support_images.append(self.transforms(sample_dic['image']).unsqueeze(0))
for _ in range(nb_query):
sample_dic = next(self.class_datasets[class_id])
sample_dic = self.get_next(class_id)
query_images.append(self.transforms(sample_dic['image']).unsqueeze(0))
support_labels.extend([episode_classes.index(class_id)] * nb_support)
query_labels.extend([episode_classes.index(class_id)] * nb_query)
Expand All @@ -126,31 +128,56 @@ def __iter__(self):
query_labels = torch.tensor(query_labels)
yield support_images, query_images, support_labels, query_labels

def get_next(self, class_id):
try:
sample_dic = next(self.class_datasets[class_id])
except:
self.class_datasets[class_id] = cycle_(self.class_datasets[class_id])
sample_dic = next(self.class_datasets[class_id])
return sample_dic


class BatchDataset(torch.utils.data.IterableDataset):
def __init__(self,
class_datasets: List[TFRecordDataset],
transforms: torchvision.transforms):
super(BatchDataset).__init__()
self.class_datasets = [cycle(dataset) for dataset in class_datasets]
self.class_datasets = class_datasets
self.transforms = transforms

def __iter__(self):
while True:
rand_class = RNG.randint(len(self.class_datasets))
sample_dic = next(self.class_datasets[rand_class])
rand_class = self.random_gen.randint(len(self.class_datasets))
sample_dic = self.get_next(rand_class)
transformed_image = self.transforms(sample_dic['image'])
target = sample_dic['label'][0]
yield transformed_image, target

def get_next(self, class_id):
try:
sample_dic = next(self.class_datasets[class_id])
except:
self.class_datasets[class_id] = cycle_(self.class_datasets[class_id])
sample_dic = next(self.class_datasets[class_id])
return sample_dic


class ZipDataset(torch.utils.data.IterableDataset):
def __init__(self,
dataset_list: List[EpisodicDataset]):
self.episodic_dataset_list = [cycle(dataset) for dataset in dataset_list]
self.dataset_list = dataset_list
self.random_gen = np.random.RandomState()

def __iter__(self):
while True:
rand_source = RNG.randint(len(self.episodic_dataset_list))
next_e = next(self.episodic_dataset_list[rand_source])
yield next_e
rand_source = self.random_gen.randint(len(self.dataset_list))
next_e = self.get_next(rand_source)
yield next_e

def get_next(self, source_id):
try:
dataset = next(self.dataset_list[source_id])
except:
self.dataset_list[source_id] = iter(self.dataset_list[source_id])
dataset = next(self.dataset_list[source_id])
return dataset
5 changes: 4 additions & 1 deletion pytorch_meta_dataset/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Reader(object):
def __init__(self,
dataset_spec: Union[HDS, BDS, DS],
split: Split,
shuffle_queue_size: int,
offset: int):
"""Initializes a Reader from a source.
Expand All @@ -39,6 +40,7 @@ def __init__(self,
self.split = split
self.dataset_spec = dataset_spec
self.offset = offset
self.shuffle_queue_size = shuffle_queue_size

self.base_path = self.dataset_spec.path
self.class_set = self.dataset_spec.get_classes(self.split)
Expand Down Expand Up @@ -78,7 +80,8 @@ def decode_image(features, offset):
dataset = TFRecordDataset(data_path=filename,
index_path=index_path,
description=description,
transform=decode_fn)
transform=decode_fn,
shuffle_queue_size=self.shuffle_queue_size)

class_datasets.append(dataset)

Expand Down
Loading

0 comments on commit cad64b3

Please sign in to comment.