diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java new file mode 100644 index 00000000000..b23efd7f104 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.compare; + +import java.util.List; + +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.NonNullableStructVector; +import org.apache.arrow.vector.complex.UnionVector; + +/** + * Visitor to compare floating point. + */ +public class ApproxEqualsVisitor extends RangeEqualsVisitor { + + /** + * The float/double values are treated as equal as long as the delta is <= epsilon. + */ + private final float epsilon; + + public ApproxEqualsVisitor(ValueVector right, float epsilon) { + this (right, epsilon, true); + } + + public ApproxEqualsVisitor(ValueVector right, float epsilon, boolean typeCheckNeeded) { + this (right, epsilon, typeCheckNeeded, 0, 0, right.getValueCount()); + } + + public ApproxEqualsVisitor(ValueVector right, float epsilon, boolean typeCheckNeeded, + int leftStart, int rightStart, int length) { + super(right, rightStart, leftStart, length, typeCheckNeeded); + this.epsilon = epsilon; + } + + @Override + public Boolean visit(BaseFixedWidthVector left, Void value) { + if (left instanceof Float4Vector) { + return validate(left) && float4ApproxEquals((Float4Vector) left); + } else if (left instanceof Float8Vector) { + return validate(left) && float8ApproxEquals((Float8Vector) left); + } else { + return super.visit(left, value); + } + } + + @Override + protected boolean compareUnionVectors(UnionVector left) { + + UnionVector rightVector = (UnionVector) right; + + List leftChildren = left.getChildrenFromFields(); + List rightChildren = rightVector.getChildrenFromFields(); + + if (leftChildren.size() != rightChildren.size()) { + return false; + } + + for (int k = 0; k < leftChildren.size(); k++) { + ApproxEqualsVisitor visitor = new ApproxEqualsVisitor(rightChildren.get(k), + epsilon); + if (!leftChildren.get(k).accept(visitor, null)) { + return false; + } + } + return true; + } + + @Override + protected boolean compareStructVectors(NonNullableStructVector left) { + + NonNullableStructVector rightVector = (NonNullableStructVector) right; + + if (!left.getChildFieldNames().equals(rightVector.getChildFieldNames())) { + return false; + } + + for (String name : left.getChildFieldNames()) { + ApproxEqualsVisitor visitor = new ApproxEqualsVisitor(rightVector.getChild(name), + epsilon); + if (!left.getChild(name).accept(visitor, null)) { + return false; + } + } + + return true; + } + + @Override + protected boolean compareListVectors(ListVector left) { + + for (int i = 0; i < length; i++) { + int leftIndex = leftStart + i; + int rightIndex = rightStart + i; + + boolean isNull = left.isNull(leftIndex); + if (isNull != right.isNull(rightIndex)) { + return false; + } + + int offsetWidth = BaseRepeatedValueVector.OFFSET_WIDTH; + + if (!isNull) { + final int startIndexLeft = left.getOffsetBuffer().getInt(leftIndex * offsetWidth); + final int endIndexLeft = left.getOffsetBuffer().getInt((leftIndex + 1) * offsetWidth); + + final int startIndexRight = right.getOffsetBuffer().getInt(rightIndex * offsetWidth); + final int endIndexRight = right.getOffsetBuffer().getInt((rightIndex + 1) * offsetWidth); + + if ((endIndexLeft - startIndexLeft) != (endIndexRight - startIndexRight)) { + return false; + } + + ValueVector leftDataVector = left.getDataVector(); + ValueVector rightDataVector = ((ListVector)right).getDataVector(); + + if (!leftDataVector.accept(new ApproxEqualsVisitor(rightDataVector, epsilon, typeCheckNeeded, + startIndexLeft, startIndexRight, (endIndexLeft - startIndexLeft)), null)) { + return false; + } + } + } + return true; + } + + protected boolean compareFixedSizeListVectors(FixedSizeListVector left) { + + if (left.getListSize() != ((FixedSizeListVector)right).getListSize()) { + return false; + } + + for (int i = 0; i < length; i++) { + int leftIndex = leftStart + i; + int rightIndex = rightStart + i; + + boolean isNull = left.isNull(leftIndex); + if (isNull != right.isNull(rightIndex)) { + return false; + } + + int listSize = left.getListSize(); + + if (!isNull) { + final int startIndexLeft = leftIndex * listSize; + final int endIndexLeft = (leftIndex + 1) * listSize; + + final int startIndexRight = rightIndex * listSize; + final int endIndexRight = (rightIndex + 1) * listSize; + + if ((endIndexLeft - startIndexLeft) != (endIndexRight - startIndexRight)) { + return false; + } + + ValueVector leftDataVector = left.getDataVector(); + ValueVector rightDataVector = ((FixedSizeListVector)right).getDataVector(); + + if (!leftDataVector.accept(new ApproxEqualsVisitor(rightDataVector, epsilon, typeCheckNeeded, + startIndexLeft, startIndexRight, (endIndexLeft - startIndexLeft)), null)) { + return false; + } + } + } + return true; + } + + private boolean float4ApproxEquals(Float4Vector left) { + + for (int i = 0; i < length; i++) { + int leftIndex = leftStart + i; + int rightIndex = rightStart + i; + + boolean isNull = left.isNull(leftIndex); + + if (isNull != right.isNull(rightIndex)) { + return false; + } + + if (!isNull) { + + float leftValue = left.get(leftIndex); + float rightValue = ((Float4Vector)right).get(rightIndex); + if (Math.abs(leftValue - rightValue) > epsilon) { + return false; + } + } + } + return true; + } + + private boolean float8ApproxEquals(Float8Vector left) { + for (int i = 0; i < length; i++) { + int leftIndex = leftStart + i; + int rightIndex = rightStart + i; + + boolean isNull = left.isNull(leftIndex); + + if (isNull != right.isNull(rightIndex)) { + return false; + } + + if (!isNull) { + + double leftValue = left.get(leftIndex); + double rightValue = ((Float8Vector)right).get(rightIndex); + if (Math.abs(leftValue - rightValue) > epsilon) { + return false; + } + } + } + return true; + } +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java index 273642f0329..d69cb7cd0f3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java @@ -66,7 +66,7 @@ public RangeEqualsVisitor(ValueVector right, int leftStart, int rightStart, int /** * Do some validation work, like type check and indices check. */ - private boolean validate(ValueVector left) { + protected boolean validate(ValueVector left) { if (!compareValueVector(left, right)) { return false; diff --git a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java index eb35eadcfef..3757757b684 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java @@ -25,6 +25,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.ZeroVector; @@ -33,8 +35,11 @@ import org.apache.arrow.vector.complex.UnionVector; import org.apache.arrow.vector.complex.impl.NullableStructWriter; import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.holders.NullableFloat4Holder; +import org.apache.arrow.vector.holders.NullableFloat8Holder; import org.apache.arrow.vector.holders.NullableIntHolder; import org.apache.arrow.vector.holders.NullableUInt4Holder; +import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.FieldType; @@ -253,6 +258,191 @@ public void testEqualsWithOutTypeCheck() { } } + @Test + public void testFloat4ApproxEquals() { + try (final Float4Vector vector1 = new Float4Vector("float", allocator); + final Float4Vector vector2 = new Float4Vector("float", allocator); + final Float4Vector vector3 = new Float4Vector("float", allocator)) { + + vector1.allocateNew(2); + vector1.setValueCount(2); + vector2.allocateNew(2); + vector2.setValueCount(2); + vector3.allocateNew(2); + vector3.setValueCount(2); + + vector1.setSafe(0, 1.1f); + vector1.setSafe(1, 2.2f); + + float epsilon = 1.0E-6f; + + vector2.setSafe(0, 1.1f + epsilon / 2); + vector2.setSafe(1, 2.2f + epsilon / 2); + + vector3.setSafe(0, 1.1f + epsilon * 2); + vector3.setSafe(1, 2.2f + epsilon * 2); + + ApproxEqualsVisitor visitor = new ApproxEqualsVisitor(vector1, epsilon); + + assertTrue(visitor.equals(vector2)); + assertFalse(visitor.equals(vector3)); + } + } + + @Test + public void testFloat8ApproxEquals() { + try (final Float8Vector vector1 = new Float8Vector("float", allocator); + final Float8Vector vector2 = new Float8Vector("float", allocator); + final Float8Vector vector3 = new Float8Vector("float", allocator)) { + + vector1.allocateNew(2); + vector1.setValueCount(2); + vector2.allocateNew(2); + vector2.setValueCount(2); + vector3.allocateNew(2); + vector3.setValueCount(2); + + vector1.setSafe(0, 1.1); + vector1.setSafe(1, 2.2); + + float epsilon = 1.0E-6f; + + vector2.setSafe(0, 1.1 + epsilon / 2); + vector2.setSafe(1, 2.2 + epsilon / 2); + + vector3.setSafe(0, 1.1 + epsilon * 2); + vector3.setSafe(1, 2.2 + epsilon * 2); + + ApproxEqualsVisitor visitor = new ApproxEqualsVisitor(vector1, epsilon); + + assertTrue(visitor.equals(vector2)); + assertFalse(visitor.equals(vector3)); + } + } + + @Test + public void testStructVectorApproxEquals() { + try (final StructVector right = StructVector.empty("struct", allocator); + final StructVector left1 = StructVector.empty("struct", allocator); + final StructVector left2 = StructVector.empty("struct", allocator); + ) { + right.addOrGet("f0", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), Float4Vector.class); + right.addOrGet("f1", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Float8Vector.class); + left1.addOrGet("f0", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), Float4Vector.class); + left1.addOrGet("f1", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Float8Vector.class); + left2.addOrGet("f0", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)), Float4Vector.class); + left2.addOrGet("f1", + FieldType.nullable(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), Float8Vector.class); + + final float epsilon = 1.0E-6f; + + NullableStructWriter rightWriter = right.getWriter(); + rightWriter.allocate(); + writeStructVector(rightWriter, 1.1f, 2.2); + writeStructVector(rightWriter, 2.02f, 4.04); + rightWriter.setValueCount(2); + + NullableStructWriter leftWriter1 = left1.getWriter(); + leftWriter1.allocate(); + writeStructVector(leftWriter1, 1.1f + epsilon / 2, 2.2 + epsilon / 2); + writeStructVector(leftWriter1, 2.02f - epsilon / 2, 4.04 - epsilon / 2); + leftWriter1.setValueCount(2); + + NullableStructWriter leftWriter2 = left2.getWriter(); + leftWriter2.allocate(); + writeStructVector(leftWriter2, 1.1f + epsilon * 2, 2.2 + epsilon * 2); + writeStructVector(leftWriter2, 2.02f - epsilon * 2, 4.04 - epsilon * 2); + leftWriter2.setValueCount(2); + + ApproxEqualsVisitor visitor = new ApproxEqualsVisitor(right, epsilon); + assertTrue(visitor.equals(left1)); + assertFalse(visitor.equals(left2)); + } + } + + @Test + public void testUnionVectorApproxEquals() { + try (final UnionVector right = new UnionVector("union", allocator, null); + final UnionVector left1 = new UnionVector("union", allocator, null); + final UnionVector left2 = new UnionVector("union", allocator, null);) { + + final NullableFloat4Holder float4Holder = new NullableFloat4Holder(); + float4Holder.value = 1.01f; + float4Holder.isSet = 1; + + final NullableFloat8Holder float8Holder = new NullableFloat8Holder(); + float8Holder.value = 2.02f; + float8Holder.isSet = 1; + + final float epsilon = 1.0E-6f; + + right.setType(0, Types.MinorType.FLOAT4); + right.setSafe(0, float4Holder); + right.setType(1, Types.MinorType.FLOAT8); + right.setSafe(1, float8Holder); + right.setValueCount(2); + + float4Holder.value += epsilon / 2; + float8Holder.value += epsilon / 2; + + left1.setType(0, Types.MinorType.FLOAT4); + left1.setSafe(0, float4Holder); + left1.setType(1, Types.MinorType.FLOAT8); + left1.setSafe(1, float8Holder); + left1.setValueCount(2); + + float4Holder.value += epsilon * 2; + float8Holder.value += epsilon * 2; + + left2.setType(0, Types.MinorType.FLOAT4); + left2.setSafe(0, float4Holder); + left2.setType(1, Types.MinorType.FLOAT8); + left2.setSafe(1, float8Holder); + left2.setValueCount(2); + + ApproxEqualsVisitor visitor = new ApproxEqualsVisitor(right, epsilon); + assertTrue(visitor.equals(left1)); + assertFalse(visitor.equals(left2)); + } + } + + @Test + public void testListVectorApproxEquals() { + try (final ListVector right = ListVector.empty("list", allocator); + final ListVector left1 = ListVector.empty("list", allocator); + final ListVector left2 = ListVector.empty("list", allocator);) { + + final float epsilon = 1.0E-6f; + + UnionListWriter rightWriter = right.getWriter(); + rightWriter.allocate(); + writeListVector(rightWriter, new double[] {1, 2}); + writeListVector(rightWriter, new double[] {1.01, 2.02}); + rightWriter.setValueCount(2); + + UnionListWriter leftWriter1 = left1.getWriter(); + leftWriter1.allocate(); + writeListVector(leftWriter1, new double[] {1, 2}); + writeListVector(leftWriter1, new double[] {1.01 + epsilon / 2, 2.02 - epsilon / 2}); + leftWriter1.setValueCount(2); + + UnionListWriter leftWriter2 = left2.getWriter(); + leftWriter2.allocate(); + writeListVector(leftWriter2, new double[] {1, 2}); + writeListVector(leftWriter2, new double[] {1.01 + epsilon * 2, 2.02 - epsilon * 2}); + leftWriter2.setValueCount(2); + + ApproxEqualsVisitor visitor = new ApproxEqualsVisitor(right, epsilon); + assertTrue(visitor.equals(left1)); + assertFalse(visitor.equals(left2)); + } + } + private void writeStructVector(NullableStructWriter writer, int value1, long value2) { writer.start(); writer.integer("f0").writeInt(value1); @@ -260,6 +450,13 @@ private void writeStructVector(NullableStructWriter writer, int value1, long val writer.end(); } + private void writeStructVector(NullableStructWriter writer, float value1, double value2) { + writer.start(); + writer.float4("f0").writeFloat4(value1); + writer.float8("f1").writeFloat8(value2); + writer.end(); + } + private void writeListVector(UnionListWriter writer, int[] values) { writer.startList(); for (int v: values) { @@ -267,4 +464,12 @@ private void writeListVector(UnionListWriter writer, int[] values) { } writer.endList(); } + + private void writeListVector(UnionListWriter writer, double[] values) { + writer.startList(); + for (double v: values) { + writer.float8().writeFloat8(v); + } + writer.endList(); + } }