Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -86,7 +86,7 @@ public void postConfig() {

@Override
public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> instanceProvenance) {
return(train(examples, instanceProvenance, INCREMENT_INVOCATION_COUNT));
return train(examples, instanceProvenance, INCREMENT_INVOCATION_COUNT) ;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -236,16 +236,21 @@ public synchronized void setInvocationCount(int invocationCount){
}

rng = new SplittableRandom(seed);
SplittableRandom localRNG;
trainInvocationCounter = 0;
for (int invocationCounter = 0; invocationCounter < invocationCount; invocationCounter++){
localRNG = rng.split();
trainInvocationCounter++;

for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){
SplittableRandom localRNG = rng.split();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this creates an "unused variable" compiler warning in my IDE. Do you care?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the localRNG? That's what I get in intellij, but it's intentional, it's the only way to advance the rng to the right state. I'm not sure if there is a suppress warning flag that will work across all IDEs, but if there is we can do that. We'll need to do something for the static analysis too, but that can be a problem for another day.

}

}

private float accuracy(List<Prediction<Label>> predictions, Dataset<Label> examples, float[] weights) {
/**
* Compute the accuracy of a set of predictions.
* @param predictions The base learner predictions.
* @param examples The training examples.
* @param weights The current example weights.
* @return The accuracy.
*/
private static float accuracy(List<Prediction<Label>> predictions, Dataset<Label> examples, float[] weights) {
float correctSum = 0;
float total = 0;
for (int i = 0; i < predictions.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -172,7 +172,7 @@ public void testSetInvocationCount() {

// The number of times to call train before final training.
// Original trainer will be trained numOfInvocations + 1 times
// New trainer will have it's invocation count set to numOfInvocations then trained once
// New trainer will have its invocation count set to numOfInvocations then trained once
int numOfInvocations = 2;

// Create the first model and train it numOfInvocations + 1 times
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -55,7 +55,7 @@ public class MultinomialNaiveBayesTrainer implements Trainer<Label>, WeightedExa
@Config(description="Smoothing parameter.")
private double alpha = 1.0;

private int invocationCount = 0;
private int trainInvocationCount = 0;

/**
* Constructs a multinomial naive bayes trainer using a smoothing value of 1.0.
Expand Down Expand Up @@ -111,7 +111,7 @@ public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runPr
}
TrainerProvenance trainerProvenance = getProvenance();
ModelProvenance provenance = new ModelProvenance(MultinomialNaiveBayesModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), trainerProvenance, runProvenance);
invocationCount++;
trainInvocationCount++;

SparseVector[] labelVectors = new SparseVector[labelInfos.size()];

Expand All @@ -129,7 +129,7 @@ public Model<Label> train(Dataset<Label> examples, Map<String, Provenance> runPr

@Override
public int getInvocationCount() {
return invocationCount;
return trainInvocationCount;
}

@Override
Expand All @@ -138,7 +138,7 @@ public void setInvocationCount(int invocationCount) {
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}

this.invocationCount = invocationCount;
this.trainInvocationCount = invocationCount;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -238,12 +238,9 @@ public synchronized void setInvocationCount(int invocationCount){
}

rng = new SplittableRandom(seed);
SplittableRandom localRNG;
trainInvocationCounter = 0;

for (int invocationCounter = 0; invocationCounter < invocationCount; invocationCounter++){
localRNG = rng.split();
trainInvocationCounter++;
for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){
SplittableRandom localRNG = rng.split();
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -109,7 +109,7 @@ public void testSetInvocationCount() {

// The number of times to call train before final training.
// Original trainer will be trained numOfInvocations + 1 times
// New trainer will have it's invocation count set to numOfInvocations then trained once
// New trainer will have its invocation count set to numOfInvocations then trained once
int numOfInvocations = 2;

// Create the first model and train it numOfInvocations + 1 times
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -318,12 +318,9 @@ public synchronized void setInvocationCount(int invocationCount){
}

rng = new SplittableRandom(seed);
SplittableRandom localRNG;
trainInvocationCounter = 0;

for (int invocationCounter = 0; invocationCounter < invocationCount; invocationCounter++){
localRNG = rng.split();
trainInvocationCounter++;
for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){
SplittableRandom localRNG = rng.split();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.optimisers.AdaGrad;
import org.tribuo.test.Helpers;

import java.util.logging.Level;
import java.util.logging.Logger;

import static org.junit.jupiter.api.Assertions.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;

/**
* Smoke tests for k-means.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ public int getInvocationCount() {
}

@Override
public void setInvocationCount(int invocationCount) {
public synchronized void setInvocationCount(int invocationCount) {
if(invocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -276,12 +276,16 @@ public int getInvocationCount() {
}

@Override
public void setInvocationCount(int invocationCount) {
public synchronized void setInvocationCount(int invocationCount) {
if(invocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}

this.trainInvocationCounter = invocationCount;
rng = new SplittableRandom(seed);

for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){
SplittableRandom localRNG = rng.split();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -75,7 +75,7 @@ public enum Distance {
@Config(description="The threading model to use.")
private Backend backend = Backend.THREADPOOL;

private int invocationCount = 0;
private int trainInvocationCount = 0;

/**
* For olcut.
Expand Down Expand Up @@ -131,7 +131,7 @@ public Model<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance
if(invocationCount != INCREMENT_INVOCATION_COUNT){
setInvocationCount(invocationCount);
}
invocationCount++;
trainInvocationCount++;

ModelProvenance provenance = new ModelProvenance(KNNModel.class.getName(), OffsetDateTime.now(), examples.getProvenance(), getProvenance(), runProvenance);

Expand All @@ -145,7 +145,7 @@ public String toString() {

@Override
public int getInvocationCount() {
return invocationCount;
return trainInvocationCount;
}

@Override
Expand All @@ -154,7 +154,7 @@ public void setInvocationCount(int invocationCount) {
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}

this.invocationCount = invocationCount;
this.trainInvocationCount = invocationCount;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,10 @@ public synchronized void setInvocationCount(int invocationCount){
}

rng = new SplittableRandom(seed);
SplittableRandom localRNG;
for (int invocationCounter = 0; invocationCounter < invocationCount; invocationCounter++){
localRNG = rng.split();
}
trainInvocationCounter = invocationCount;

for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){
SplittableRandom localRNG = rng.split();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -147,12 +147,9 @@ public synchronized void setInvocationCount(int invocationCount){
}

rng = new SplittableRandom(seed);
SplittableRandom localRNG;
trainInvocationCounter = 0;

for (int invocationCounter = 0; invocationCounter < invocationCount; invocationCounter++){
localRNG = rng.split();
trainInvocationCounter++;
for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){
SplittableRandom localRNG = rng.split();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ public int getInvocationCount() {
}

@Override
public void setInvocationCount(int invocationCount) {
public synchronized void setInvocationCount(int invocationCount) {
if(invocationCount < 0){
throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
}
Expand Down
10 changes: 8 additions & 2 deletions Core/src/main/java/org/tribuo/SparseTrainer.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -52,5 +52,11 @@ default public SparseModel<T> train(Dataset<T> examples) {
* @param invocationCount The state of the RNG the trainer should be set to before training
* @return a predictive model that can be used to generate predictions for new examples.
*/
public SparseModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount);
@Override
public default SparseModel<T> train(Dataset<T> examples, Map<String, Provenance> runProvenance, int invocationCount) {
synchronized (this){
setInvocationCount(invocationCount);
return train(examples, runProvenance);
}
}
}
8 changes: 5 additions & 3 deletions Core/src/main/java/org/tribuo/Trainer.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -89,11 +89,13 @@ public default Model<T> train(Dataset<T> examples, Map<String, Provenance> runPr
* This is used when reproducing a Tribuo-trained model by setting the state of the RNG to
* what it was at when Tribuo trained the original model by simulating invocations of the train method.
* This method should ALWAYS be overridden, and the default method is purely for compatibility.
* <p>
* In a future major release this default implementation will be removed.
* @param invocationCount the number of invocations of the train method to simulate
*/
default public void setInvocationCount(int invocationCount){
Logger.getLogger(this.getClass().getName()).warning("This class is using the default implementation of " +
"setInvocationCount and so might not behave as expected. We highly recommend overriding this method " +
"to function as per the documentation.");
"setInvocationCount and so might not behave as expected when reproduced. We highly recommend overriding " +
"this method as per the documentation.");
}
}
9 changes: 3 additions & 6 deletions Core/src/main/java/org/tribuo/ensemble/BaggingTrainer.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -191,12 +191,9 @@ public synchronized void setInvocationCount(int invocationCount){
}

rng = new SplittableRandom(seed);
SplittableRandom localRNG;
trainInvocationCounter = 0;

for (int invocationCounter = 0; invocationCounter < invocationCount; invocationCounter++){
localRNG = rng.split();
trainInvocationCounter++;
for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){
SplittableRandom localRNG = rng.split();
}

}
Expand Down
2 changes: 1 addition & 1 deletion Core/src/main/java/org/tribuo/hash/HashingTrainer.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down
Loading