Skip to content

Commit

Permalink
Implement microaggregation function mode with fallback distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
prasser committed Aug 24, 2024
1 parent 5d9add0 commit 668f8aa
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 1 deletion.
31 changes: 31 additions & 0 deletions src/main/org/deidentifier/arx/AttributeType.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction;
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionArithmeticMean;
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionGeometricMean;
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionInterval;
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionMedian;
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionMode;
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionModeWithDistributionFallback;
import org.deidentifier.arx.framework.check.distribution.DistributionAggregateFunction.DistributionAggregateFunctionSet;
import org.deidentifier.arx.io.CSVDataOutput;
import org.deidentifier.arx.io.CSVHierarchyInput;
Expand Down Expand Up @@ -732,6 +734,18 @@ public static MicroAggregationFunction createMedian(boolean ignoreMissingData) {
public static MicroAggregationFunction createMode() {
return createMode(true);
}

/**
* Creates a microaggregation function returning the mode. If more than one value qualifies as mode,
* the function draws from the qualifying values using the provided distribution. Ignores missing data.
*
* @param distribution Map from values to frequencies
* @param seed Seed to use for drawing, can be null
* @return
*/
public static MicroAggregationFunction createModeWithDistributionFallback(Map<String, Double> distribution, long seed) {
return createModeWithDistributionFallback(true, distribution, seed);
}

/**
* Creates a microaggregation function returning the mode.
Expand All @@ -743,6 +757,23 @@ public static MicroAggregationFunction createMode(boolean ignoreMissingData) {
return new MicroAggregationFunction(new DistributionAggregateFunctionMode(ignoreMissingData),
DataScale.NOMINAL, "Mode");
}

/**
* Creates a microaggregation function returning the mode. If more than one value qualifies as mode,
* the function draws from the qualifying values using the provided distribution.
*
* @param ignoreMissingData Should the function ignore missing data. Default is true.
* @param distribution Map from values to frequencies
* @param seed Seed to use for drawing, can be null
* @return
*/
public static MicroAggregationFunction createModeWithDistributionFallback(boolean ignoreMissingData,
Map<String, Double> distribution,
Long seed) {
return new MicroAggregationFunction(new DistributionAggregateFunctionModeWithDistributionFallback(ignoreMissingData, distribution, seed),
DataScale.NOMINAL, "Mode with distribution fallback");
}

/**
* Creates a microaggregation function returning sets. This variant will ignore missing data.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.deidentifier.arx.aggregates;

import java.util.LinkedHashMap;
import java.util.Map;

/**
* A frequency distribution.
*
Expand Down Expand Up @@ -44,4 +47,16 @@ public class StatisticsFrequencyDistribution {
this.count = count;
this.frequency = frequency;
}

/**
* Returns the distribution as a map
* @return
*/
public Map<String, Double> asMap() {
Map<String, Double> map = new LinkedHashMap<>();
for (int i = 0; i < frequency.length; i++) {
map.put(values[i], frequency[i]);
}
return map;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@
package org.deidentifier.arx.framework.check.distribution;

import java.io.Serializable;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;

import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.deidentifier.arx.DataType;
import org.deidentifier.arx.DataType.DataTypeWithRatioScale;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.IntDoubleOpenHashMap;

import cern.colt.GenericSorting;
import cern.colt.Swapper;
import cern.colt.function.IntComparator;
Expand Down Expand Up @@ -459,7 +465,6 @@ private <T> T getValueAt(List<T> values, List<Integer> frequencies, int index) {
}
}


/**
* This class calculates the mode for a given distribution.
*
Expand Down Expand Up @@ -583,6 +588,223 @@ private int getMode(Distribution distribution) {
}
}

/**
* This class calculates the mode for a given distribution falling back to drawing from multiple values that would qualify as
* mode using the provided distribution
*
* @author Fabian Prasser
*
*/
public static class DistributionAggregateFunctionModeWithDistributionFallback extends DistributionAggregateFunction {

/** SVUID. */
private static final long serialVersionUID = 6285156778817664604L;

/** Minimum */
private double minimum = 0d;

/** Maximum */
private double maximum = 0d;

/** Distribution*/
private Map<String, Double> distribution;

/** Integer distribution*/
private IntDoubleOpenHashMap intDistribution;

/** The seed to use*/
private Long seed;

/** The random source to use*/
private Random random;

/**
* Instantiates.
*
* @param ignoreMissingData
* @param distribution
* @param seed Maybe null
*/
public DistributionAggregateFunctionModeWithDistributionFallback(boolean ignoreMissingData,
Map<String, Double> distribution,
Long seed) {
super(ignoreMissingData, true);
this.distribution = distribution;
this.seed = seed;
if (this.seed == null) {
this.random = new SecureRandom();
} else {
this.random = new Random(this.seed);
}
}

/**
* Clone constructor
* @param ignoreMissingData
* @param minimum
* @param maximum
* @param distribution
* @param seed Maybe null
*/
private DistributionAggregateFunctionModeWithDistributionFallback(boolean ignoreMissingData,
double minimum,
double maximum,
Map<String, Double> distribution,
Long seed) {
this(ignoreMissingData, distribution, seed);
this.minimum = minimum;
this.maximum = maximum;
}

@Override
public <T> String aggregate(Distribution distribution) {

// Determine mode
int mode = getModeWithDistributionFallback(distribution);
return mode == -1 ? DataType.NULL_VALUE : dictionary[mode];
}

/**
* Clone method
*/
public DistributionAggregateFunctionModeWithDistributionFallback clone() {
DistributionAggregateFunctionModeWithDistributionFallback result = new DistributionAggregateFunctionModeWithDistributionFallback(this.ignoreMissingData,
this.minimum,
this.maximum,
this.distribution,
this.seed);
if (dictionary != null) {
result.initialize(dictionary, type);
}
return result;
}

@Override
public <T> double getError(Distribution distribution) {

if (!(type instanceof DataTypeWithRatioScale)) {
return 0d;
}

@SuppressWarnings("unchecked")
DataTypeWithRatioScale<T> rType = (DataTypeWithRatioScale<T>) this.type;
DoubleArrayList list = new DoubleArrayList();
Iterator<Double> it = DistributionIterator.createIteratorDouble(distribution, dictionary, rType);
while (it.hasNext()) {
Double value = it.next();
value = value == null ? (ignoreMissingData ? null : 0d) : value;
if (value != null) {
list.add(value);
}
}

// Determine and check mode
int mode = getModeWithDistributionFallback(distribution);
if (mode == -1) {
return 1d;
}

// Compute error
return getNMSE(minimum, maximum, Arrays.copyOf(list.elements(), list.size()),
rType.toDouble(rType.parse(dictionary[mode])));
}

@Override
public void initialize(String[] dictionary, DataType<?> type) {
super.initialize(dictionary, type);
if (type instanceof DataTypeWithRatioScale) {
double[] values = getMinMax(dictionary, (DataTypeWithRatioScale<?>)type);
this.minimum = values[0];
this.maximum = values[1];
}
intDistribution = new IntDoubleOpenHashMap();
int index = 0;
for (String value : dictionary) {
Double frequency = this.distribution.get(value);
if (frequency != null) {
intDistribution.put(index, frequency);
}
index++;
}
}

/**
* Returns the index of the most frequent element from the distribution. Draws from the most frequent items if
* there are multiple, using the provided distribution. Returns -1 if there is no such element.
* @param distribution
* @return
*/
private int getModeWithDistributionFallback(Distribution distribution) {

// Prepare
int[] buckets = distribution.getBuckets();
int max = -1;
IntArrayList mode = new IntArrayList();

// Iterate through distribution, collecting the mode
for (int i = 0; i < buckets.length; i += 2) {
int value = buckets[i];
int frequency = buckets[i + 1];
if (value != -1) {
// Same frequency
if (Math.abs(max - frequency) < 1e-9) {
mode.add(value);
// More frequent
} else if (frequency > max) {
max = frequency;
mode.clear();
mode.add(value);
}
}
}

// Weird
if (mode.isEmpty()) {
return -1;

// Exactly one mode
} else if (mode.size() == 1) {
return mode.get(0);

// Need to draw from distribution
} else {

// Collect frequencies
DoubleArrayList frequencies = new DoubleArrayList();
for (int i = 0; i < mode.size(); i++) {
int code = mode.get(i);
if (this.intDistribution.containsKey(code)) {
frequencies.add(this.intDistribution.get(code));
} else {
frequencies.add(0d);
}
}

// Convert frequencies to cumulative frequencies
for (int i = 1; i < frequencies.size(); i++) {
frequencies.set(i, frequencies.get(i) + frequencies.get(i - 1));
}

// Normalize
double maxFrequency = frequencies.get(frequencies.size() - 1);
for (int i = 0; i < frequencies.size(); i++) {
frequencies.set(i, frequencies.get(i) / maxFrequency);
}

// Draw
double r = random.nextDouble();
for (int i = 0; i < frequencies.size(); i++) {
if (r <= frequencies.get(i)) {
return mode.get(i);
}
}
}

// Should never happen
return -1;
}
}

/**
* This class calculates a set for a given distribution.
*
Expand Down

0 comments on commit 668f8aa

Please sign in to comment.