Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, TranslationC
}

static class SplittableProcessElementsTranslator<
InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT>>
InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT, ?>>
implements TransformTranslator<ProcessElements<InputT, OutputT, RestrictionT, TrackerT>> {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ public TimerInternals timerInternals() {
(StateInternalsFactory<String>) this.currentKeyStateInternals.getFactory();

@SuppressWarnings({ "rawtypes", "unchecked" })
ProcessFn<InputT, OutputT, Object, RestrictionTracker<Object>>
ProcessFn<InputT, OutputT, Object, RestrictionTracker<Object, Object>>
splittableDoFn = (ProcessFn) doFn;
splittableDoFn.setStateInternalsFactory(stateInternalsFactory);
TimerInternalsFactory<String> timerInternalsFactory = key -> currentKeyTimerInternals;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public void simpleProcess(ProcessContext ctxt) {
ctxt.output(ctxt.element().getValue() + 1);
}
};
private abstract static class SomeTracker implements RestrictionTracker<Void> {}
private abstract static class SomeTracker extends RestrictionTracker<Void, Void> {}
private DoFn<KV<String, Integer>, Integer> splittableDoFn =
new DoFn<KV<String, Integer>, Integer>() {
@ProcessElement
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public int hashCode() {

private static class SplittableDropElementsFn extends DoFn<KV<Long, String>, Void> {
@ProcessElement
public void proc(ProcessContext context, RestrictionTracker<Integer> restriction) {
public void proc(ProcessContext context, RestrictionTracker<Integer, ?> restriction) {
context.output(null);
}

Expand All @@ -241,7 +241,7 @@ public Integer restriction(KV<Long, String> elem) {
}

@NewTracker
public RestrictionTracker<Integer> newTracker(Integer restriction) {
public RestrictionTracker<Integer, ?> newTracker(Integer restriction) {
throw new UnsupportedOperationException("Should never be called; only to test translation");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,18 @@ public SomeRestrictionTracker newTracker() {
}
}

private static class SomeRestrictionTracker implements RestrictionTracker<SomeRestriction> {
private static class SomeRestrictionTracker extends RestrictionTracker<SomeRestriction, Void> {
private final SomeRestriction someRestriction;

public SomeRestrictionTracker(SomeRestriction someRestriction) {
this.someRestriction = someRestriction;
}

@Override
protected boolean tryClaimImpl(Void position) {
return false;
}

@Override
public SomeRestriction currentRestriction() {
return someRestriction;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
Expand All @@ -49,6 +50,7 @@
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.joda.time.Instant;

Expand Down Expand Up @@ -126,25 +128,25 @@ public InMemoryStateBinder(StateContext<?> c) {
@Override
public <T> ValueState<T> bindValue(
StateTag<ValueState<T>> address, Coder<T> coder) {
return new InMemoryValue<>();
return new InMemoryValue<>(coder);
}

@Override
public <T> BagState<T> bindBag(
final StateTag<BagState<T>> address, Coder<T> elemCoder) {
return new InMemoryBag<>();
return new InMemoryBag<>(elemCoder);
}

@Override
public <T> SetState<T> bindSet(StateTag<SetState<T>> spec, Coder<T> elemCoder) {
return new InMemorySet<>();
return new InMemorySet<>(elemCoder);
}

@Override
public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
StateTag<MapState<KeyT, ValueT>> spec,
Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
return new InMemoryMap<>();
return new InMemoryMap<>(mapKeyCoder, mapValueCoder);
}

@Override
Expand All @@ -153,7 +155,7 @@ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
Coder<AccumT> accumCoder,
final CombineFn<InputT, AccumT, OutputT> combineFn) {
return new InMemoryCombiningState<>(combineFn);
return new InMemoryCombiningState<>(combineFn, accumCoder);
}

@Override
Expand All @@ -178,9 +180,15 @@ public WatermarkHoldState bindWatermark(
*/
public static final class InMemoryValue<T>
implements ValueState<T>, InMemoryState<InMemoryValue<T>> {
private final Coder<T> coder;

private boolean isCleared = true;
private @Nullable T value = null;

public InMemoryValue(Coder<T> coder) {
this.coder = coder;
}

@Override
public void clear() {
// Even though we're clearing we can't remove this from the in-memory state map, since
Expand All @@ -207,10 +215,10 @@ public void write(T input) {

@Override
public InMemoryValue<T> copy() {
InMemoryValue<T> that = new InMemoryValue<>();
InMemoryValue<T> that = new InMemoryValue<>(coder);
if (!this.isCleared) {
that.isCleared = this.isCleared;
that.value = this.value;
that.value = uncheckedClone(coder, this.value);
}
return that;
}
Expand Down Expand Up @@ -305,14 +313,16 @@ public InMemoryWatermarkHold<W> copy() {
public static final class InMemoryCombiningState<InputT, AccumT, OutputT>
implements CombiningState<InputT, AccumT, OutputT>,
InMemoryState<InMemoryCombiningState<InputT, AccumT, OutputT>> {
private boolean isCleared = true;
private final CombineFn<InputT, AccumT, OutputT> combineFn;
private final Coder<AccumT> accumCoder;
private boolean isCleared = true;
private AccumT accum;

public InMemoryCombiningState(
CombineFn<InputT, AccumT, OutputT> combineFn) {
CombineFn<InputT, AccumT, OutputT> combineFn, Coder<AccumT> accumCoder) {
this.combineFn = combineFn;
accum = combineFn.createAccumulator();
this.accumCoder = accumCoder;
}

@Override
Expand Down Expand Up @@ -378,10 +388,10 @@ public boolean isCleared() {
@Override
public InMemoryCombiningState<InputT, AccumT, OutputT> copy() {
InMemoryCombiningState<InputT, AccumT, OutputT> that =
new InMemoryCombiningState<>(combineFn);
new InMemoryCombiningState<>(combineFn, accumCoder);
if (!this.isCleared) {
that.isCleared = this.isCleared;
that.addAccum(accum);
that.addAccum(uncheckedClone(accumCoder, accum));
}
return that;
}
Expand All @@ -391,8 +401,13 @@ public InMemoryCombiningState<InputT, AccumT, OutputT> copy() {
* An {@link InMemoryState} implementation of {@link BagState}.
*/
public static final class InMemoryBag<T> implements BagState<T>, InMemoryState<InMemoryBag<T>> {
private final Coder<T> elemCoder;
private List<T> contents = new ArrayList<>();

public InMemoryBag(Coder<T> elemCoder) {
this.elemCoder = elemCoder;
}

@Override
public void clear() {
// Even though we're clearing we can't remove this from the in-memory state map, since
Expand Down Expand Up @@ -442,8 +457,10 @@ public Boolean read() {

@Override
public InMemoryBag<T> copy() {
InMemoryBag<T> that = new InMemoryBag<>();
that.contents.addAll(this.contents);
InMemoryBag<T> that = new InMemoryBag<>(elemCoder);
for (T elem : this.contents) {
that.contents.add(uncheckedClone(elemCoder, elem));
}
return that;
}
}
Expand All @@ -452,8 +469,13 @@ public InMemoryBag<T> copy() {
* An {@link InMemoryState} implementation of {@link SetState}.
*/
public static final class InMemorySet<T> implements SetState<T>, InMemoryState<InMemorySet<T>> {
private final Coder<T> elemCoder;
private Set<T> contents = new HashSet<>();

public InMemorySet(Coder<T> elemCoder) {
this.elemCoder = elemCoder;
}

@Override
public void clear() {
contents = new HashSet<>();
Expand Down Expand Up @@ -513,8 +535,10 @@ public Boolean read() {

@Override
public InMemorySet<T> copy() {
InMemorySet<T> that = new InMemorySet<>();
that.contents.addAll(this.contents);
InMemorySet<T> that = new InMemorySet<>(elemCoder);
for (T elem : this.contents) {
that.contents.add(uncheckedClone(elemCoder, elem));
}
return that;
}
}
Expand All @@ -524,8 +548,16 @@ public InMemorySet<T> copy() {
*/
public static final class InMemoryMap<K, V> implements
MapState<K, V>, InMemoryState<InMemoryMap<K, V>> {
private final Coder<K> keyCoder;
private final Coder<V> valueCoder;

private Map<K, V> contents = new HashMap<>();

public InMemoryMap(Coder<K> keyCoder, Coder<V> valueCoder) {
this.keyCoder = keyCoder;
this.valueCoder = valueCoder;
}

@Override
public void clear() {
contents = new HashMap<>();
Expand Down Expand Up @@ -600,9 +632,22 @@ public boolean isCleared() {

@Override
public InMemoryMap<K, V> copy() {
InMemoryMap<K, V> that = new InMemoryMap<>();
InMemoryMap<K, V> that = new InMemoryMap<>(keyCoder, valueCoder);
for (Map.Entry<K, V> entry : this.contents.entrySet()) {
that.contents.put(
uncheckedClone(keyCoder, entry.getKey()), uncheckedClone(valueCoder, entry.getValue()));
}
that.contents.putAll(this.contents);
return that;
}
}

/** Like {@link CoderUtils#clone} but without a checked exception. */
private static <T> T uncheckedClone(Coder<T> coder, T value) {
try {
return CoderUtils.clone(coder, value);
} catch (CoderException e) {
throw new RuntimeException(e);
}
}
}
Loading