Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import java.util.function.Function;
import java.util.stream.Collectors;

public class ClientBuilder {
public class ClientBuilder<T> {
private static AtomicInteger counter = new AtomicInteger(0);
private final String name;

Expand All @@ -49,10 +49,10 @@ public class ClientBuilder {

private final int retries;

private final ReactiveSocketConnector<SocketAddress> connector;
private final ReactiveSocketConnector<T> connector;
private final Function<Throwable, Boolean> retryThisException;

private final Publisher<List<SocketAddress>> source;
private final Publisher<Collection<T>> source;

private ClientBuilder(
String name,
Expand All @@ -61,8 +61,8 @@ private ClientBuilder(
long connectTimeout, TimeUnit connectTimeoutUnit,
double backupQuantile,
int retries, Function<Throwable, Boolean> retryThisException,
ReactiveSocketConnector<SocketAddress> connector,
Publisher<List<SocketAddress>> source
ReactiveSocketConnector<T> connector,
Publisher<Collection<T>> source
) {
this.name = name;
this.executor = executor;
Expand All @@ -77,8 +77,8 @@ private ClientBuilder(
this.source = source;
}

public ClientBuilder withRequestTimeout(long timeout, TimeUnit unit) {
return new ClientBuilder(
public ClientBuilder<T> withRequestTimeout(long timeout, TimeUnit unit) {
return new ClientBuilder<>(
name,
executor,
timeout, unit,
Expand All @@ -90,8 +90,8 @@ public ClientBuilder withRequestTimeout(long timeout, TimeUnit unit) {
);
}

public ClientBuilder withConnectTimeout(long timeout, TimeUnit unit) {
return new ClientBuilder(
public ClientBuilder<T> withConnectTimeout(long timeout, TimeUnit unit) {
return new ClientBuilder<>(
name,
executor,
requestTimeout, requestTimeoutUnit,
Expand All @@ -103,8 +103,8 @@ public ClientBuilder withConnectTimeout(long timeout, TimeUnit unit) {
);
}

public ClientBuilder withExecutor(ScheduledExecutorService executor) {
return new ClientBuilder(
public ClientBuilder<T> withExecutor(ScheduledExecutorService executor) {
return new ClientBuilder<>(
name,
executor,
requestTimeout, requestTimeoutUnit,
Expand All @@ -116,8 +116,8 @@ public ClientBuilder withExecutor(ScheduledExecutorService executor) {
);
}

public ClientBuilder withConnector(ReactiveSocketConnector<SocketAddress> connector) {
return new ClientBuilder(
public ClientBuilder<T> withConnector(ReactiveSocketConnector<T> connector) {
return new ClientBuilder<>(
name,
executor,
requestTimeout, requestTimeoutUnit,
Expand All @@ -129,8 +129,8 @@ public ClientBuilder withConnector(ReactiveSocketConnector<SocketAddress> connec
);
}

public ClientBuilder withSource(Publisher<List<SocketAddress>> source) {
return new ClientBuilder(
public ClientBuilder<T> withSource(Publisher<Collection<T>> source) {
return new ClientBuilder<>(
name,
executor,
requestTimeout, requestTimeoutUnit,
Expand All @@ -150,52 +150,47 @@ public ReactiveSocket build() {
throw new IllegalStateException("Please configure the connector!");
}

ReactiveSocketConnector<SocketAddress> filterConnector = connector
ReactiveSocketConnector<T> filterConnector = connector
.chain(socket -> new TimeoutSocket(socket, requestTimeout, requestTimeoutUnit, executor))
.chain(DrainingSocket::new);

Publisher<List<ReactiveSocketFactory<SocketAddress>>> factories =
Publisher<? extends Collection<ReactiveSocketFactory<T>>> factories =
sourceToFactory(source, filterConnector);

return new LoadBalancer(factories);
return new LoadBalancer<T>(factories);
}

private Publisher<List<ReactiveSocketFactory<SocketAddress>>> sourceToFactory(
Publisher<List<SocketAddress>> source,
ReactiveSocketConnector<SocketAddress> connector
private Publisher<? extends Collection<ReactiveSocketFactory<T>>> sourceToFactory(
Publisher<? extends Collection<T>> source,
ReactiveSocketConnector<T> connector
) {
return subscriber ->
source.subscribe(new Subscriber<List<SocketAddress>>() {
private Map<SocketAddress, ReactiveSocketFactory<SocketAddress>> current;
source.subscribe(new Subscriber<Collection<T>>() {
private Map<T, ReactiveSocketFactory<T>> current;

@Override
public void onSubscribe(Subscription s) {
subscriber.onSubscribe(s);
current = new HashMap<>();
current = Collections.emptyMap();
}

@Override
public void onNext(List<SocketAddress> socketAddresses) {
socketAddresses.stream()
.filter(sa -> !current.containsKey(sa))
.map(connector::toFactory)
.map(factory -> factory.chain(TimeoutFactory.asChainFunction(connectTimeout, connectTimeoutUnit,
executor)))
.map(FailureAwareFactory::new)
.forEach(factory -> current.put(factory.remote(), factory));

Set<SocketAddress> addresses = new HashSet<>(socketAddresses);
Iterator<Map.Entry<SocketAddress, ReactiveSocketFactory<SocketAddress>>> it =
current.entrySet().iterator();
while (it.hasNext()) {
SocketAddress sa = it.next().getKey();
if (! addresses.contains(sa)) {
it.remove();
public void onNext(Collection<T> socketAddresses) {
Map<T, ReactiveSocketFactory<T>> next = new HashMap<>(socketAddresses.size());
for (T sa: socketAddresses) {
ReactiveSocketFactory<T> factory = current.get(sa);
if (factory == null) {
ReactiveSocketFactory<T> newFactory = connector.toFactory(sa);
newFactory = new TimeoutFactory<>(newFactory, connectTimeout, connectTimeoutUnit, executor);
newFactory = new FailureAwareFactory<>(newFactory);
next.put(sa, newFactory);
} else {
next.put(sa, factory);
}
}

List<ReactiveSocketFactory<SocketAddress>> factories =
current.values().stream().collect(Collectors.toList());
current = next;
List<ReactiveSocketFactory<T>> factories = new ArrayList<>(current.values());
subscriber.onNext(factories);
}

Expand All @@ -207,17 +202,14 @@ public void onNext(List<SocketAddress> socketAddresses) {
});
}

public static ClientBuilder instance() {
return new ClientBuilder(
public static <T> ClientBuilder<T> instance() {
return new ClientBuilder<>(
"rs-loadbalancer-" + counter.incrementAndGet(),
Executors.newScheduledThreadPool(4, new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread thread = new Thread(r);
thread.setName("reactivesocket-scheduler-thread");
thread.setDaemon(true);
return thread;
}
Executors.newScheduledThreadPool(4, runnable -> {
Thread thread = new Thread(runnable);
thread.setName("reactivesocket-scheduler-thread");
thread.setDaemon(true);
return thread;
}),
1, TimeUnit.SECONDS,
10, TimeUnit.SECONDS,
Expand Down
Loading