Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Core/src/main/java/org/tribuo/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.Queue;
import java.util.Set;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -248,6 +249,7 @@ public TransformerMap createTransformers(TransformationMap transformations) {
ArrayList<String> featureNames = new ArrayList<>(getFeatureMap().keySet());

// Validate map by checking no regex applies to multiple features.
logger.fine(String.format("Processing %d feature specific transforms", transformations.getFeatureTransformations().size()));
Map<String,List<Transformation>> featureTransformations = new HashMap<>();
for (Map.Entry<String,List<Transformation>> entry : transformations.getFeatureTransformations().entrySet()) {
// Compile the regex.
Expand Down Expand Up @@ -283,6 +285,9 @@ public TransformerMap createTransformers(TransformationMap transformations) {
}
if (!transformations.getGlobalTransformations().isEmpty()) {
// Append all the global transformations
int ntransform = featureNames.size();
logger.fine(String.format("Starting %,d global transformations", ntransform));
int ndone = 0;
for (String v : featureNames) {
// Create the queue of feature transformations for this feature
Queue<TransformStatistics> l = featureStats.computeIfAbsent(v, (k) -> new LinkedList<>());
Expand All @@ -293,6 +298,10 @@ public TransformerMap createTransformers(TransformationMap transformations) {
featureStats.put(v, l);
// Generate the sparse count initialised to the number of features.
sparseCount.putIfAbsent(v, new MutableLong(data.size()));
ndone++;
if(logger.isLoggable(Level.FINE) && ndone % 10000 == 0) {
logger.fine(String.format("Completed %,d of %,d global transformations", ndone, ntransform));
}
}
}

Expand Down
11 changes: 11 additions & 0 deletions Core/src/main/java/org/tribuo/MutableDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
* A MutableDataset is a {@link Dataset} with a {@link MutableFeatureMap} which grows over time.
* Whenever an {@link Example} is added to the dataset it observes each feature and output
* keeping appropriate statistics in the {@link FeatureMap} and {@link OutputInfo}.
*/
public class MutableDataset<T extends Output<T>> extends Dataset<T> {

private static final Logger logger = Logger.getLogger(MutableDataset.class.getName());

private static final long serialVersionUID = 1L;

/**
Expand Down Expand Up @@ -193,11 +198,17 @@ public boolean isDense() {
*/
public void transform(TransformerMap transformerMap) {
featureMap.clear();
logger.fine(String.format("Transforming %,d examples", data.size()));
int nt = 0;
for (Example<T> example : data) {
example.transform(transformerMap);
for (Feature f : example) {
featureMap.add(f.getName(),f.getValue());
}
nt++;
if(logger.isLoggable(Level.FINE) && nt % 10000 == 0) {
logger.fine(String.format("Transformed %,d/%,d", nt, data.size()));
}
}
transformProvenances.add(transformerMap.getProvenance());
}
Expand Down
51 changes: 38 additions & 13 deletions Core/src/main/java/org/tribuo/impl/ArrayExample.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,19 @@ public ArrayExample(T output, List<? extends Feature> features) {
*/
public ArrayExample(Example<T> other) {
super(other);
featureNames = new String[other.size()];
featureValues = new double[other.size()];
for (Feature f : other) {
featureNames[size] = f.getName();
featureValues[size] = f.getValue();
size++;
if(other instanceof ArrayExample) {
ArrayExample< T> otherArr = (ArrayExample<T>) other;
featureNames = Arrays.copyOf(otherArr.featureNames, otherArr.size);
featureValues = Arrays.copyOf(otherArr.featureValues, otherArr.size);
size = otherArr.size;
} else {
featureNames = new String[other.size()];
featureValues = new double[other.size()];
for(Feature f : other) {
featureNames[size] = f.getName();
featureValues[size] = f.getValue();
size++;
}
}
}

Expand Down Expand Up @@ -408,14 +415,32 @@ public void set(Feature feature) {

@Override
public void transform(TransformerMap transformerMap) {
for (Map.Entry<String,List<Transformer>> e : transformerMap.entrySet()) {
int index = Arrays.binarySearch(featureNames,0,size,e.getKey());
if (index >= 0) {
double value = featureValues[index];
for (Transformer t : e.getValue()) {
value = t.transform(value);
if(transformerMap.size() < featureNames.length) {
//
// We have fewer transformers than feature names, so let's
// iterate through the map and find the features.
for(Map.Entry<String, List<Transformer>> e : transformerMap.entrySet()) {
int index = Arrays.binarySearch(featureNames, 0, size, e.getKey());
if(index >= 0) {
double value = featureValues[index];
for(Transformer t : e.getValue()) {
value = t.transform(value);
}
featureValues[index] = value;
}
}
} else {
//
// We have more transformers, so let's fetch them by name.
for(int i = 0; i < size; i++) {
List<Transformer> l = transformerMap.get(featureNames[i]);
if(l != null) {
double value = featureValues[i];
for(Transformer t : l) {
value = t.transform(value);
}
featureValues[i] = value;
}
featureValues[index] = value;
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions Core/src/main/java/org/tribuo/transform/TransformTrainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.time.OffsetDateTime;
import java.util.Map;
import java.util.logging.Logger;

/**
* A {@link Trainer} which encapsulates another trainer plus a {@link TransformationMap} object
Expand All @@ -38,6 +39,8 @@
* first call {@link MutableDataset#densify} on the datasets.
*/
public final class TransformTrainer<T extends Output<T>> implements Trainer<T> {

private static final Logger logger = Logger.getLogger(TransformTrainer.class.getName());

@Config(mandatory = true,description="Trainer to use.")
private Trainer<T> innerTrainer;
Expand Down Expand Up @@ -81,10 +84,16 @@ public TransformTrainer(Trainer<T> innerTrainer, TransformationMap transformatio

@Override
public TransformedModel<T> train(Dataset<T> examples, Map<String, Provenance> instanceProvenance) {

logger.fine(String.format("Creating transformers"));
TransformerMap transformerMap = examples.createTransformers(transformations);

logger.fine("Transforming data set");

Dataset<T> transformedDataset = transformerMap.transformDataset(examples,densify);

logger.fine("Running inner trainer");

Model<T> innerModel = innerTrainer.train(transformedDataset);

ModelProvenance provenance = new ModelProvenance(TransformedModel.class.getName(), OffsetDateTime.now(), transformedDataset.getProvenance(), getProvenance(), instanceProvenance);
Expand Down
13 changes: 11 additions & 2 deletions Core/src/main/java/org/tribuo/transform/TransformationMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
Expand All @@ -32,6 +33,8 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Pattern;

/**
Expand All @@ -51,7 +54,7 @@ public class TransformationMap implements Configurable, Provenancable<Configured
@Config(mandatory = true,description="Global transformations to apply after the feature specific transforms.")
private List<Transformation> globalTransformations;

@Config(mandatory = true,description="Feature specific transformations. Accepts regexes for feature names.")
@Config(description="Feature specific transformations. Accepts regexes for feature names.")
private Map<String,TransformationList> featureTransformationList = new HashMap<>();

private final Map<String,List<Transformation>> featureTransformations = new HashMap<>();
Expand All @@ -64,13 +67,14 @@ public class TransformationMap implements Configurable, Provenancable<Configured
private TransformationMap() {}

public TransformationMap(List<Transformation> globalTransformations, Map<String,List<Transformation>> featureTransformations) {
this.globalTransformations = globalTransformations;
this.globalTransformations = new ArrayList<>(globalTransformations);
this.featureTransformations.putAll(featureTransformations);

// Copy values out for provenance
for (Map.Entry<String,List<Transformation>> e : featureTransformations.entrySet()) {
featureTransformationList.put(e.getKey(),new TransformationList(e.getValue()));
}

}

public TransformationMap(List<Transformation> globalTransformations) {
Expand All @@ -86,6 +90,11 @@ public TransformationMap(Map<String,List<Transformation>> featureTransformations
*/
@Override
public void postConfig() {
if(globalTransformations.isEmpty() && featureTransformationList.isEmpty()) {
throw new PropertyException("TransformationMap",
"Both global transformations and feature transformations can't be empty!");
}

for (Map.Entry<String,TransformationList> e : featureTransformationList.entrySet()) {
featureTransformations.put(e.getKey(),e.getValue().list);
}
Expand Down
28 changes: 28 additions & 0 deletions Core/src/main/java/org/tribuo/transform/TransformerMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.logging.Logger;

/**
* A collection of {@link Transformer}s which can be applied to a {@link Dataset}
Expand All @@ -49,6 +50,9 @@
* first call {@link MutableDataset#densify} on the datasets.
*/
public final class TransformerMap implements Provenancable<TransformerMapProvenance>, Serializable {

private static final Logger logger = Logger.getLogger(TransformerMap.class.getName());

private static final long serialVersionUID = 2L;

private final Map<String, List<Transformer>> map;
Expand Down Expand Up @@ -129,16 +133,40 @@ public <T extends Output<T>> MutableDataset<T> transformDataset(Dataset<T> datas
* @return A deep copy of the dataset (and it's examples) with the transformers applied to it's features.
*/
public <T extends Output<T>> MutableDataset<T> transformDataset(Dataset<T> dataset, boolean densify) {

logger.fine("Creating deep copy of data set");

MutableDataset<T> newDataset = MutableDataset.createDeepCopy(dataset);

if (densify) {
newDataset.densify();
}

logger.fine(String.format("Transforming data set copy"));

newDataset.transform(this);

return newDataset;
}

/**
* Gets the size of the map.
*
* @return the size of the map of feature names to transformers.
*/
public int size() {
return map.size();
}

/**
* Gets the transformer associated with a given feature name.
* @param featureName the name of the feature for which we want the transformer
* @return the transformer associated with the feature name, which may be <code>null</code>
* if there is no feature with that name.
*/
public List<Transformer> get(String featureName) {
return map.get(featureName);
}

@Override
public String toString() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package org.tribuo.transform.transformations;

import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.tribuo.transform.TransformStatistics;
import org.tribuo.transform.Transformation;
import org.tribuo.transform.TransformationProvenance;
import org.tribuo.transform.Transformer;

/**
* A feature transformation that computes the IDF for features and then transforms
* them with a TF-IDF weighting.
*/
public class IDFTransformation implements Transformation {

private TransformationProvenance provenance;

@Override
public TransformStatistics createStats() {
return new IDFStatistics();
}

@Override
public TransformationProvenance getProvenance() {
if(provenance == null) {
provenance = new IDFTransformationProvenance();
}
return provenance;
}

private static class IDFStatistics implements TransformStatistics {

/**
* The document frequency for the feature that this statistic is
* associated with. This is a count of the number of examples that the
* feature occurs in.
*/
private int df;

/**
* The number of examples that the feature did not occur in.
*/
private int sparseObservances;


@Override
public void observeValue(double value) {
//
// One more document (i.e., an example) has this feature.
df++;
}

@Override
public void observeSparse() {
}

@Override
public void observeSparse(int count) {
sparseObservances = count;
}

@Override
public Transformer generateTransformer() {
return new IDFTransformer(df, df+sparseObservances);
}

}

private static class IDFTransformer implements Transformer {

private double df;

private double N;

public IDFTransformer(int df, int N) {
this.df = df;
this.N = N;
}

@Override
public double transform(double tf) {
return Math.log(N / df) * (1 + Math.log(tf));
}

}

public final static class IDFTransformationProvenance implements TransformationProvenance {

@Override
public Map<String, Provenance> getConfiguredParameters() {
return Collections.unmodifiableMap(new HashMap<>());
}

@Override
public String getClassName() {
return IDFTransformation.class.getName();
}

}

}