Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class WebSocketClient implements WebSocket.Listener {
private final Consumer<CharSequence> onTextCallback;

private final AtomicBoolean attemptReconnect = new AtomicBoolean();
private final AtomicBoolean keepPinging = new AtomicBoolean();
private CountDownLatch pingCountdownLatch;

/**
Expand Down Expand Up @@ -90,6 +91,7 @@ public void onOpen(WebSocket webSocket) {
connectCallback.run();
}
logger.log(Level.INFO, "Connected to " + uri);
keepPinging.set(true);
new Thread(new PingRunnable()).start();
}

Expand Down Expand Up @@ -134,7 +136,7 @@ public CompletionStage<?> onClose(WebSocket webSocket,
* is called.
*/
public void sendPing() {
logger.log(Level.FINE, "Sending ping");
logger.log(Level.FINE, Thread.currentThread().getName() + " Sending ping");
webSocket.sendPing(ByteBuffer.allocate(0));
}

Expand Down Expand Up @@ -175,6 +177,7 @@ public CompletionStage<?> onText(WebSocket webSocket,
* @param reason Custom reason text.
*/
public void close(String reason) {
keepPinging.set(false);
webSocket.sendClose(1000, reason);
}

Expand All @@ -197,7 +200,7 @@ private class PingRunnable implements Runnable {

@Override
public void run() {
while (true) {
while (keepPinging.get()) {
pingCountdownLatch = new CountDownLatch(1);
sendPing();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.time.Instant;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
* Utility class for handling web socket messages. In the context of the save-and-restore service,
* only messages from server are expected. Client messages are logged, but do not invoke any behavior.
* Utility class for handling web socket messages.
*/
@SuppressWarnings("nls")
public class WebSocket {
Expand All @@ -41,11 +44,16 @@ public class WebSocket {

private final WebSocketSession session;
private final String id;

private final String description;
private final Logger logger = Logger.getLogger(WebSocket.class.getName());

private final ObjectMapper objectMapper;

/**
* Keeps track of when this session was used for a ping/pong exchange. Should be set to non-null value ONLY
* when an actual pong was received by {@link WebSocketHandler}.
*/
private Instant lastPinged;

/**
* Constructor
*/
Expand All @@ -58,6 +66,8 @@ public WebSocket(ObjectMapper objectMapper, WebSocketSession webSocketSession) {
writeThread.setName("Web Socket Write Thread " + this.id);
writeThread.setDaemon(true);
writeThread.start();
InetSocketAddress inetSocketAddress = webSocketSession.getRemoteAddress();
this.description = this.id + "/" + (inetSocketAddress != null ? inetSocketAddress.getAddress().toString() : "IP address unknown");
}

/**
Expand All @@ -70,6 +80,14 @@ public String getId() {
return id;
}

/**
*
* @return A description containing the session ID and - if available - the associated IP address.
*/
public String getDescription() {
return description;
}

/**
* @param message Potentially long message
* @return Message shorted to 200 chars
Expand Down Expand Up @@ -137,7 +155,7 @@ private void writeQueuedMessages() {
}

/**
* Called when client sends a general message
* Called when client sends a generic message
*
* @param message {@link TextMessage}, its payload is expected to be JSON.
*/
Expand All @@ -150,12 +168,6 @@ public void handleTextMessage(TextMessage message) throws Exception {
logger.log(Level.INFO, "Client message type: " + type);
}

/**
* Clears all PVs
*
* <p>Web socket calls this onClose(),
* but context may also call this again just in case
*/
public void dispose() {
// Exit write thread
try {
Expand All @@ -166,8 +178,35 @@ public void dispose() {
// TODO: is this needed?
session.close();
} catch (Throwable ex) {
logger.log(Level.WARNING, "Error disposing " + getId(), ex);
logger.log(Level.WARNING, "Error disposing " + description, ex);
}
logger.log(Level.INFO, () -> "Web socket " + description + " closed");
}

/**
* Sets the time of last received pong message.
* @param instant Time of last received pong message.
*/
public synchronized void setLastPinged(Instant instant) {
this.lastPinged = instant;
}

/**
*
* @return The time of last received pong message.
*/
public synchronized Instant getLastPinged() {
return lastPinged;
}

/**
* Sends a {@link PingMessage} to peer.
*/
public void sendPing() {
try {
session.sendMessage(new PingMessage());
} catch (IOException e) {
logger.log(Level.WARNING, "Failed to send ping message", e);
}
logger.log(Level.INFO, () -> "Web socket " + session.getId() + " closed");
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,5 @@
/*
* Copyright (C) 2023 European Spallation Source ERIC.
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
* Copyright (C) 2025 European Spallation Source ERIC.
*
*/

Expand All @@ -24,6 +10,7 @@
import org.phoebus.applications.saveandrestore.model.websocket.SaveAndRestoreWebSocketMessage;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.lang.NonNull;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PongMessage;
Expand All @@ -33,14 +20,27 @@

import javax.annotation.PreDestroy;
import java.io.EOFException;
import java.net.InetSocketAddress;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
* Single web socket end-point routing messages to active {@link WebSocket} instances.
*
* <p>
* In some cases web socket clients may become stale/disconnected for various reasons, e.g. network issues. The
* {@link #afterConnectionClosed(WebSocketSession, CloseStatus)} is not necessarily called in those case.
* To make sure the {@link #sockets} collection does not contain stale clients, a scheduled job runs once per hour to
* ping all clients, and set the time when the pong response was received. Another scheduled job will check
* the last received pong message timestamp and - if older than 70 minutes - consider the client session dead
* and dispose of it.
* </p>
*/
@Component
public class WebSocketHandler extends TextWebSocketHandler {
Expand All @@ -49,7 +49,7 @@ public class WebSocketHandler extends TextWebSocketHandler {
* List of active {@link WebSocket}
*/
@SuppressWarnings("unused")
private List<WebSocket> sockets = new ArrayList<>();
private List<WebSocket> sockets = Collections.synchronizedList(new ArrayList<>());

@SuppressWarnings("unused")
@Autowired
Expand All @@ -69,8 +69,11 @@ public class WebSocketHandler extends TextWebSocketHandler {
public void handleTextMessage(@NonNull WebSocketSession session, @NonNull TextMessage message) {
try {
// Find the WebSocket instance associated with the WebSocketSession
Optional<WebSocket> webSocketOptional =
sockets.stream().filter(webSocket -> webSocket.getId().equals(session.getId())).findFirst();
Optional<WebSocket> webSocketOptional;
synchronized (sockets){
webSocketOptional =
sockets.stream().filter(webSocket -> webSocket.getId().equals(session.getId())).findFirst();
}
if (webSocketOptional.isEmpty()) {
return; // Should only happen in case of timing issues?
}
Expand All @@ -87,7 +90,8 @@ public void handleTextMessage(@NonNull WebSocketSession session, @NonNull TextMe
*/
@Override
public void afterConnectionEstablished(@NonNull WebSocketSession session) {
logger.log(Level.INFO, "Opening web socket session from remote " + session.getRemoteAddress().getAddress());
InetSocketAddress inetSocketAddress = session.getRemoteAddress();
logger.log(Level.INFO, "Opening web socket session from remote " + (inetSocketAddress != null ? inetSocketAddress.getAddress().toString() : "<unknown IP address>"));
WebSocket webSocket = new WebSocket(objectMapper, session);
sockets.add(webSocket);
}
Expand All @@ -101,10 +105,12 @@ public void afterConnectionEstablished(@NonNull WebSocketSession session) {
*/
@Override
public void afterConnectionClosed(@NonNull WebSocketSession session, @NonNull CloseStatus status) {
Optional<WebSocket> webSocketOptional =
sockets.stream().filter(webSocket -> webSocket.getId().equals(session.getId())).findFirst();
Optional<WebSocket> webSocketOptional;
synchronized (sockets){
webSocketOptional = sockets.stream().filter(webSocket -> webSocket.getId().equals(session.getId())).findFirst();
}
if (webSocketOptional.isPresent()) {
logger.log(Level.INFO, "Closing web socket session from remote " + session.getRemoteAddress().getAddress());
logger.log(Level.INFO, "Closing web socket session " + webSocketOptional.get().getDescription());
webSocketOptional.get().dispose();
sockets.remove(webSocketOptional.get());
}
Expand All @@ -126,20 +132,22 @@ public void handleTransportError(@NonNull WebSocketSession session, @NonNull Thr
}

/**
* Called when client sends ping message, i.e. a pong message is sent and time for last message
* Called when client sends ping message, i.e. a pong message is sent and time for last pong response message
* in the {@link WebSocket} instance is refreshed.
*
* @param session Associated {@link WebSocketSession}
* @param message See {@link PongMessage}
*/
@Override
protected void handlePongMessage(@NonNull WebSocketSession session, @NonNull PongMessage message) {
logger.log(Level.INFO, "Got pong");
logger.log(Level.FINE, "Got pong for session " + session.getId());
// Find the WebSocket instance associated with this WebSocketSession
Optional<WebSocket> webSocketOptional =
sockets.stream().filter(webSocket -> webSocket.getId().equals(session.getId())).findFirst();
if (webSocketOptional.isEmpty()) {
return; // Should only happen in case of timing issues?
Optional<WebSocket> webSocketOptional;
synchronized (sockets) {
webSocketOptional = sockets.stream().filter(webSocket -> webSocket.getId().equals(session.getId())).findFirst();
}
if (webSocketOptional.isPresent()) {
webSocketOptional.get().setLastPinged(Instant.now());
}
}

Expand All @@ -155,19 +163,65 @@ private String shorten(final String message) {

@PreDestroy
public void cleanup() {
sockets.forEach(s -> {
logger.log(Level.INFO, "Disposing socket " + s.getId());
s.dispose();
});
synchronized (sockets) {
sockets.forEach(s -> {
logger.log(Level.INFO, "Disposing socket " + s.getDescription());
s.dispose();
});
}
}

public void sendMessage(SaveAndRestoreWebSocketMessage webSocketMessage) {
sockets.forEach(ws -> {
try {
ws.queueMessage(objectMapper.writeValueAsString(webSocketMessage));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
});
synchronized (sockets) {
sockets.forEach(ws -> {
try {
ws.queueMessage(objectMapper.writeValueAsString(webSocketMessage));
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
});
}
}

/**
* Sends a ping message to all clients contained in {@link #sockets}.
* <p>
* This is scheduled to run at the top of each hour, i.e. 00.00, 01.00...23.00
* </p>
*
*/
@SuppressWarnings("unused")
@Scheduled(cron = "* 0 * * * *")
public void pingClients(){
synchronized (sockets) {
sockets.forEach(WebSocket::sendPing);
}
}

/**
* For each client in {@link #sockets}, checks the timestamp of last received pong message. If this is older
* than 70 minutes, the socket is considered dead, and then disposed.
* <p>
* This is scheduled to run 5 minutes past each hour, i.e. 00.05, 01.05...23.05
* </p>
*
*/
@SuppressWarnings("unused")
@Scheduled(cron = "* 5 * * * *")
public void cleanUpDeadSockets(){
List<WebSocket> deadSockets = new ArrayList<>();
Instant now = Instant.now();
synchronized (sockets) {
sockets.forEach(s -> {
Instant lastPinged = s.getLastPinged();
if (lastPinged != null && lastPinged.isBefore(now.minus(70, ChronoUnit.MINUTES))) {
deadSockets.add(s);
}
});
deadSockets.forEach(d -> {
sockets.remove(d);
d.dispose();
});
}
}
}