Skip to content
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -76,4 +76,15 @@ public Iterator<Pair<Integer, Event>> iterator() {

return list.iterator();
}

@Override
public boolean domainAndIDEquals(ImmutableOutputInfo<Event> other) {
if (other instanceof ImmutableAnomalyInfo) {
return true;
} else {
return other.size() == 2
&& other.getID(AnomalyFactory.ANOMALOUS_EVENT) == AnomalyFactory.ANOMALOUS_EVENT.getType().getID()
&& other.getID(AnomalyFactory.EXPECTED_EVENT) == AnomalyFactory.EXPECTED_EVENT.getType().getID();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -31,7 +31,7 @@
import java.util.logging.Logger;

/**
* An ImmutableOutputInfo object for Labels.
* An {@link ImmutableOutputInfo} object for {@link Label}s.
* <p>
* Gives each unique label an id number. Also counts each label occurrence like {@link MutableLabelInfo} does,
* though the counts are frozen in this object.
Expand Down Expand Up @@ -168,6 +168,23 @@ public Iterator<Pair<Integer, Label>> iterator() {
return new ImmutableInfoIterator(idLabelMap);
}

@Override
public boolean domainAndIDEquals(ImmutableOutputInfo<Label> other) {
if (size() == other.size()) {
for (Map.Entry<Integer,String> e : idLabelMap.entrySet()) {
Label otherLbl = other.getOutput(e.getKey());
if (otherLbl == null) {
return false;
} else if (!otherLbl.label.equals(e.getValue())) {
return false;
}
}
return true;
} else {
return false;
}
}

/**
* An iterator that converts {@link Map.Entry} into {@link Pair}s on the way out.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,10 @@
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.classification.Classifiable;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.ToDoubleFunction;

/**
Expand Down Expand Up @@ -128,6 +132,42 @@ public default double tn() {
return sumOverOutputs(getDomain(), this::tn);
}

/**
* The values this confusion matrix has seen.
* <p>
* The default implementation is provided for compatibility reasons and will be removed
* in a future major release. It defaults to returning the output domain.
* @return The set of observed outputs.
*/
default public Set<T> observed() {
return getDomain().getDomain();
}

/**
* The label order this confusion matrix uses in {@code toString}.
* <p>
* The default implementation is provided for compatibility reasons and will be removed
* in a future major release. It defaults to the output domain iterated in hash order.
* @return An unmodifiable view on the label order.
*/
public default List<T> getLabelOrder() {
return Collections.unmodifiableList(new ArrayList<>(getDomain().getDomain()));
}

/**
* Sets the label order this confusion matrix uses in {@code toString}.
* <p>
* If the label order is a subset of the labels in the domain, only the
* labels present in the label order will be displayed.
* <p>
* The default implementation does not set the label order and is provided for
* backwards compatibility reasons. It should be overridden in all subclasses to
* ensure correct behaviour, and this default implementation will be removed in a
* future major release.
* @param labelOrder The label order.
*/
public default void setLabelOrder(List<T> labelOrder) {}

/**
* Sums the supplied getter over the domain.
* @param domain The domain to sum over.
Expand All @@ -142,4 +182,5 @@ static <T extends Classifiable<T>> double sumOverOutputs(ImmutableOutputInfo<T>
}
return total;
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@
import org.tribuo.math.la.DenseMatrix;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
Expand Down Expand Up @@ -66,7 +67,8 @@ public final class LabelConfusionMatrix implements ConfusionMatrix<Label> {
/**
* Creates a confusion matrix from the supplied predictions, using the label info
* from the supplied model.
* @param model The model to use for the label information.
*
* @param model The model to use for the label information.
* @param predictions The predictions.
*/
public LabelConfusionMatrix(Model<Label> model, List<Prediction<Label>> predictions) {
Expand All @@ -75,21 +77,24 @@ public LabelConfusionMatrix(Model<Label> model, List<Prediction<Label>> predicti

/**
* Creates a confusion matrix from the supplied predictions and label info.
* @throws IllegalArgumentException If the domain doesn't contain all the predictions.
* @param domain The label information.
*
* @param domain The label information.
* @param predictions The predictions.
* @throws IllegalArgumentException If the domain doesn't contain all the predictions.
*/
public LabelConfusionMatrix(ImmutableOutputInfo<Label> domain, List<Prediction<Label>> predictions) {
this.domain = domain;
this.total = predictions.size();
this.cm = new DenseMatrix(domain.size(), domain.size());
this.occurrences = new HashMap<>();
this.observed = new HashSet<>();
this.labelOrder = Collections.unmodifiableList(new ArrayList<>(domain.getDomain()));
tabulate(predictions);
}

/**
* Aggregate the predictions into this confusion matrix.
*
* @param predictions The predictions to aggregate.
*/
private void tabulate(List<Prediction<Label>> predictions) {
Expand All @@ -101,7 +106,7 @@ private void tabulate(List<Prediction<Label>> predictions) {
if (y.getLabel().equals(Label.UNKNOWN)) {
throw new IllegalArgumentException("Prediction with unknown ground truth. Unable to evaluate.");
}
occurrences.merge(y,1d, Double::sum);
occurrences.merge(y, 1d, Double::sum);
observed.add(y);
observed.add(p);
int iy = getIDOrThrow(y);
Expand All @@ -115,6 +120,11 @@ public ImmutableOutputInfo<Label> getDomain() {
return domain;
}

@Override
public Set<Label> observed() {
return Collections.unmodifiableSet(observed);
}

@Override
public double support() {
return total;
Expand Down Expand Up @@ -170,7 +180,8 @@ public double confusion(Label predicted, Label trueClass) {

/**
* A convenience method for extracting the appropriate label statistic.
* @param cls The label to check.
*
* @param cls The label to check.
* @param getter The get function which accepts a label id.
* @return The statistic for that label id.
*/
Expand All @@ -186,6 +197,7 @@ private double compute(Label cls, ToDoubleFunction<Integer> getter) {
/**
* Gets the id for the supplied label, or throws an {@link IllegalArgumentException} if it's
* an unknown label.
*
* @param key The label.
* @return The int id for that label.
*/
Expand All @@ -199,21 +211,37 @@ private int getIDOrThrow(Label key) {

/**
* Sets the label order used in {@link #toString}.
* @param labelOrder The label order to use.
* <p>
* If the label order is a subset of the labels in the domain, only the
* labels present in the label order will be displayed.
*
* @param newLabelOrder The label order to use.
*/
@Override
public void setLabelOrder(List<Label> newLabelOrder) {
if (newLabelOrder == null || newLabelOrder.isEmpty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know what size the labelset should be at this point? Can we easily check for that too while we're at it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually allows you to reduce the set of labels that you print (and was designed that way many years ago), if you only care about showing a subset of them. However that's not properly documented, nor do I have a test for it, so I should add one.

throw new IllegalArgumentException("Label order must be non-null and non-empty.");
}
this.labelOrder = Collections.unmodifiableList(new ArrayList<>(newLabelOrder));
}

/**
* Gets the current label order.
*
* May trigger order instantiation if the label order has not been set.
* @return The label order.
*/
public void setLabelOrder(List<Label> labelOrder) {
this.labelOrder = labelOrder;
public List<Label> getLabelOrder() {
return labelOrder;
}

@Override
public String toString() {
if (labelOrder == null) {
labelOrder = new ArrayList<>(domain.getDomain());
}
labelOrder.retainAll(observed);

List<Label> curOrder = new ArrayList<>(labelOrder);
curOrder.retainAll(observed);

int maxLen = Integer.MIN_VALUE;
for (Label label : labelOrder) {
for (Label label : curOrder) {
maxLen = Math.max(label.getLabel().length(), maxLen);
maxLen = Math.max(String.format(" %,d", (int)(double)occurrences.getOrDefault(label,0.0)).length(), maxLen);
}
Expand All @@ -229,14 +257,14 @@ public String toString() {

//
// Labels across the top for predicted.
for (Label predictedLabel : labelOrder) {
for (Label predictedLabel : curOrder) {
sb.append(String.format(predictedLabelFormat, predictedLabel.getLabel()));
}
sb.append('\n');

for (Label trueLabel : labelOrder) {
for (Label trueLabel : curOrder) {
sb.append(String.format(trueLabelFormat, trueLabel.getLabel()));
for (Label predictedLabel : labelOrder) {
for (Label predictedLabel : curOrder) {
int confusion = (int) confusion(predictedLabel, trueLabel);
sb.append(String.format(countFormat, confusion));
}
Expand All @@ -250,9 +278,6 @@ public String toString() {
* @return The confusion matrix as a HTML table.
*/
public String toHTML() {
if (labelOrder == null) {
labelOrder = new ArrayList<>(domain.getDomain());
}
Set<Label> labelsToPrint = new LinkedHashSet<>(labelOrder);
labelsToPrint.retainAll(observed);
StringBuilder sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -95,6 +95,11 @@ public interface LabelEvaluation extends ClassifierEvaluation<Label> {

/**
* Returns a HTML formatted String representing this evaluation.
* <p>
* Uses the label order of the confusion matrix, which can be used to display
* a subset of the per label metrics. When they are subset the total row
* represents only the subset selected, not all the predictions, however
* the accuracy and averaged metrics cover all the predictions.
* @return A HTML formatted String.
*/
default String toHTML() {
Expand All @@ -106,12 +111,18 @@ default String toHTML() {
* appropriate tabs and newlines, suitable for display on a terminal.
* It can be used as an implementation of the {@link EvaluationRenderer}
* functional interface.
* <p>
* Uses the label order of the confusion matrix, which can be used to display
* a subset of the per label metrics. When they are subset the total row
* represents only the subset selected, not all the predictions, however
* the accuracy and averaged metrics cover all the predictions.
* @param evaluation The evaluation to format.
* @return Formatted output showing the main results of the evaluation.
*/
public static String toFormattedString(LabelEvaluation evaluation) {
ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix();
List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain());
List<Label> labelOrder = new ArrayList<>(cm.getLabelOrder());
labelOrder.retainAll(cm.observed());
StringBuilder sb = new StringBuilder();
int tp = 0;
int fn = 0;
Expand Down Expand Up @@ -151,7 +162,7 @@ public static String toFormattedString(LabelEvaluation evaluation) {
sb.append(String.format(labelFormatString, "Total"));
sb.append(String.format("%,12d%,12d%,12d%,12d%n", n, tp, fn, fp));
sb.append(String.format(labelFormatString, "Accuracy"));
sb.append(String.format("%60.3f%n", (double) tp / n));
sb.append(String.format("%60.3f%n", evaluation.accuracy()));
sb.append(String.format(labelFormatString, "Micro Average"));
sb.append(String.format("%60.3f%12.3f%12.3f%n",
evaluation.microAveragedRecall(),
Expand All @@ -169,15 +180,20 @@ public static String toFormattedString(LabelEvaluation evaluation) {

/**
* This method produces a HTML formatted String output, with
* appropriate tabs and newlines, suitable for integation into a webpage.
* appropriate tabs and newlines, suitable for integration into a webpage.
* It can be used as an implementation of the {@link EvaluationRenderer}
* functional interface.
* <p>
* Uses the label order of the confusion matrix, which can be used to display
* a subset of the per label metrics. When they are subset the total row
* represents only the subset selected, not all the predictions, however
* the accuracy and averaged metrics cover all the predictions.
* @param evaluation The evaluation to format.
* @return Formatted HTML output showing the main results of the evaluation.
*/
public static String toHTML(LabelEvaluation evaluation) {
ConfusionMatrix<Label> cm = evaluation.getConfusionMatrix();
List<Label> labelOrder = new ArrayList<>(cm.getDomain().getDomain());
List<Label> labelOrder = cm.getLabelOrder();
StringBuilder sb = new StringBuilder();
int tp = 0;
int fn = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,14 @@ public ConfusionMatrix<Label> getConfusionMatrix() {
public EvaluationProvenance getProvenance() { return provenance; }

/**
* This produces a formatted String suitable for a terminal.
* @return A formatted String representing this {@code LabelEvaluationImpl}.
* This method produces a nicely formatted String output, with
* appropriate tabs and newlines, suitable for display on a terminal.
* <p>
* Uses the label order of the confusion matrix, which can be used to display
* a subset of the per label metrics. When they are subset the total row
* represents only the subset selected, not all the predictions, however
* the accuracy and averaged metrics cover all the predictions.
* @return Formatted output showing the main results of the evaluation.
*/
@Override
public String toString() {
Expand Down
Loading