Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AUC Metric #15

Open
JoaoMoranguinho opened this issue Nov 26, 2020 · 2 comments
Open

AUC Metric #15

JoaoMoranguinho opened this issue Nov 26, 2020 · 2 comments

Comments

@JoaoMoranguinho
Copy link

Hello!

I've been using your code and I would like to know which changes or additions must be done in order to obtain the AUC metric at least for the testing phase.

Best regards,
João Moura

@utayao
Copy link
Owner

utayao commented Nov 28, 2020

For that purpose, the direct way is to write the auc metric calculation function in the metrics.py. Another solution is to modify the code to get model output (classification prediction) and then calculate the metric out of Keras API.

@JoaoMoranguinho
Copy link
Author

Hi!

I ended up doing just that. Take a look at the implementation and tell me what you think.

class bag_auc(tf.keras.metrics.Metric):
    def __init__(self, name="bag_auc", **kwargs):
        super(bag_auc, self).__init__(name=name, **kwargs)
        self.true_labels = []
        self.pred_labels = []

    def update_state(self, y_true, y_pred, sample_weight=None):
        b_true = K.mean(y_true, axis=0, keepdims=False)
        self.true_labels.append(b_true.numpy())
        pred_max = K.max(y_pred, axis=0, keepdims=False)
        self.pred_labels.append(pred_max.numpy())

    def result(self):
        m = tf.keras.metrics.AUC()
        m.update_state(self.true_labels, self.pred_labels)
        return m.result().numpy()

Note: the model.compile() function must take the "run_eagerly=True" parameter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants