Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low accuracy when using SupportSetFolder. #149

Open
TeddyPorfiris opened this issue May 24, 2024 · 1 comment
Open

Low accuracy when using SupportSetFolder. #149

TeddyPorfiris opened this issue May 24, 2024 · 1 comment
Labels
question Further information is requested

Comments

@TeddyPorfiris
Copy link

TeddyPorfiris commented May 24, 2024

Hello! Thanks so much for easyfsl, it's fantastic. I am testing my Prototypical Network (trained on mini imagenet) with SupportSetFolder. When I test it on the folder I attached called dataset1 (containing photos from internet), I get very accurate results. But when I test it on the folder I attached called dataset2 (containing photos I took), I get very inaccurate results. If you could help me figure out why this is, I'd appreciate it so much. Thanks again.

pip install easyfsl

import torch
import os
import csv
from pathlib import Path
import pandas as pd
from skimage import io
from typing import List, Tuple
from PIL import Image
from typing import Optional

from torch import nn, optim, Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder, DatasetFolder
from torchvision.models import resnet18
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average
from easyfsl.methods.utils import compute_prototypes
from easyfsl.datasets import FewShotDataset, WrapFewShotDataset, SupportSetFolder
from easyfsl.methods import FewShotClassifier

class PrototypicalNetworks(FewShotClassifier):
    def __init__(
        self,
        backbone: Optional[nn.Module] = None,
    ):
        """
        Initialize the Prototypical Networks Few-Shot Classifier
        Args:
            backbone: the feature extractor used by the method. Must output a tensor of the
                appropriate shape (depending on the method).
                If None is passed, the backbone will be initialized as nn.Identity().
        """
        super().__init__(backbone=backbone)

    def forward(
        self,
        support_images: torch.Tensor,  # Support images
        support_labels: torch.Tensor,  # Support labels
        query_images: torch.Tensor,    # Query images
    ) -> torch.Tensor:
        """
        Predict classification labels.
        Args:
            support_images: images of the support set of shape (n_support, **image_shape)
            support_labels: labels of support set images of shape (n_support, )
            query_images: images of the query set of shape (n_query, **image_shape)
        Returns:
            a prediction of classification scores for query images of shape (n_query, n_classes)
        """
        # Compute features for support and query images
        z_support = self.compute_features(support_images)
        z_query = self.compute_features(query_images)

        # Compute prototypes from support set
        self.compute_prototypes_and_store_support_set(support_images, support_labels)
        logits = self.l2_distance_to_prototypes(z_query)
        return self.softmax_if_specified(logits)

    @staticmethod
    def is_transductive() -> bool:
        return True  # or False depending on your implementation


# Initialize the backbone (pretrained ResNet18 with the fully connected layer replaced by a Flatten layer)
convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
# print(convolutional_network)

# Create the Prototypical Networks model using resnet18 as the feature extractor CNN


model_path = '/content/MIN_model.pth'
model = PrototypicalNetworks(convolutional_network).cuda()
model.load_state_dict(torch.load(model_path))

device = "cuda"

# Define transformations to be applied to images
transform=transforms.Compose(
    [
        transforms.Resize([348, 348]),
        transforms.CenterCrop(348),
        transforms.ToTensor(),
    ]
)

support_set = SupportSetFolder(root='/content/dataset2/support_set', transform=transform, device=device)

# transform_tensor = transforms.Compose([transforms.ToTensor()])
query_image_path = '/content/dataset2/query_set/chips1.jpg'
query_image_PIL = Image.open(query_image_path)
query_images = transform(query_image_PIL).float()
query_images = query_images.unsqueeze(0)


with torch.no_grad():
    model.eval()
    model.process_support_set(support_set.get_images(), support_set.get_labels())
    class_names = support_set.classes
    print(f"Class names: {class_names}")
    predicted_labels = model(support_set.get_images().cuda(), support_set.get_labels().cuda(), query_images.to(device).cuda()).argmax(dim=1)
    # print(f"Predicted labels: {predicted_labels}")


    predicted_classes = [ support_set.classes[label] for label in predicted_labels]

    print(f"Predicted classes: {predicted_classes}")

Link to download MIN_model.pth: https://drive.google.com/file/d/1q6sfNYcYSTUJzEiHq1T-nJ5R31EZ8dio/view?usp=sharing
dataset1.zip
dataset1.zip
dataset2.zip
dataset2.zip

@TeddyPorfiris TeddyPorfiris added the question Further information is requested label May 24, 2024
@ebennequin
Copy link
Collaborator

There are numerous reasons why your model could perform well on one few-shot task and not on another. The good thing with few-shot learning is that since the volumes are small, it is easier to investigate image-wise what went wrong with the prediction, and why it was associated to a wrong class.

I suggest you use classic investigation tools (confusion matrix, etc...) and complete them with a visualization of poorly classified images and to what particular images from the support set they are the closest.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants