Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return labels for FER2013 if possible #8452

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2442,28 +2442,68 @@ def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, "fer2013")
os.makedirs(base_folder)

use_icml = config.pop("use_icml", False)
use_fer = config.pop("use_fer", False)

num_samples = 5
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
writer = csv.DictWriter(
file,
fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",),
quoting=csv.QUOTE_NONNUMERIC,
quotechar='"',
)
writer.writeheader()
for _ in range(num_samples):
row = dict(
pixels=" ".join(
str(pixel) for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
)

if use_icml or use_fer:
pixels_key, usage_key = (" pixels", " Usage") if use_icml else ("pixels", "Usage")
fieldnames = ("emotion", usage_key, pixels_key) if use_icml else ("emotion", pixels_key, usage_key)
filename = "icml_face_data.csv" if use_icml else "fer2013.csv"
with open(os.path.join(base_folder, filename), "w", newline="") as file:
writer = csv.DictWriter(
file,
fieldnames=fieldnames,
quoting=csv.QUOTE_NONNUMERIC,
quotechar='"',
)
if config["split"] == "train":
row["emotion"] = str(int(torch.randint(0, 7, ())))
writer.writeheader()
for i in range(num_samples):
row = {
"emotion": str(int(torch.randint(0, 7, ()))),
usage_key: "Training" if i % 2 else "PublicTest",
pixels_key: " ".join(
str(pixel)
for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
),
}

writer.writerow(row)
else:
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
writer = csv.DictWriter(
file,
fieldnames=("emotion", "pixels") if config["split"] == "train" else ("pixels",),
quoting=csv.QUOTE_NONNUMERIC,
quotechar='"',
)
writer.writeheader()
for _ in range(num_samples):
row = dict(
pixels=" ".join(
str(pixel)
for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
)
)
if config["split"] == "train":
row["emotion"] = str(int(torch.randint(0, 7, ())))

writer.writerow(row)
writer.writerow(row)

return num_samples

def test_icml_file(self):
config = {"split": "test"}
with self.create_dataset(config=config) as (dataset, _):
assert all(s[1] is None for s in dataset)

for split in ("train", "test"):
for d in ({"use_icml": True}, {"use_fer": True}):
config = {"split": split, **d}
with self.create_dataset(config=config) as (dataset, _):
assert all(s[1] is not None for s in dataset)


class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.GTSRB
Expand Down
65 changes: 55 additions & 10 deletions torchvision/datasets/fer2013.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,21 @@ class FER2013(VisionDataset):
"""`FER2013
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.

.. note::
This dataset can return test labels only if ``fer2013.csv`` OR
``icml_face_data.csv`` are present in ``root/fer2013/``. If only
``train.csv`` and ``test.csv`` are present, the test labels are set to
``None``.

Args:
root (str or ``pathlib.Path``): Root directory of dataset where directory
``root/fer2013`` exists.
``root/fer2013`` exists. This directory may contain either
``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
``test.csv``. Precendence is given in that order, i.e. if
``fer2013.csv`` is present then the rest of the files will be
ignored. All these (combinations of) files contain the same data and
are supported for convenience, but only ``fer2013.csv`` and
``icml_face_data.csv`` are able to return non-None test labels.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``
Expand All @@ -25,6 +37,25 @@ class FER2013(VisionDataset):
_RESOURCES = {
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
# The fer2013.csv and icml_face_data.csv files contain both train and
# tests instances, and unlike test.csv they contain the labels for the
# test instances. We give these 2 files precedence over train.csv and
# test.csv. And yes, they both contain the same data, but with different
# column names (note the spaces) and ordering:
# $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
# ==> fer2013.csv <==
# emotion,pixels,Usage
#
# ==> icml_face_data.csv <==
# emotion, Usage, pixels
#
# ==> train.csv <==
# emotion,pixels
#
# ==> test.csv <==
# pixels
"fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
"icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
}

def __init__(
Expand All @@ -34,11 +65,13 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
self._split = verify_str_arg(split, "split", self._RESOURCES.keys())
self._split = verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, transform=transform, target_transform=target_transform)

base_folder = pathlib.Path(self.root) / "fer2013"
file_name, md5 = self._RESOURCES[self._split]
use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
data_file = base_folder / file_name
if not check_integrity(str(data_file), md5=md5):
raise RuntimeError(
Expand All @@ -47,14 +80,26 @@ def __init__(
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
)

pixels_key = " pixels" if use_icml_file else "pixels"
usage_key = " Usage" if use_icml_file else "Usage"

def get_img(row):
return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)

def get_label(row):
if use_fer_file or use_icml_file or self._split == "train":
return int(row["emotion"])
else:
return None

with open(data_file, "r", newline="") as file:
self._samples = [
(
torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
int(row["emotion"]) if "emotion" in row else None,
)
for row in csv.DictReader(file)
]
rows = (row for row in csv.DictReader(file))

if use_fer_file or use_icml_file:
valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
rows = (row for row in rows if row[usage_key] in valid_keys)

self._samples = [(get_img(row), get_label(row)) for row in rows]

def __len__(self) -> int:
return len(self._samples)
Expand Down
Loading