Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,32 @@ public XGBoostClassificationTrainer(int numTrees, double eta, double gamma, int
postConfig();
}

/**
* Create an XGBoost trainer.
*
* @param boosterType The base learning algorithm.
* @param treeMethod The tree building algorithm if using a tree booster.
* @param numTrees Number of trees to boost.
* @param eta Step size shrinkage parameter (default 0.3, range [0,1]).
* @param gamma Minimum loss reduction to make a split (default 0, range
* [0,inf]).
* @param maxDepth Maximum tree depth (default 6, range [1,inf]).
* @param minChildWeight Minimum sum of instance weights needed in a leaf
* (default 1, range [0, inf]).
* @param subsample Subsample size for each tree (default 1, range (0,1]).
* @param featureSubsample Subsample features for each tree (default 1,
* range (0,1]).
* @param lambda L2 regularization term on weights (default 1).
* @param alpha L1 regularization term on weights (default 0).
* @param nThread Number of threads to use (default 4).
* @param verbosity Set the logging verbosity of the native library.
* @param seed RNG seed.
*/
public XGBoostClassificationTrainer(BoosterType boosterType, TreeMethod treeMethod, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, LoggingVerbosity verbosity, long seed) {
super(boosterType,treeMethod,numTrees,eta,gamma,maxDepth,minChildWeight,subsample,featureSubsample,lambda,alpha,nThread,verbosity,seed);
postConfig();
}

