diff --git a/java/core/src/java/org/apache/orc/impl/mask/RedactMaskFactory.java b/java/core/src/java/org/apache/orc/impl/mask/RedactMaskFactory.java index 95741397a4..ef10fe628d 100644 --- a/java/core/src/java/org/apache/orc/impl/mask/RedactMaskFactory.java +++ b/java/core/src/java/org/apache/orc/impl/mask/RedactMaskFactory.java @@ -17,6 +17,7 @@ */ package org.apache.orc.impl.mask; +import org.apache.commons.lang.StringUtils; import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DecimalColumnVector; @@ -25,15 +26,19 @@ import org.apache.hadoop.hive.ql.exec.vector.TimestampColumnVector; import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.io.Text; -import org.apache.orc.TypeDescription; import org.apache.orc.DataMask; +import org.apache.orc.TypeDescription; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Calendar; +import java.util.Map; +import java.util.SortedMap; import java.util.TimeZone; +import java.util.TreeMap; import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; import java.util.regex.Pattern; /** @@ -114,6 +119,9 @@ public class RedactMaskFactory extends MaskFactory { private final boolean maskDate; private final boolean maskTimestamp; + // index tuples that are not to be masked + private final SortedMap unmaskIndexRanges = new TreeMap(); + public RedactMaskFactory(String... params) { ByteBuffer param = params.length < 1 ? ByteBuffer.allocate(0) : ByteBuffer.wrap(params[0].getBytes(StandardCharsets.UTF_8)); @@ -131,6 +139,9 @@ public RedactMaskFactory(String... params) { String[] timeParams; if (params.length < 2) { timeParams = null; + } else if(StringUtils.isBlank(params[1])) { + // Date params blank + timeParams = null; } else { timeParams = params[1].split("\\W+"); } @@ -146,6 +157,19 @@ public RedactMaskFactory(String... params) { maskTimestamp = maskDate || (HOUR_REPLACEMENT != UNMASKED_DATE) || (MINUTE_REPLACEMENT != UNMASKED_DATE) || (SECOND_REPLACEMENT != UNMASKED_DATE); + + /* un-mask range */ + String[] unmaskIndexes; + if(params.length >= 3 && !StringUtils.isBlank(params[2])) { + unmaskIndexes = params[2].split(","); + + for(int i=0; i < unmaskIndexes.length; i++ ) { + String[] pair = unmaskIndexes[i].trim().split(":"); + unmaskIndexRanges.put(Integer.parseInt(pair[0]), Integer.parseInt(pair[1])); + } + + } + } @Override @@ -451,6 +475,16 @@ static int getDateParam(String[] dateParams, int posn, * @return the masked value */ public long maskLong(long value) { + + /* check whether unmasking range provided */ + try { + if (!unmaskIndexRanges.isEmpty()) { + return unmaskRangeLongValue(value); + } + } catch(final IndexOutOfBoundsException e) { + // Bad range, we move on and return the mask without un-masking. + } + long base; if (DIGIT_REPLACEMENT == 0) { return 0; @@ -521,6 +555,7 @@ public long maskLong(long value) { } else { base *= 1_111_111_111_111_111_111L; } + return DIGIT_REPLACEMENT * base; } @@ -601,6 +636,16 @@ public long maskLong(long value) { * @return the */ public double maskDouble(double value) { + + /* check whether unmasking range provided */ + try { + if (!unmaskIndexRanges.isEmpty()) { + return unmaskRangeDoubleValue(value); + } + } catch(final IndexOutOfBoundsException e) { + // Bad range, we move on and return the mask without un-masking. + } + double base; // It seems better to mask 0 to 9.99999 rather than 9.99999e-308. if (value == 0 || DIGIT_REPLACEMENT == 0) { @@ -695,9 +740,23 @@ int maskDate(int daysSinceEpoch) { * @return the masked value. */ HiveDecimalWritable maskDecimal(HiveDecimalWritable source) { - String str = DIGIT_PATTERN.matcher(source.toString()). - replaceAll(Integer.toString(DIGIT_REPLACEMENT)); - return new HiveDecimalWritable(str); + // No unmasking range + if(unmaskIndexRanges.isEmpty()) { + String str = DIGIT_PATTERN.matcher(source.toString()).replaceAll(Integer.toString(DIGIT_REPLACEMENT)); + return new HiveDecimalWritable(str); + } else { + final StringBuffer result = new StringBuffer(); + // get the ranges that need to be masked + for(final Map.Entry map : getInverseIndexRange(source.toString().length()).entrySet() ) { + final Matcher m = DIGIT_PATTERN.matcher(source.toString()).region(map.getKey(), map.getValue() + 1); + while(m.find()) { + m.appendReplacement(result, Integer.toString(DIGIT_REPLACEMENT)); + } + m.appendTail(result); + } + return new HiveDecimalWritable(result.toString()); + } + } /** @@ -815,14 +874,20 @@ void maskString(BytesColumnVector source, int row, BytesColumnVector target) { byte[] outputBuffer = target.getValPreallocatedBytes(); int outputOffset = target.getValPreallocatedStart(); int outputStart = outputOffset; + + int index = 0; while (sourceBytes.remaining() > 0) { int cp = Text.bytesToCodePoint(sourceBytes); // Find the replacement for the current character. int replacement = getReplacement(cp); - if (replacement == UNMASKED_CHAR) { + if (replacement == UNMASKED_CHAR || isIndexInUnmaskRange(index, source.length[row])) { replacement = cp; } + + // increment index + index++; + int len = getCodepointLength(replacement); // If the translation will overflow the buffer, we need to resize. @@ -854,4 +919,174 @@ void maskString(BytesColumnVector source, int row, BytesColumnVector target) { } target.setValPreallocated(row, outputOffset - outputStart); } + + /** + * A function that accepts the original value and the computed Mask then tries + * to un-mask any masked values. + *

+ * Returns the value which are partially masked (if configured), else + * returns the mask. + * @param value + * @return + */ + long unmaskRangeLongValue(final long value) throws IndexOutOfBoundsException { + + final StringBuffer result = new StringBuffer(); + unmaskRangeDigitHelper(String.valueOf(value), result); + return Long.parseLong(result.toString()); + } + + /** + * + * @param value + * @return + */ + double unmaskRangeDoubleValue(final double value) { + + final StringBuffer result = new StringBuffer(); + unmaskRangeDigitHelper(String.valueOf(value), result); + return Double.valueOf(result.toString()); + + } + + /** + * A helper method that does partial masking based on the the
+ * un-marking arguments. + * @param value + * @param result + */ + void unmaskRangeDigitHelper(final String value, final StringBuffer result) { + // get the ranges that need to be masked + + for (final Map.Entry map : getInverseIndexRange( + value.length()).entrySet()) { + final Matcher m = DIGIT_PATTERN.matcher(value) + .region(map.getKey(), map.getValue() + 1); + while (m.find()) { + m.appendReplacement(result, Integer.toString(DIGIT_REPLACEMENT)); + } + m.appendTail(result); + } + } + + /** + * Given an index and length of a string + * find out whether it is in a given un-mask range. + * @param index + * @param length + * @return true if the index is in un-mask range else false. + */ + private boolean isIndexInUnmaskRange(final int index, final int length) { + + for(final Map.Entry pair : unmaskIndexRanges.entrySet()) { + int start = 0; + int end = 0; + + if(pair.getKey() >= 0) { + // for positive indexes + start = pair.getKey(); + } else { + // for negative indexes + start = length + pair.getKey(); + } + + if(pair.getValue() >= 0) { + // for positive indexes + end = pair.getValue(); + } else { + // for negative indexes + end = length + pair.getValue(); + } + + // if the given index is in range + if(index >= start && index <= end ) { + return true; + } + + } + + return false; + } + + /** + * A helper method that converts negative indexes to positive given the length + * of string and returns a sorted map based on keys (start index) + * @param length + * @return + */ + private SortedMap getPositiveUnmaskRangeIndexes(final int length) { + + /* Always return a sorted map */ + final SortedMap result = new TreeMap(); + + for(final Map.Entry pair : unmaskIndexRanges.entrySet()) { + int start = 0; + int end = 0; + + if(pair.getKey() >= 0) { + // for positive indexes + start = pair.getKey(); + } else { + // for negative indexes + start = length + pair.getKey(); + } + + if(pair.getValue() >= 0) { + // for positive indexes + end = pair.getValue(); + } else { + // for negative indexes + end = length + pair.getValue(); + } + + result.put(start, end); + + } + + return result; + } + + /** + * Given the length
+ * this method returns index ranges which exclude the supplied range.
+ * i.e. for unmask range, return masked range. + * NOTE: This function requires the Map to be sorted. + * @param length + * @return + */ + private Map getInverseIndexRange(final int length) { + + final SortedMap inverse = new TreeMap<>(); + + // Normalize the range indexes + final SortedMap range = getPositiveUnmaskRangeIndexes(length); + + int startIndex = 0; + int endIndex = 0; + + for(final Map.Entry pair : range.entrySet()) { + + // Unmasking states from first index so we move on + if(pair.getKey() == 0) { + startIndex = pair.getValue() + 1; + endIndex = pair.getValue(); + continue; + } + + inverse.put(startIndex, pair.getKey() - 1); + startIndex = pair.getValue() + 1; + + endIndex = (endIndex >= pair.getValue()) ? endIndex : pair.getValue(); + + } + + // final range + if(endIndex < length -1) { + inverse.put(startIndex, length -1 ); + } + + return inverse; + + } + } diff --git a/java/core/src/test/org/apache/orc/impl/mask/TestUnmaskRange.java b/java/core/src/test/org/apache/orc/impl/mask/TestUnmaskRange.java new file mode 100644 index 0000000000..53dbd0d658 --- /dev/null +++ b/java/core/src/test/org/apache/orc/impl/mask/TestUnmaskRange.java @@ -0,0 +1,157 @@ +package org.apache.orc.impl.mask; + +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import org.apache.hadoop.hive.ql.exec.vector.BytesColumnVector; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; + +import static org.junit.Assert.assertEquals; + +/** + * Test Unmask option + */ +public class TestUnmaskRange { + + public TestUnmaskRange() { + super(); + } + + /* Test for Long */ + @Test + public void testSimpleLongRangeMask() { + RedactMaskFactory mask = new RedactMaskFactory("9", "", "0:2"); + long result = mask.maskLong(123456); + assertEquals(123_999, result); + + // negative index + mask = new RedactMaskFactory("9", "", "-3:-1"); + result = mask.maskLong(123456); + assertEquals(999_456, result); + + // out of range mask, return the original mask + mask = new RedactMaskFactory("9", "", "7:10"); + result = mask.maskLong(123456); + assertEquals(999999, result); + + } + + @Test + public void testDefaultRangeMask() { + RedactMaskFactory mask = new RedactMaskFactory("9", "", ""); + long result = mask.maskLong(123456); + assertEquals(999999, result); + + mask = new RedactMaskFactory("9"); + result = mask.maskLong(123456); + assertEquals(999999, result); + + } + + @Test + public void testCCRangeMask() { + long cc = 4716885592186382L; + long maskedCC = 4716_77777777_6382L; + // Range unmask for first 4 and last 4 of credit card number + final RedactMaskFactory mask = new RedactMaskFactory("Xx7", "", "0:3,-4:-1"); + long result = mask.maskLong(cc); + + assertEquals(String.valueOf(cc).length(), String.valueOf(result).length()); + assertEquals(4716_77777777_6382L, result); + } + + /* Tests for Double */ + @Test + public void testSimpleDoubleRangeMask() { + RedactMaskFactory mask = new RedactMaskFactory("Xx7", "", "0:2"); + assertEquals(1237.77, mask.maskDouble(1234.99), 0.000001); + assertEquals(12377.7, mask.maskDouble(12345.9), 0.000001); + + mask = new RedactMaskFactory("Xx7", "", "-3:-1"); + assertEquals(7774.9, mask.maskDouble(1234.9), 0.000001); + + } + + /* test for String */ + @Test + public void testStringRangeMask() { + + BytesColumnVector source = new BytesColumnVector(); + BytesColumnVector target = new BytesColumnVector(); + target.reset(); + + byte[] input = "Mary had 1 little lamb!!".getBytes(StandardCharsets.UTF_8); + source.setRef(0, input, 0, input.length); + + // Set a 4 byte chinese character (U+2070E), which is letter other + input = "\uD841\uDF0E".getBytes(StandardCharsets.UTF_8); + source.setRef(1, input, 0, input.length); + + RedactMaskFactory mask = new RedactMaskFactory("", "", "0:3, -5:-1"); + for(int r=0; r < 2; ++r) { + mask.maskString(source, r, target); + } + + assertEquals("Mary xxx 9 xxxxxx xamb!!", new String(target.vector[0], + target.start[0], target.length[0], StandardCharsets.UTF_8)); + assertEquals("\uD841\uDF0E", new String(target.vector[1], + target.start[1], target.length[1], StandardCharsets.UTF_8)); + + // test defaults, no-unmask range + mask = new RedactMaskFactory(); + for(int r=0; r < 2; ++r) { + mask.maskString(source, r, target); + } + + assertEquals("Xxxx xxx 9 xxxxxx xxxx..", new String(target.vector[0], + target.start[0], target.length[0], StandardCharsets.UTF_8)); + assertEquals("ª", new String(target.vector[1], + target.start[1], target.length[1], StandardCharsets.UTF_8)); + + + // test out of range string mask + mask = new RedactMaskFactory("", "", "-1:-5"); + for(int r=0; r < 2; ++r) { + mask.maskString(source, r, target); + } + + assertEquals("Xxxx xxx 9 xxxxxx xxxx..", new String(target.vector[0], + target.start[0], target.length[0], StandardCharsets.UTF_8)); + assertEquals("ª", new String(target.vector[1], + target.start[1], target.length[1], StandardCharsets.UTF_8)); + + } + + /* test for Decimal */ + @Test + public void testDecimalRangeMask() { + + RedactMaskFactory mask = new RedactMaskFactory("Xx7", "", "0:3"); + assertEquals(new HiveDecimalWritable("123477.777"), + mask.maskDecimal(new HiveDecimalWritable("123456.789"))); + + // try with a reverse index + mask = new RedactMaskFactory("Xx7", "", "-3:-1, 0:3"); + assertEquals(new HiveDecimalWritable("123477777.777654"), + mask.maskDecimal(new HiveDecimalWritable("123456789.987654"))); + + } + +}