diff --git a/src/main/kotlin/cash/atto/node/Event.kt b/src/main/kotlin/cash/atto/node/Event.kt index cb177728..0da62d02 100644 --- a/src/main/kotlin/cash/atto/node/Event.kt +++ b/src/main/kotlin/cash/atto/node/Event.kt @@ -7,12 +7,18 @@ import org.springframework.stereotype.Component import org.springframework.transaction.reactive.TransactionSynchronization import org.springframework.transaction.reactive.TransactionSynchronizationManager import reactor.core.publisher.Mono +import java.net.InetAddress import java.time.Instant interface Event { val timestamp: Instant } +data class InboundConnectionRequested( + val address: InetAddress, + override val timestamp: Instant = Instant.now(), +) : Event + @Component class EventPublisher( private val publisher: ApplicationEventPublisher, diff --git a/src/main/kotlin/cash/atto/node/network/BannedMonitor.kt b/src/main/kotlin/cash/atto/node/network/BannedMonitor.kt index 1660b62f..9089efa7 100644 --- a/src/main/kotlin/cash/atto/node/network/BannedMonitor.kt +++ b/src/main/kotlin/cash/atto/node/network/BannedMonitor.kt @@ -1,18 +1,26 @@ package cash.atto.node.network +import com.github.benmanes.caffeine.cache.Caffeine +import com.github.benmanes.caffeine.cache.Scheduler import org.springframework.context.event.EventListener import org.springframework.stereotype.Component import java.net.InetAddress -import java.util.concurrent.ConcurrentHashMap +import java.time.Duration @Component object BannedMonitor { - private val set = ConcurrentHashMap.newKeySet() + private val banned = + Caffeine + .newBuilder() + .scheduler(Scheduler.systemScheduler()) + .expireAfterWrite(Duration.ofHours(1)) + .build() + .asMap() @EventListener fun store(banned: NodeBanned) { - set.add(banned.address) + BannedMonitor.banned[banned.address] = true } - fun isBanned(address: InetAddress): Boolean = set.contains(address) + fun isBanned(address: InetAddress): Boolean = banned.containsKey(address) } diff --git a/src/main/kotlin/cash/atto/node/network/ChallengeStore.kt b/src/main/kotlin/cash/atto/node/network/ChallengeStore.kt index 4c4954b2..dff7627c 100644 --- a/src/main/kotlin/cash/atto/node/network/ChallengeStore.kt +++ b/src/main/kotlin/cash/atto/node/network/ChallengeStore.kt @@ -17,13 +17,10 @@ internal object ChallengeStore { .scheduler(Scheduler.systemScheduler()) .expireAfterWrite(5, TimeUnit.SECONDS) .maximumSize(100_000) - .build() + .build() .asMap() - fun remove( - publicUri: URI, - challenge: String, - ): Boolean = challenges.remove(publicUri, challenge) + fun remove(challenge: String): URI? = challenges.remove(challenge) fun generate(publicUri: URI): String { val challengePrefix = publicUri.toString().toByteArray() @@ -33,7 +30,7 @@ internal object ChallengeStore { (challengePrefix + it).toHex() } - challenges[publicUri] = challenge + challenges[challenge] = publicUri return challenge } diff --git a/src/main/kotlin/cash/atto/node/network/NetworkProcessor.kt b/src/main/kotlin/cash/atto/node/network/NetworkProcessor.kt index 06ba70c0..cc1780bc 100644 --- a/src/main/kotlin/cash/atto/node/network/NetworkProcessor.kt +++ b/src/main/kotlin/cash/atto/node/network/NetworkProcessor.kt @@ -10,6 +10,8 @@ import cash.atto.commons.fromHexToByteArray import cash.atto.commons.isValid import cash.atto.commons.toByteArray import cash.atto.node.CacheSupport +import cash.atto.node.EventPublisher +import cash.atto.node.InboundConnectionRequested import cash.atto.node.transaction.Transaction import cash.atto.protocol.AttoKeepAlive import cash.atto.protocol.AttoNode @@ -49,9 +51,9 @@ import org.springframework.context.event.EventListener import org.springframework.core.env.Environment import org.springframework.scheduling.annotation.Scheduled import org.springframework.stereotype.Component +import java.net.InetAddress import java.net.InetSocketAddress import java.net.URI -import java.security.SecureRandom import java.time.Duration import java.util.concurrent.Executors import kotlin.time.Duration.Companion.seconds @@ -64,6 +66,7 @@ class NetworkProcessor( environment: Environment, private val networkProperties: NetworkProperties, private val connectionManager: NodeConnectionManager, + private val eventPublisher: EventPublisher, ) : CacheSupport { private val logger = KotlinLogging.logger {} @@ -74,8 +77,6 @@ class NetworkProcessor( const val CONNECTION_TIMEOUT_IN_SECONDS = 5L } - val random = SecureRandom.getInstanceStrong()!! - private val httpClient = HttpClient(io.ktor.client.engine.cio.CIO) { install( @@ -115,6 +116,7 @@ class NetworkProcessor( .newBuilder() .scheduler(Scheduler.systemScheduler()) .expireAfterWrite(Duration.ofSeconds(CONNECTION_TIMEOUT_IN_SECONDS)) + .maximumSize(10_000) .build>() .asMap() @@ -143,13 +145,35 @@ class NetworkProcessor( post("/handshakes") { try { val remoteHost = call.request.origin.remoteHost + + if (BannedMonitor.isBanned(InetAddress.getByName(remoteHost))) { + logger.trace { "Rejected handshake from banned address $remoteHost" } + call.respond(HttpStatusCode.Forbidden) + return@post + } + + eventPublisher.publish(InboundConnectionRequested(InetAddress.getByName(remoteHost))) + val counterResponse = call.receive() - val node = counterResponse.node - val publicUri = node.publicUri val challenge = counterResponse.challenge - if (!ChallengeStore.remove(publicUri, challenge)) { - logger.trace { "Received invalid challenge request from $publicUri $remoteHost $counterResponse" } + val publicUri = ChallengeStore.remove(challenge) + if (publicUri == null) { + logger.trace { "Received invalid challenge request from $remoteHost $counterResponse" } + call.respond(HttpStatusCode.BadRequest) + return@post + } + + val node = counterResponse.node + + if (node.publicUri != publicUri) { + logger.trace { "Node publicUri ${node.publicUri} doesn't match expected $publicUri from $remoteHost" } + call.respond(HttpStatusCode.BadRequest) + return@post + } + + if (counterResponse.genesis != genesisTransaction.hash) { + logger.trace { "Received mismatched genesis hash from $publicUri $remoteHost $counterResponse" } call.respond(HttpStatusCode.BadRequest) return@post } @@ -203,20 +227,47 @@ class NetworkProcessor( logger.trace { "New websocket connection attempt from $remoteHost" } - val publicUri = URI(call.request.headers[PUBLIC_URI_HEADER]!!) + if (BannedMonitor.isBanned(InetAddress.getByName(remoteHost))) { + logger.trace { "Rejected websocket from banned address $remoteHost" } + call.respond(HttpStatusCode.Forbidden) + return@webSocket + } - if (publicUri == thisNode.publicUri) { - logger.trace { "Can't connect as a server to $publicUri. This uri is this node." } + eventPublisher.publish(InboundConnectionRequested(InetAddress.getByName(remoteHost))) + + val publicUriHeader = call.request.headers[PUBLIC_URI_HEADER] + val challengeHeader = call.request.headers[CHALLENGE_HEADER] + + if (publicUriHeader == null || challengeHeader == null) { + logger.trace { "Missing required headers from $remoteHost" } call.respond(HttpStatusCode.BadRequest) return@webSocket } - val challenge = call.request.headers[CHALLENGE_HEADER]!! + val publicUri = URI(publicUriHeader) - logger.trace { "Headers received: publicUri=$publicUri, challenge=$challenge" } + val scheme = publicUri.scheme + if (scheme != "ws" && scheme != "wss") { + logger.trace { "Invalid URI scheme '$scheme' from $remoteHost" } + call.respond(HttpStatusCode.BadRequest) + return@webSocket + } - if (!challenge.isChallengePrefixValid()) { + if (networkProperties.loopbackBlocked && InetAddress.getByName(publicUri.host).isLoopbackAddress) { + logger.trace { "Loopback address not allowed for $publicUri from $remoteHost" } call.respond(HttpStatusCode.BadRequest) + return@webSocket + } + + if (publicUri == thisNode.publicUri) { + logger.trace { "Can't connect as a server to $publicUri. This uri is this node." } + call.respond(HttpStatusCode.BadRequest) + return@webSocket + } + + logger.trace { "Headers received: publicUri=$publicUri, challenge=$challengeHeader" } + + if (!challengeHeader.isChallengePrefixValid()) { logger.trace { "Received invalid challenge prefix request from $publicUri $remoteHost" } call.respond(HttpStatusCode.BadRequest) return@webSocket @@ -240,11 +291,11 @@ class NetworkProcessor( val counterChallenge = ChallengeStore.generate(publicUri) val counterResponse = CounterChallengeResponse( - challenge, + challengeHeader, genesisTransaction.hash, thisNode, timestamp, - signer.sign(AttoChallenge(challenge.fromHexToByteArray()), timestamp), + signer.sign(AttoChallenge(challengeHeader.fromHexToByteArray()), timestamp), counterChallenge, ) val result = @@ -271,7 +322,7 @@ class NetworkProcessor( val response = result.body() - if (!ChallengeStore.remove(publicUri, counterChallenge)) { + if (ChallengeStore.remove(counterChallenge) == null) { connectingMap.remove(publicUri) logger.trace { "Received invalid challenge response from $publicUri $remoteHost $response" } call.respond(HttpStatusCode.BadRequest) @@ -280,6 +331,13 @@ class NetworkProcessor( val node = response.node + if (node.publicUri != publicUri) { + connectingMap.remove(publicUri) + logger.trace { "Node publicUri ${node.publicUri} doesn't match header $publicUri from $remoteHost" } + call.respond(HttpStatusCode.BadRequest) + return@webSocket + } + val counterHash = AttoHash.hash(64, node.publicKey.value, counterChallenge.fromHexToByteArray(), response.timestamp.toByteArray()) @@ -317,7 +375,7 @@ class NetworkProcessor( } @Scheduled(fixedRate = 1_000) - suspend fun boostrap() { + suspend fun bootstrap() { networkProperties .defaultNodes .asSequence() @@ -406,6 +464,11 @@ class NetworkProcessor( return } + if (node.publicUri != publicUri) { + logger.trace { "Node publicUri ${node.publicUri} doesn't match expected $publicUri" } + return + } + connectionManager.manage(node, connectionSocketAddress, session) } catch (e: Exception) { logger.trace(e) { "Exception while trying to connect to $publicUri" } diff --git a/src/main/kotlin/cash/atto/node/network/NetworkProperties.kt b/src/main/kotlin/cash/atto/node/network/NetworkProperties.kt index 26a40670..f6af1e90 100644 --- a/src/main/kotlin/cash/atto/node/network/NetworkProperties.kt +++ b/src/main/kotlin/cash/atto/node/network/NetworkProperties.kt @@ -8,4 +8,5 @@ import org.springframework.context.annotation.Configuration class NetworkProperties { var expirationTimeInSeconds: Long = 300 var defaultNodes: MutableSet = HashSet() + var loopbackBlocked: Boolean = true } diff --git a/src/main/kotlin/cash/atto/node/network/NodeConnectionManager.kt b/src/main/kotlin/cash/atto/node/network/NodeConnectionManager.kt index aeeccbfe..efb57a45 100644 --- a/src/main/kotlin/cash/atto/node/network/NodeConnectionManager.kt +++ b/src/main/kotlin/cash/atto/node/network/NodeConnectionManager.kt @@ -11,6 +11,7 @@ import com.github.benmanes.caffeine.cache.Scheduler import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.websocket.Frame import io.ktor.websocket.WebSocketSession +import io.ktor.websocket.close import io.ktor.websocket.readBytes import jakarta.annotation.PreDestroy import kotlinx.coroutines.CoroutineScope @@ -92,9 +93,10 @@ class NodeConnectionManager( val publicUri = node.publicUri val connection = NodeConnection(node, connectionSocketAddress, session) - if (connectionMap.putIfAbsent(publicUri, connection) != null) { - logger.trace { "Connection to ${node.publicUri} already managed. New connection will be ignored" } - connection.disconnect() + val previousConnection = connectionMap.put(publicUri, connection) + if (previousConnection != null) { + logger.trace { "Connection to ${node.publicUri} already managed. Replacing with new connection (last wins)" } + previousConnection.disconnect() } try { @@ -163,14 +165,13 @@ class NodeConnectionManager( @EventListener fun send(networkMessage: BroadcastNetworkMessage<*>) { - val strategy = networkMessage.strategy val message = networkMessage.payload logger.trace { "Sending $networkMessage" } connectionMap.values .asSequence() - .filter { strategy.shouldBroadcast(it.node) } + .filter { networkMessage.accepts(it.node.publicUri, it.node) } .forEach { connection -> scope.launch { send(connection.node.publicUri, message) @@ -178,6 +179,17 @@ class NodeConnectionManager( } } + @EventListener + fun ban(event: NodeBanned) { + connectionMap + .entries + .filter { it.value.connectionInetSocketAddress.address == event.address } + .forEach { (uri, _) -> + logger.info { "Disconnecting $uri due to ban of ${event.address}" } + connectionMap.remove(uri) + } + } + @Scheduled(fixedRate = 10_000) fun keepAlive() { val sample = connectionMap.toMap().values.randomOrNull() @@ -185,12 +197,6 @@ class NodeConnectionManager( send(BroadcastNetworkMessage(strategy = BroadcastStrategy.EVERYONE, payload = message)) } - private fun BroadcastStrategy.shouldBroadcast(node: AttoNode): Boolean = - when (this) { - BroadcastStrategy.EVERYONE -> true - BroadcastStrategy.VOTERS -> node.isVoter() - } - private inner class NodeConnection( val node: AttoNode, val connectionInetSocketAddress: InetSocketAddress, @@ -205,8 +211,13 @@ class NodeConnectionManager( .onCompletion { cause -> logger.info(cause) { "Disconnected from ${node.publicUri}" } } .map { it.readBytes() } - fun disconnect() { - session.cancel() + suspend fun disconnect() { + try { + session.close() + } catch (e: Exception) { + logger.trace(e) { "Exception during graceful close of ${node.publicUri}, cancelling session" } + session.cancel() + } } suspend fun send(message: ByteArray) { diff --git a/src/main/kotlin/cash/atto/node/network/guardian/Guardian.kt b/src/main/kotlin/cash/atto/node/network/guardian/Guardian.kt index 4d61844d..ab99784a 100644 --- a/src/main/kotlin/cash/atto/node/network/guardian/Guardian.kt +++ b/src/main/kotlin/cash/atto/node/network/guardian/Guardian.kt @@ -3,6 +3,7 @@ package cash.atto.node.network.guardian import cash.atto.commons.AttoPublicKey import cash.atto.node.CacheSupport import cash.atto.node.EventPublisher +import cash.atto.node.InboundConnectionRequested import cash.atto.node.network.DirectNetworkMessage import cash.atto.node.network.InboundNetworkMessage import cash.atto.node.network.NodeBanned @@ -75,6 +76,14 @@ class Guardian( } } + @EventListener + fun count(event: InboundConnectionRequested) { + val socketAddress = InetSocketAddress(event.address, 0) + statisticsMap.compute(socketAddress) { _, v -> + (v ?: 0UL) + 1UL + } + } + @EventListener fun add(nodeEvent: NodeConnected) { if (nodeEvent.node.isNotVoter()) { @@ -89,37 +98,39 @@ class Guardian( voterMap.remove(nodeEvent.connectionSocketAddress) } - @Scheduled(fixedRate = 1, timeUnit = TimeUnit.SECONDS) + @Scheduled(fixedRate = 15, timeUnit = TimeUnit.SECONDS) @Synchronized fun guard() { val newSnapshot = statisticsMap.toMap() val differenceMap = calculateDifference(newSnapshot, snapshot) - val median = median(extractVoters(differenceMap).values) + snapshot = newSnapshot + + val voterValues = extractVoters(differenceMap).values + if (voterValues.isEmpty()) { + return + } + + val median = median(voterValues) if (median < guardianProperties.minimalMedian) { return } + val threshold = median * guardianProperties.toleranceMultiplier + val mergedDifferenceMap = differenceMap .entries .groupBy({ it.key.address }, { it.value }) .mapValues { it.value.sum() } - val maliciousActors = - mergedDifferenceMap - .entries - .associateBy({ it.value }, { it.key }) - .toSortedMap() - .tailMap(median * guardianProperties.toleranceMultiplier) - - maliciousActors.forEach { - logger.info { "Banning ${it.value}. Received ${it.key} requests while median of voters is $median per second" } - eventPublisher.publish(NodeBanned(it.value)) - } - - snapshot = newSnapshot + mergedDifferenceMap + .filter { it.value >= threshold } + .forEach { (address, count) -> + logger.info { "Banning $address. Received $count requests while median of voters is $median per second" } + eventPublisher.publish(NodeBanned(address)) + } } private fun calculateDifference( @@ -143,9 +154,7 @@ class Guardian( } private fun median(hits: Collection): ULong { - if (hits.isEmpty()) { - return ULong.MAX_VALUE - } + require(hits.isNotEmpty()) { "Cannot compute median of empty collection" } val sortedHits = hits.sorted() val middle = sortedHits.size / 2 return if (sortedHits.size % 2 == 0) { diff --git a/src/test/kotlin/cash/atto/node/network/peer/PeerStepDefinition.kt b/src/test/kotlin/cash/atto/node/network/peer/PeerStepDefinition.kt index d2a79d74..c3d3fca2 100644 --- a/src/test/kotlin/cash/atto/node/network/peer/PeerStepDefinition.kt +++ b/src/test/kotlin/cash/atto/node/network/peer/PeerStepDefinition.kt @@ -26,7 +26,7 @@ class PeerStepDefinition( nodeStepDefinition.startNeighbour(shortId) nodeStepDefinition.setAsDefaultNode() runBlocking { - networkProcessor.boostrap() + networkProcessor.bootstrap() } checkPeer("THIS", shortId) @@ -36,7 +36,7 @@ class PeerStepDefinition( @When("default handshake starts") fun startDefaultHandshake() { runBlocking { - networkProcessor.boostrap() + networkProcessor.bootstrap() } } diff --git a/src/test/resources/application-default.yaml b/src/test/resources/application-default.yaml index 33f04b8b..55c07fd8 100644 --- a/src/test/resources/application-default.yaml +++ b/src/test/resources/application-default.yaml @@ -13,6 +13,8 @@ spring: include-stacktrace: ALWAYS atto: + network: + loopback-blocked: false transaction: prioritization: frequency: 1