/**
* This gives direct access to the XGBoost parameter map.
* <p>
Expand All @@ -128,7 +154,7 @@ protected XGBoostClassificationTrainer() { }
public void postConfig() {
super.postConfig();
parameters.put("objective", "multi:softprob");
if(!evalMetric.isEmpty()) {
if (!evalMetric.isEmpty()) {
parameters.put("eval_metric", evalMetric);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,45 @@
import com.oracle.labs.mlrg.olcut.config.Option;
import org.tribuo.Trainer;
import org.tribuo.classification.ClassificationOptions;
import org.tribuo.common.xgboost.XGBoostTrainer;

import java.util.logging.Level;
import java.util.logging.Logger;

/**
* CLI options for training an XGBoost classifier.
*/
public class XGBoostOptions implements ClassificationOptions<XGBoostClassificationTrainer> {
private static final Logger logger = Logger.getLogger(XGBoostOptions.class.getName());

@Option(longName = "xgb-booster-type", usage = "Weak learning algorithm.")
public XGBoostTrainer.BoosterType xgbBoosterType = XGBoostTrainer.BoosterType.GBTREE;
@Option(longName = "xgb-tree-method", usage = "Tree building algorithm.")
public XGBoostTrainer.TreeMethod xgbTreeMethod = XGBoostTrainer.TreeMethod.AUTO;
@Option(longName = "xgb-ensemble-size", usage = "Number of trees in the ensemble.")
public int xgbEnsembleSize = -1;
@Option(longName = "xgb-alpha", usage = "L1 regularization term for weights (default 0).")
@Option(longName = "xgb-alpha", usage = "L1 regularization term for weights.")
public float xbgAlpha = 0.0f;
@Option(longName = "xgb-min-weight", usage = "Minimum sum of instance weights needed in a leaf (default 1, range [0,inf]).")
@Option(longName = "xgb-min-weight", usage = "Minimum sum of instance weights needed in a leaf (range [0,Infinity]).")
public float xgbMinWeight = 1;
@Option(longName = "xgb-max-depth", usage = "Max tree depth (default 6, range (0,inf]).")
@Option(longName = "xgb-max-depth", usage = "Max tree depth (range (0,Integer.MAX_VALUE]).")
public int xgbMaxDepth = 6;
@Option(longName = "xgb-eta", usage = "Step size shrinkage parameter (default 0.3, range [0,1]).")
@Option(longName = "xgb-eta", usage = "Step size shrinkage parameter (range [0,1]).")
public float xgbEta = 0.3f;
@Option(longName = "xgb-subsample-features", usage = "Subsample features for each tree (default 1, range (0,1]).")
public float xgbSubsampleFeatures;
@Option(longName = "xgb-gamma", usage = "Minimum loss reduction to make a split (default 0, range [0,inf]).")
@Option(longName = "xgb-subsample-features", usage = "Subsample features for each tree (range (0,1]).")
public float xgbSubsampleFeatures = 0.0f;
@Option(longName = "xgb-gamma", usage = "Minimum loss reduction to make a split (range [0,Infinity]).")
public float xgbGamma = 0.0f;
@Option(longName = "xgb-lambda", usage = "L2 regularization term for weights (default 1).")
@Option(longName = "xgb-lambda", usage = "L2 regularization term for weights.")
public float xgbLambda = 1.0f;
@Option(longName = "xgb-quiet", usage = "Make the XGBoost training procedure quiet.")
@Option(longName = "xgb-quiet", usage = "Deprecated, use xgb-loglevel.")
public boolean xgbQuiet;
@Option(longName = "xgb-subsample", usage = "Subsample size for each tree (default 1, range (0,1]).")
@Option(longName = "xgb-loglevel", usage = "Make the XGBoost training procedure quiet.")
public XGBoostTrainer.LoggingVerbosity xgbLogLevel = XGBoostTrainer.LoggingVerbosity.WARNING;
@Option(longName = "xgb-subsample", usage = "Subsample size for each tree (range (0,1]).")
public float xgbSubsample = 1.0f;
@Option(longName = "xgb-num-threads", usage = "Number of threads to use (default 4, range (1, num hw threads)).")
public int xgbNumThreads;
@Option(longName = "xgb-num-threads", usage = "Number of threads to use (range (1, num hw threads)). The default of 0 means use all hw threads.")
public int xgbNumThreads = 0;
@Option(longName = "xgb-seed", usage = "Sets the random seed for XGBoost.")
private long xgbSeed = Trainer.DEFAULT_SEED;

Expand All @@ -54,6 +66,10 @@ public XGBoostClassificationTrainer getTrainer() {
if (xgbEnsembleSize == -1) {
throw new IllegalArgumentException("Please supply the number of trees.");
}
return new XGBoostClassificationTrainer(xgbEnsembleSize, xgbEta, xgbGamma, xgbMaxDepth, xgbMinWeight, xgbSubsample, xgbSubsampleFeatures, xgbLambda, xbgAlpha, xgbNumThreads, xgbQuiet, xgbSeed);
if (xgbQuiet) {
logger.log(Level.WARNING,"Silencing XGBoost, overriding logging verbosity. Please switch to the 'xgb-loglevel' argument.");
xgbLogLevel = XGBoostTrainer.LoggingVerbosity.SILENT;
}
return new XGBoostClassificationTrainer(xgbBoosterType, xgbTreeMethod, xgbEnsembleSize, xgbEta, xgbGamma, xgbMaxDepth, xgbMinWeight, xgbSubsample, xgbSubsampleFeatures, xgbLambda, xbgAlpha, xgbNumThreads, xgbLogLevel, xgbSeed);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.tribuo.classification.example.LabelledDataGenerator;
import org.tribuo.common.xgboost.XGBoostFeatureImportance;
import org.tribuo.common.xgboost.XGBoostModel;
import org.tribuo.common.xgboost.XGBoostTrainer;
import org.tribuo.data.text.TextDataSource;
import org.tribuo.data.text.TextFeatureExtractor;
import org.tribuo.data.text.impl.BasicPipeline;
Expand Down Expand Up @@ -61,6 +62,15 @@ public class TestXGBoost {

private static final XGBoostClassificationTrainer t = new XGBoostClassificationTrainer(50);

private static final XGBoostClassificationTrainer dart = new XGBoostClassificationTrainer(
XGBoostTrainer.BoosterType.DART,XGBoostTrainer.TreeMethod.AUTO,50,0.3,0,6,1,1,1,1,0,1, XGBoostTrainer.LoggingVerbosity.SILENT,42);

private static final XGBoostClassificationTrainer linear = new XGBoostClassificationTrainer(
XGBoostTrainer.BoosterType.LINEAR,XGBoostTrainer.TreeMethod.AUTO,50,0.3,0,6,1,1,1,1,0,1, XGBoostTrainer.LoggingVerbosity.SILENT,42);

private static final XGBoostClassificationTrainer gbtree = new XGBoostClassificationTrainer(
XGBoostTrainer.BoosterType.GBTREE,XGBoostTrainer.TreeMethod.HIST,50,0.3,0,6,1,1,1,1,0,1, XGBoostTrainer.LoggingVerbosity.SILENT,42);

private static final int[] NUM_TREES = new int[]{1,5,10,50};

//on Windows, this resolves to some nonsense like this: /C:/workspace/Classification/XGBoost/target/test-classes/test_input.tribuo
Expand Down Expand Up @@ -168,8 +178,8 @@ private void checkPrediction(String msgPrefix, XGBoostModel<Label> model, Predic
}
}

public Model<Label> testXGBoost(Pair<Dataset<Label>,Dataset<Label>> p) {
Model<Label> m = t.train(p.getA());
public static Model<Label> testXGBoost(XGBoostClassificationTrainer trainer, Pair<Dataset<Label>,Dataset<Label>> p) {
Model<Label> m = trainer.train(p.getA());
LabelEvaluator e = new LabelEvaluator();
LabelEvaluation evaluation = e.evaluate(m,p.getB());
Map<String, List<Pair<String,Double>>> features = m.getTopFeatures(3);
Expand Down Expand Up @@ -205,20 +215,23 @@ public void testFeatureImportanceSmokeTest() {
@Test
public void testDenseData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
Model<Label> model = testXGBoost(p);
Model<Label> model = testXGBoost(t,p);
Helpers.testModelSerialization(model,Label.class);
testXGBoost(dart,p);
testXGBoost(linear,p);
testXGBoost(gbtree,p);
}

@Test
public void testSparseData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.sparseTrainTest();
testXGBoost(p);
testXGBoost(t,p);
}

@Test
public void testSparseBinaryData() {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.binarySparseTrainTest();
testXGBoost(p);
testXGBoost(t,p);
}

@Test
Expand Down
5 changes: 5 additions & 0 deletions Common/XGBoost/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>

<dependencies>
<dependency>
<groupId>${project.groupId}</groupId>
Expand Down Expand Up @@ -68,6 +69,10 @@
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor_2.12</artifactId>
</exclusion>
<exclusion>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_2.12</artifactId>
</exclusion>
<exclusion>
<groupId>org.scala-lang.modules</groupId>
<artifactId>scala-java8-compat_2.12</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,56 @@ public abstract class XGBoostTrainer<T extends Output<T>> implements Trainer<T>,

private static final Logger logger = Logger.getLogger(XGBoostTrainer.class.getName());

protected final Map<String, Object> parameters = new HashMap<>();
/**
* The tree building algorithm.
*/
public enum TreeMethod {
/**
* XGBoost chooses between {@link TreeMethod#EXACT} and {@link TreeMethod#APPROX}
* depending on dataset size.
*/
AUTO("auto"),
/**
* Exact greedy algorithm, enumerates all split candidates.
*/
EXACT("exact"),
/**
* Approximate greedy algorithm, using a quantile sketch of the data and a gradient histogram.
*/
APPROX("approx"),
/**
* Faster histogram optimized approximate algorithm.
*/
HIST("hist"),
/**
* GPU implementation of the {@link TreeMethod#HIST} algorithm.
* <p>
* Note: GPU computation may not be supported on all platforms, and Tribuo is not tested with XGBoost GPU support.
*/
GPU_HIST("gpu_hist");

public final String paramName;

TreeMethod(String paramName) {
this.paramName = paramName;
}
}

/**
* The logging verbosity of the native library.
*/
public enum LoggingVerbosity {
SILENT(0),
WARNING(1),
INFO(2),
DEBUG(3);

public final int value;

LoggingVerbosity(int value) {
this.value = value;
}
}

/**
* The type of XGBoost model.
Expand All @@ -104,6 +153,8 @@ public enum BoosterType {
}
}

protected final Map<String, Object> parameters = new HashMap<>();

@Config(mandatory = true,description="The number of trees to build.")
protected int numTrees;

Expand Down Expand Up @@ -134,12 +185,22 @@ public enum BoosterType {
@Config(description="The number of threads to use at training time.")
private int nThread = 4;

@Config(description="Quiesce all the logging output from the XGBoost C library.")
/**
* Deprecated by XGBoost in favour of the verbosity field.
*/
@Deprecated
@Config(description="Quiesce all the logging output from the XGBoost C library. Deprecated in favour of 'verbosity'.")
private int silent = 1;

@Config(description="Logging verbosity, 0 is silent, 3 is debug.")
private LoggingVerbosity verbosity = LoggingVerbosity.SILENT;

@Config(description="Type of the weak learner.")
private BoosterType booster = BoosterType.GBTREE;

@Config(description="The tree building algorithm to use.")
private TreeMethod treeMethod = TreeMethod.AUTO;

@Config(description="The RNG seed.")
private long seed = Trainer.DEFAULT_SEED;

Expand All @@ -155,6 +216,8 @@ protected XGBoostTrainer(int numTrees, int numThreads, boolean silent) {

/**
* Create an XGBoost trainer.
* <p>
* Sets the boosting algorithm to {@link BoosterType#GBTREE} and the tree building algorithm to {@link TreeMethod#AUTO}.
*
* @param numTrees Number of trees to boost.
* @param eta Step size shrinkage parameter (default 0.3, range [0,1]).
Expand All @@ -173,9 +236,36 @@ protected XGBoostTrainer(int numTrees, int numThreads, boolean silent) {
* @param seed RNG seed.
*/
protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, boolean silent, long seed) {
this(BoosterType.GBTREE,TreeMethod.AUTO,numTrees,eta,gamma,maxDepth,minChildWeight,subsample,featureSubsample,lambda,alpha,nThread,silent ? LoggingVerbosity.SILENT : LoggingVerbosity.INFO,seed);
}

/**
* Create an XGBoost trainer.
*
* @param boosterType The base learning algorithm.
* @param treeMethod The tree building algorithm if using a tree booster.
* @param numTrees Number of trees to boost.
* @param eta Step size shrinkage parameter (default 0.3, range [0,1]).
* @param gamma Minimum loss reduction to make a split (default 0, range
* [0,inf]).
* @param maxDepth Maximum tree depth (default 6, range [1,inf]).
* @param minChildWeight Minimum sum of instance weights needed in a leaf
* (default 1, range [0, inf]).
* @param subsample Subsample size for each tree (default 1, range (0,1]).
* @param featureSubsample Subsample features for each tree (default 1,
* range (0,1]).
* @param lambda L2 regularization term on weights (default 1).
* @param alpha L1 regularization term on weights (default 0).
* @param nThread Number of threads to use (default 4).
* @param verbosity Set the logging verbosity of the native library.
* @param seed RNG seed.
*/
protected XGBoostTrainer(BoosterType boosterType, TreeMethod treeMethod, int numTrees, double eta, double gamma, int maxDepth, double minChildWeight, double subsample, double featureSubsample, double lambda, double alpha, int nThread, LoggingVerbosity verbosity, long seed) {
if (numTrees < 1) {
throw new IllegalArgumentException("Must supply a positive number of trees. Received " + numTrees);
}
this.booster = boosterType;
this.treeMethod = treeMethod;
this.numTrees = numTrees;
this.eta = eta;
this.gamma = gamma;
Expand All @@ -186,7 +276,8 @@ protected XGBoostTrainer(int numTrees, double eta, double gamma, int maxDepth, d
this.lambda = lambda;
this.alpha = alpha;
this.nThread = nThread;
this.silent = silent ? 1 : 0;
this.verbosity = verbosity;
this.silent = 0; // silent is deprecated
this.seed = seed;
}

Expand Down Expand Up @@ -227,8 +318,13 @@ public void postConfig() {
parameters.put("alpha", alpha);
parameters.put("nthread", nThread);
parameters.put("seed", seed);
parameters.put("silent", silent);
if (silent == 1) {
parameters.put("verbosity", 0);
} else {
parameters.put("verbosity", verbosity.value);
}
parameters.put("booster", booster.paramName);
parameters.put("tree_method", treeMethod.paramName);
}

@Override
Expand Down
Loading