diff --git a/sql/src/main/java/org/apache/druid/sql/http/ArrayWriter.java b/sql/src/main/java/org/apache/druid/sql/http/ArrayWriter.java index c177cf398158..678df49630e2 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/ArrayWriter.java +++ b/sql/src/main/java/org/apache/druid/sql/http/ArrayWriter.java @@ -36,6 +36,9 @@ public ArrayWriter(final OutputStream outputStream, final ObjectMapper jsonMappe { this.jsonGenerator = jsonMapper.getFactory().createGenerator(outputStream); this.outputStream = outputStream; + + // Disable automatic JSON termination, so clients can detect truncated responses. + jsonGenerator.configure(JsonGenerator.Feature.AUTO_CLOSE_JSON_CONTENT, false); } @Override diff --git a/sql/src/main/java/org/apache/druid/sql/http/ObjectWriter.java b/sql/src/main/java/org/apache/druid/sql/http/ObjectWriter.java index b1623a53cf86..ac7b0cf1405a 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/ObjectWriter.java +++ b/sql/src/main/java/org/apache/druid/sql/http/ObjectWriter.java @@ -36,6 +36,9 @@ public ObjectWriter(final OutputStream outputStream, final ObjectMapper jsonMapp { this.jsonGenerator = jsonMapper.getFactory().createGenerator(outputStream); this.outputStream = outputStream; + + // Disable automatic JSON termination, so clients can detect truncated responses. + jsonGenerator.configure(JsonGenerator.Feature.AUTO_CLOSE_JSON_CONTENT, false); } @Override diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java index 02f596adbdaa..e1d68ef2ecce 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java @@ -95,9 +95,11 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import java.util.stream.Collectors; @@ -123,6 +125,7 @@ public class SqlResourceTest extends CalciteTestBase private final SettableSupplier> validateAndAuthorizeLatchSupplier = new SettableSupplier<>(); private final SettableSupplier> planLatchSupplier = new SettableSupplier<>(); private final SettableSupplier> executeLatchSupplier = new SettableSupplier<>(); + private final SettableSupplier, Sequence>> sequenceMapFnSupplier = new SettableSupplier<>(); private boolean sleep = false; @@ -254,7 +257,8 @@ public SqlLifecycle factorize() System.nanoTime(), validateAndAuthorizeLatchSupplier, planLatchSupplier, - executeLatchSupplier + executeLatchSupplier, + sequenceMapFnSupplier ); } }, @@ -505,6 +509,76 @@ public void testArrayResultFormat() throws Exception ); } + @Test + public void testArrayResultFormatWithErrorAfterFirstRow() throws Exception + { + sequenceMapFnSupplier.set(errorAfterSecondRowMapFn()); + + final String query = "SELECT cnt FROM foo"; + final Pair response = + doPostRaw(new SqlQuery(query, ResultFormat.ARRAY, false, null, null), req); + + // Truncated response: missing final ] + Assert.assertNull(response.lhs); + Assert.assertEquals("[[1],[1]", response.rhs); + } + + @Test + public void testObjectResultFormatWithErrorAfterFirstRow() throws Exception + { + sequenceMapFnSupplier.set(errorAfterSecondRowMapFn()); + + final String query = "SELECT cnt FROM foo"; + final Pair response = + doPostRaw(new SqlQuery(query, ResultFormat.OBJECT, false, null, null), req); + + // Truncated response: missing final ] + Assert.assertNull(response.lhs); + Assert.assertEquals("[{\"cnt\":1},{\"cnt\":1}", response.rhs); + } + + @Test + public void testArrayLinesResultFormatWithErrorAfterFirstRow() throws Exception + { + sequenceMapFnSupplier.set(errorAfterSecondRowMapFn()); + + final String query = "SELECT cnt FROM foo"; + final Pair response = + doPostRaw(new SqlQuery(query, ResultFormat.ARRAYLINES, false, null, null), req); + + // Truncated response: missing final LFLF + Assert.assertNull(response.lhs); + Assert.assertEquals("[1]\n[1]", response.rhs); + } + + @Test + public void testObjectLinesResultFormatWithErrorAfterFirstRow() throws Exception + { + sequenceMapFnSupplier.set(errorAfterSecondRowMapFn()); + + final String query = "SELECT cnt FROM foo"; + final Pair response = + doPostRaw(new SqlQuery(query, ResultFormat.OBJECTLINES, false, null, null), req); + + // Truncated response: missing final LFLF + Assert.assertNull(response.lhs); + Assert.assertEquals("{\"cnt\":1}\n{\"cnt\":1}", response.rhs); + } + + @Test + public void testCsvResultFormatWithErrorAfterFirstRow() throws Exception + { + sequenceMapFnSupplier.set(errorAfterSecondRowMapFn()); + + final String query = "SELECT cnt FROM foo"; + final Pair response = + doPostRaw(new SqlQuery(query, ResultFormat.CSV, false, null, null), req); + + // Truncated response: missing final LFLF + Assert.assertNull(response.lhs); + Assert.assertEquals("1\n1\n", response.rhs); + } + @Test public void testArrayResultFormatWithHeader() throws Exception { @@ -1128,7 +1202,14 @@ private Pair doPostRaw(final SqlQuery query, final HttpS if (response.getStatus() == 200) { final StreamingOutput output = (StreamingOutput) response.getEntity(); final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - output.write(baos); + try { + output.write(baos); + } + catch (Exception ignored) { + // Suppress errors and return the response so far. Similar to what the real web server would do, if it + // started writing a 200 OK and then threw an exception in the middle. + } + return Pair.of( null, new String(baos.toByteArray(), StandardCharsets.UTF_8) @@ -1180,11 +1261,26 @@ private HttpServletRequest mockRequestForCancel() return req; } + private static Function, Sequence> errorAfterSecondRowMapFn() + { + return results -> { + final AtomicLong rows = new AtomicLong(); + return results.map(row -> { + if (rows.incrementAndGet() == 3) { + throw new ISE("Oh no!"); + } else { + return row; + } + }); + }; + } + private static class TestSqlLifecycle extends SqlLifecycle { private final SettableSupplier> validateAndAuthorizeLatchSupplier; private final SettableSupplier> planLatchSupplier; private final SettableSupplier> executeLatchSupplier; + private final SettableSupplier, Sequence>> sequenceMapFnSupplier; private TestSqlLifecycle( PlannerFactory plannerFactory, @@ -1195,13 +1291,15 @@ private TestSqlLifecycle( long startNs, SettableSupplier> validateAndAuthorizeLatchSupplier, SettableSupplier> planLatchSupplier, - SettableSupplier> executeLatchSupplier + SettableSupplier> executeLatchSupplier, + SettableSupplier, Sequence>> sequenceMapFnSupplier ) { super(plannerFactory, emitter, requestLogger, queryScheduler, startMs, startNs); this.validateAndAuthorizeLatchSupplier = validateAndAuthorizeLatchSupplier; this.planLatchSupplier = planLatchSupplier; this.executeLatchSupplier = executeLatchSupplier; + this.sequenceMapFnSupplier = sequenceMapFnSupplier; } @Override @@ -1253,9 +1351,12 @@ public void plan() throws RelConversionException @Override public Sequence execute() { + final Function, Sequence> sequenceMapFn = + Optional.ofNullable(sequenceMapFnSupplier.get()).orElse(Function.identity()); + if (executeLatchSupplier.get() != null) { if (executeLatchSupplier.get().rhs) { - Sequence sequence = super.execute(); + Sequence sequence = sequenceMapFn.apply(super.execute()); executeLatchSupplier.get().lhs.countDown(); return sequence; } else { @@ -1267,10 +1368,10 @@ public Sequence execute() catch (InterruptedException e) { throw new RuntimeException(e); } - return super.execute(); + return sequenceMapFn.apply(super.execute()); } } else { - return super.execute(); + return sequenceMapFn.apply(super.execute()); } } }