Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP prettyToString(ConfusionMatrix<MultiLabel>) and labelConfusionMat… #128

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package org.tribuo.classification.evaluation;

import java.util.logging.Logger;

import org.tribuo.classification.Classifiable;
import org.tribuo.evaluation.metrics.EvaluationMetric.Average;
import org.tribuo.evaluation.metrics.MetricTarget;

import java.util.logging.Logger;

/**
* Static functions for computing classification metrics based on a {@link ConfusionMatrix}.
*/
Expand Down Expand Up @@ -60,7 +60,7 @@ public static <T extends Classifiable<T>> double accuracy(T label, ConfusionMatr
double support = cm.support(label);
// handle div-by-zero
if (support == 0d) {
logger.warning("No predictions: accuracy ill-defined");
logger.warning("No predictions for " + label + ": accuracy ill-defined");
return Double.NaN;
}
return cm.tp(label) / cm.support(label);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@

package org.tribuo.classification.evaluation;

import java.util.Arrays;
import java.util.List;

import org.junit.jupiter.api.Test;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.tribuo.classification.Utils.label;
import static org.tribuo.classification.Utils.mkDomain;
import static org.tribuo.classification.Utils.mkPrediction;
import static org.junit.jupiter.api.Assertions.assertEquals;


public class LabelConfusionMatrixTest {
Expand All @@ -38,7 +38,8 @@ public void testMulticlass() {
mkPrediction("a", "a"),
mkPrediction("c", "b"),
mkPrediction("b", "b"),
mkPrediction("b", "c")
mkPrediction("b", "c"),
mkPrediction("a", "b")
);
ImmutableOutputInfo<Label> domain = mkDomain(predictions);
LabelConfusionMatrix cm = new LabelConfusionMatrix(domain, predictions);
Expand All @@ -54,25 +55,25 @@ public void testMulticlass() {
assertEquals(1, cm.tp(a));
assertEquals(0, cm.fp(a));
assertEquals(3, cm.tn(a));
assertEquals(0, cm.fn(a));
assertEquals(1, cm.support(a));
assertEquals(1, cm.fn(a));
assertEquals(2, cm.support(a));

assertEquals(1, cm.tp(b));
assertEquals(1, cm.fp(b));
assertEquals(2, cm.fp(b));
assertEquals(1, cm.tn(b));
assertEquals(1, cm.fn(b));
assertEquals(2, cm.support(b));

assertEquals(0, cm.tp(c));
assertEquals(1, cm.fp(c));
assertEquals(2, cm.tn(c));
assertEquals(3, cm.tn(c));
assertEquals(1, cm.fn(c));
assertEquals(1, cm.support(c));

assertEquals(4, cm.support());
assertEquals(5, cm.support());
String cmToString = cm.toString();
assertEquals(" a b c\n" +
"a 1 0 0\n" +
"a 1 1 0\n" +
"b 0 1 1\n" +
"c 0 1 0\n", cmToString);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

package org.tribuo.multilabel.evaluation;

import java.util.List;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
Expand All @@ -25,10 +30,6 @@
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.MultiLabelFactory;

import java.util.List;
import java.util.Set;
import java.util.function.Function;

/**
* A {@link ConfusionMatrix} which accepts {@link MultiLabel}s.
*
Expand Down Expand Up @@ -158,15 +159,18 @@ public double confusion(MultiLabel predicted, MultiLabel truth) {

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
for (int i = 0; i < mcm.length; i++) {
DenseMatrix cm = mcm[i];
sb.append(cm.toString());
sb.append("\n");
}
sb.append("]");
return sb.toString();
return getDomain().getDomain().stream()
.map(multiLabel -> {
final int tp = (int) tp(multiLabel);
final int fn = (int) fn(multiLabel);
final int fp = (int) fp(multiLabel);
final int tn = (int) tn(multiLabel);
return String.join("\n",
multiLabel.toString(),
String.format(" [tn: %,d fn: %,d]", tn, fn),
String.format(" [fp: %,d tp: %,d]", fp, tp));
}
).collect(Collectors.joining("\n"));
}

static ConfusionMatrixTuple tabulate(ImmutableOutputInfo<MultiLabel> domain, List<Prediction<MultiLabel>> predictions) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,37 @@

package org.tribuo.multilabel;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.MutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.evaluation.ClassifierEvaluation;
import org.tribuo.classification.evaluation.ConfusionMatrix;
import org.tribuo.classification.evaluation.LabelConfusionMatrix;
import org.tribuo.classification.sgd.linear.LinearSGDTrainer;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;
import org.tribuo.impl.ListExample;
import org.tribuo.multilabel.baseline.IndependentMultiLabelTrainer;
import org.tribuo.multilabel.evaluation.MultiLabelEvaluator;
import org.tribuo.multilabel.example.MultiLabelDataGenerator;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.tribuo.test.Helpers;

import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import com.oracle.labs.mlrg.olcut.util.Pair;

import static org.junit.jupiter.api.Assertions.assertEquals;

Expand Down Expand Up @@ -67,4 +81,97 @@ public void testIndependentBinaryPredictions() {
Helpers.testModelSerialization(model,MultiLabel.class);
}

@Test
public void testMultiLabelConfusionMatrixToStrings() {
Dataset<MultiLabel> train = MultiLabelDataGenerator.generateTrainData();
Dataset<MultiLabel> test = MultiLabelDataGenerator.generateTestData();

IndependentMultiLabelTrainer trainer = new IndependentMultiLabelTrainer(
new LogisticRegressionTrainer());
Model<MultiLabel> model = trainer.train(train);

ClassifierEvaluation<MultiLabel> evaluation = new MultiLabelEvaluator()
.evaluate(model, test);

System.out.println(evaluation);

// MultiLabelConfusionMatrix toString() hard to interpret
final ConfusionMatrix<MultiLabel> mcm = evaluation.getConfusionMatrix();

System.out.println("new toString()");
System.out.println(mcm);

System.out.println("\npredictions");
evaluation.getPredictions().forEach(System.out::println);

final List<Prediction<MultiLabel>> predictions = evaluation.getPredictions();
System.out.println("\nsingleLabelConfusionMatrix");
System.out.println(singleLabelConfusionMatrix(predictions));
}

public static LabelConfusionMatrix singleLabelConfusionMatrix(final List<Prediction<MultiLabel>> predictions) {
final List<Prediction<Label>> singleLabelPredictions = mkSingleLabelPredictions(predictions);
ImmutableOutputInfo<Label> domain = mkDomain(singleLabelPredictions);
LabelConfusionMatrix cm = new LabelConfusionMatrix(domain, singleLabelPredictions);
return cm;
}

public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<MultiLabel>> predictions) {
return mkSingleLabelPredictions(predictions, false);
}

public static List<Prediction<Label>> mkSingleLabelPredictions(List<Prediction<MultiLabel>> predictions,
final boolean falseNegativeHeuristic) {
return predictions.stream()
.flatMap(p -> {
final Set<Label> trueLabels = p.getExample().getOutput().getLabelSet();
final Set<Label> predicted = p.getOutput().getLabelSet();
// intersection(trueLabels, predicted) = true positives
// predicted - trueLabels = false positives
// trueLabels - predicted = false negatives
return Stream.concat(predicted.stream().map(pred -> {
if (trueLabels.contains(pred)) {
return mkPrediction(pred.getLabel(), pred.getLabel());
} else if (trueLabels.size() == 1) {
return mkPrediction(trueLabels.iterator().next().getLabel(), pred.getLabel());
} else {
// arbitrarily pick first trueLabel
return mkPrediction(trueLabels.iterator().next().getLabel(), pred.getLabel());
}
}),
!falseNegativeHeuristic ? Stream.of() :
// partially represent false negatives by calling them false positives tied to some predicted label if there is one
trueLabels.stream().filter(t -> !predicted.contains(t)).flatMap(fnTrueLabel -> {
if (predicted.isEmpty()) {
// nothing to pin this on
return Stream.of();
} else if (predicted.size() == 1) {
return Stream.of(mkPrediction(fnTrueLabel.getLabel(), predicted.iterator().next().getLabel()));
} else {
// arbitrarily pick first predicted label
return Stream.of(mkPrediction(fnTrueLabel.getLabel(), predicted.iterator().next().getLabel()));
}
})
);
}).collect(Collectors.toList());
}

// FIXME HACK copied from Classification/Core/src/test/java/org/tribuo/classification/Utils.java

public static Prediction<Label> mkPrediction(String trueVal, String predVal) {
LabelFactory factory = new LabelFactory();
Example<Label> example = new ListExample<>(factory.generateOutput(trueVal));
example.add(new Feature("noop", 1d));
Prediction<Label> prediction = new Prediction<>(factory.generateOutput(predVal), 0, example);
return prediction;
}

public static ImmutableOutputInfo<Label> mkDomain(List<Prediction<Label>> predictions) {
final MutableOutputInfo<Label> info = new LabelFactory().generateInfo();
for (Prediction<Label> p : predictions) {
info.observe(p.getExample().getOutput());
info.observe(p.getOutput()); // TODO? LN added
}
return info.generateImmutableOutputInfo();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@

package org.tribuo.multilabel.evaluation;

import java.util.Arrays;
import java.util.List;

import org.junit.jupiter.api.Test;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.multilabel.MultiLabel;
import org.junit.jupiter.api.Test;

import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.tribuo.multilabel.IndependentMultiLabelTest.singleLabelConfusionMatrix;
import static org.tribuo.multilabel.Utils.getUnknown;
import static org.tribuo.multilabel.Utils.label;
import static org.tribuo.multilabel.Utils.mkDomain;
import static org.tribuo.multilabel.Utils.mkPrediction;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class MultiLabelConfusionMatrixTest {

Expand Down Expand Up @@ -158,6 +159,11 @@ public void testSingleLabel() {
assertEquals(1, cm.support(c));

assertEquals(4, cm.support());

System.out.println("new toString()");
System.out.println(cm);
System.out.println("\nsingleLabelConfusionMatrix");
System.out.println(singleLabelConfusionMatrix(predictions));
}

@Test
Expand Down Expand Up @@ -231,6 +237,11 @@ public void testMultiLabel() {
assertEquals(1, cm.support(c));

assertEquals(5, cm.support());

System.out.println("new toString()");
System.out.println(cm);
System.out.println("\nsingleLabelConfusionMatrix");
System.out.println(singleLabelConfusionMatrix(predictions));
}


Expand Down