Skip to content

Commit

Permalink
FeatureHasher should have an option to not hash the values.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Dec 19, 2022
1 parent 4874f38 commit d25b8b6
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,8 +24,11 @@
import java.util.List;

/**
* A feature transformer maps a list of features to a new list of features
* Useful for example to apply the hashing trick to a set of features
* A feature transformer maps a list of features to a new list of features.
* Useful for example to apply the hashing trick to a set of features.
* <p>
* Note a list of features returned by a {@code FeatureTransformer} may contain
* duplicate features, and should be reduced to ensure that each feature is unique.
*/
public interface FeatureTransformer extends Configurable, Provenancable<ConfiguredObjectProvenance> {

Expand Down
66 changes: 58 additions & 8 deletions Data/src/main/java/org/tribuo/data/text/impl/FeatureHasher.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,15 +36,62 @@ public class FeatureHasher implements FeatureTransformer {

private static final Logger logger = Logger.getLogger(FeatureHasher.class.getName());

/**
* Default value for the hash function seed.
*/
public static final int DEFAULT_HASH_SEED = 38495;

/**
* Default value for the value hash function seed.
*/
public static final int DEFAULT_HASH_VALUE_SEED = 77777;

@Config(mandatory = true,description="Dimension to map the hash into.")
private int dimension;

@Config(description = "Seed used in the hash function.")
private int hashSeed = DEFAULT_HASH_SEED;

@Config(description = "Seed used for value hash function.")
private int hashValueSeed = DEFAULT_HASH_VALUE_SEED;

@Config(description = "Preserve input feature value.")
private boolean preserveValue = false;

/**
* Constructs a feature hasher using the supplied hash dimension.
* <p>
* Note the hasher also hashes the feature value into {-1, 1}.
* @param dimension The dimension to reduce the hashed features into.
*/
public FeatureHasher(int dimension) {
this(dimension, DEFAULT_HASH_SEED, DEFAULT_HASH_VALUE_SEED, false);
}

/**
* Constructs a feature hasher using the supplied hash dimension.
* @param dimension The dimension to reduce the hashed features into.
* @param preserveValue If true the feature value is used unaltered in the new features,
* if false it is hashed into the values {-1, 1}.
*/
public FeatureHasher(int dimension, boolean preserveValue) {
this(dimension, DEFAULT_HASH_SEED, DEFAULT_HASH_VALUE_SEED, preserveValue);
}

/**
* Constructs a feature hasher using the supplied hash dimension and seed values.
* @param dimension The dimension to reduce the hashed features into.
* @param hashSeed The seed used in the murmurhash computation.
* @param hashValueSeed The seed used in the murmurhash computation for the feature value,
* unused if {@code preserveValue} is true.
* @param preserveValue If true the feature value is used unaltered in the new features,
* if false it is hashed into the values {-1, 1}.
*/
public FeatureHasher(int dimension, int hashSeed, int hashValueSeed, boolean preserveValue) {
this.dimension = dimension;
this.hashSeed = hashSeed;
this.hashValueSeed = hashValueSeed;
this.preserveValue = preserveValue;
}

/**
Expand All @@ -58,20 +105,23 @@ public List<Feature> map(String tag, List<Feature> features) {
List<Feature> hashedFeatures = new ArrayList<>();

for (Feature feature : features) {
int hash = MurmurHash3.murmurhash3_x86_32(feature.getName(), 0, feature.getName().length(), 38495);
//int bit = hash & 1;
int bit = MurmurHash3.murmurhash3_x86_32(feature.getName(), 0, feature.getName().length(), 77777) & 1;
int hash = MurmurHash3.murmurhash3_x86_32(feature.getName(), 0, feature.getName().length(), hashSeed);
hash = hash >>> 1;
int code = hash % dimension;

int change = bit == 1 ? 1 : -1;

Feature newFeature = new Feature(tag + "-hash="+code,change);
double value;
if (preserveValue) {
value = feature.getValue();
} else {
int bit = MurmurHash3.murmurhash3_x86_32(feature.getName(), 0, feature.getName().length(), hashValueSeed) & 1;
value = bit == 1 ? 1 : -1;
}

Feature newFeature = new Feature(tag + "-hash="+code, value);
hashedFeatures.add(newFeature);
}

return hashedFeatures;

}

@Override
Expand Down
50 changes: 40 additions & 10 deletions Data/src/main/java/org/tribuo/data/text/impl/TokenPipeline.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,16 +43,19 @@ public class TokenPipeline implements TextPipeline {
private List<FeatureTransformer> transformers = new ArrayList<>();
private FeatureAggregator aggregator;

@Config(mandatory = true,description="Use term counting, otherwise emit binary features.")
@Config(mandatory = true, description = "Use term counting, otherwise emit binary features.")
private boolean termCounting;

@Config(description="Dimension to map the hash into.")
@Config(description = "Dimension to map the hash into.")
private int hashDim = -1;

@Config(mandatory = true,description="Tokenizer to use.")
@Config(description = "Should feature hashing preserve the value?")
private boolean hashPreserveValue = true;

@Config(mandatory = true, description = "Tokenizer to use.")
private Tokenizer tokenizer;

@Config(description="n in the n-gram to emit.")
@Config(description = "n in the n-gram to emit.")
private int ngram = 2;

/**
Expand All @@ -73,7 +76,31 @@ public class TokenPipeline implements TextPipeline {
public TokenPipeline(Tokenizer tokenizer, int ngram, boolean termCounting) {
this(tokenizer, ngram, termCounting, -1);
}


/**
* Creates a new token pipeline.
*
* @param tokenizer The tokenizer to use to split up the text into words
* (i.e., features.)
* @param ngram The maximum size of ngram features to add to the features
* generated by the pipeline. A value of {@code n} means that ngram
* features of size 1-n will be generated. A good standard value to use is
* 2, which means that unigram and bigram features will be generated. You
* will very likely see diminishing returns for larger values of
* {@code n} but there will be times when they will be necessary.
* @param termCounting If {@code true}, multiple occurrences of terms
* in the document will be counted and the count will be used as the value
* of the features that are produced.
* @param dimension The maximum dimension for the feature space. If this value
* is greater than 0, then at most {@code dimension} features will be
* through the use of a hashing function that will collapse the feature
* space. This {@code TokenPipeline} will preserve the feature values when hashing,
* w.
*/
public TokenPipeline(Tokenizer tokenizer, int ngram, boolean termCounting, int dimension) {
this(tokenizer, ngram, termCounting, dimension, true);
}

/**
* Creates a new token pipeline.
*
Expand All @@ -88,16 +115,19 @@ public TokenPipeline(Tokenizer tokenizer, int ngram, boolean termCounting) {
* @param termCounting If {@code true}, multiple occurrences of terms
* in the document will be counted and the count will be used as the value
* of the features that are produced.
* @param dimension The maximum dimension for the feature space. If this value
* @param dimension The maximum dimension for the feature space. If this value
* is greater than 0, then at most {@code dimension} features will be
* through the use of a hashing function that will collapse the feature
* through the use of a hashing function that will collapse the feature
* space.
* @param hashPreserveValue If true, the hash function preserves the feature value, if false
* it hashes it into the values {-1, 1}.
*/
public TokenPipeline(Tokenizer tokenizer, int ngram, boolean termCounting, int dimension) {
public TokenPipeline(Tokenizer tokenizer, int ngram, boolean termCounting, int dimension, boolean hashPreserveValue) {
this.tokenizer = tokenizer;
this.ngram = ngram;
this.hashDim = dimension;
this.termCounting = termCounting;
this.hashPreserveValue = hashPreserveValue;
postConfig();
}

Expand All @@ -115,7 +145,7 @@ public void postConfig() {
processors.add(new NgramProcessor(tokenizer,i,1));
}
if (hashDim > 0) {
transformers.add(new FeatureHasher(hashDim));
transformers.add(new FeatureHasher(hashDim, hashPreserveValue));
}
if (termCounting) {
aggregator = new SumAggregator();
Expand Down
30 changes: 29 additions & 1 deletion Data/src/test/java/org/tribuo/data/text/TextPipelineTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -101,6 +101,34 @@ public void testTokenPipeline() {
assertTrue(featureList.contains(new Feature("2-N=input/text",1.0)));
}

@Test
public void testHashingTokenPipeline() {
String input = "This is some input text.";

TokenPipeline pipeline = new TokenPipeline(new BreakIteratorTokenizer(Locale.US),2,true, 10);

List<Feature> featureList = pipeline.process("test",input);

assertTrue(featureList.contains(new Feature("test-hash=1",1.0)));
assertTrue(featureList.contains(new Feature("test-hash=2",2.0)));
assertTrue(featureList.contains(new Feature("test-hash=3",5.0)));
assertTrue(featureList.contains(new Feature("test-hash=5",1.0)));
assertTrue(featureList.contains(new Feature("test-hash=6",1.0)));
assertTrue(featureList.contains(new Feature("test-hash=7",1.0)));

TokenPipeline hashedValuePipeline = new TokenPipeline(new BreakIteratorTokenizer(Locale.US),2,true, 10, false);

List<Feature> hashedValueFeatureList = hashedValuePipeline.process("test",input);

assertTrue(hashedValueFeatureList.contains(new Feature("test-hash=1",1.0)));
assertTrue(hashedValueFeatureList.contains(new Feature("test-hash=2",0.0)));
assertTrue(hashedValueFeatureList.contains(new Feature("test-hash=3",-1.0)));
assertTrue(hashedValueFeatureList.contains(new Feature("test-hash=5",-1.0)));
assertTrue(hashedValueFeatureList.contains(new Feature("test-hash=6",1.0)));
assertTrue(hashedValueFeatureList.contains(new Feature("test-hash=7",-1.0)));
}


@Test
public void testTokenPipelineTagging() {
String input = "This is some input text.";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.data.text.impl;

import org.junit.jupiter.api.Test;
import org.tribuo.Feature;

import java.util.ArrayList;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class FeatureHasherTest {

@Test
public void negativeValuesTest() {
List<Feature> input = new ArrayList<>();
Feature posValue = new Feature("Testing", 2.0);
input.add(posValue);
Feature negValue = new Feature("Test",2.0);
input.add(negValue);

FeatureHasher preserving = new FeatureHasher(10, true);
FeatureHasher notPreserving = new FeatureHasher(10, false);

List<Feature> preservingOutput = preserving.map("test", input);
List<Feature> notPreservingOutput = notPreserving.map("test", input);

assertEquals(2.0, preservingOutput.get(0).getValue());
assertEquals(2.0, preservingOutput.get(1).getValue());

assertEquals(1.0, notPreservingOutput.get(0).getValue());
assertEquals(-1.0, notPreservingOutput.get(1).getValue());
}

}

0 comments on commit d25b8b6

Please sign in to comment.