Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/main/kotlin/cash/atto/node/Event.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@ import org.springframework.stereotype.Component
import org.springframework.transaction.reactive.TransactionSynchronization
import org.springframework.transaction.reactive.TransactionSynchronizationManager
import reactor.core.publisher.Mono
import java.net.InetAddress
import java.time.Instant

interface Event {
val timestamp: Instant
}

data class InboundConnectionRequested(
val address: InetAddress,
override val timestamp: Instant = Instant.now(),
) : Event

@Component
class EventPublisher(
private val publisher: ApplicationEventPublisher,
Expand Down
16 changes: 12 additions & 4 deletions src/main/kotlin/cash/atto/node/network/BannedMonitor.kt
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
package cash.atto.node.network

import com.github.benmanes.caffeine.cache.Caffeine
import com.github.benmanes.caffeine.cache.Scheduler
import org.springframework.context.event.EventListener
import org.springframework.stereotype.Component
import java.net.InetAddress
import java.util.concurrent.ConcurrentHashMap
import java.time.Duration

@Component
object BannedMonitor {
private val set = ConcurrentHashMap.newKeySet<InetAddress>()
private val banned =
Caffeine
.newBuilder()
.scheduler(Scheduler.systemScheduler())
.expireAfterWrite(Duration.ofHours(1))
.build<InetAddress, Boolean>()
.asMap()

@EventListener
fun store(banned: NodeBanned) {
set.add(banned.address)
BannedMonitor.banned[banned.address] = true
}

fun isBanned(address: InetAddress): Boolean = set.contains(address)
fun isBanned(address: InetAddress): Boolean = banned.containsKey(address)
}
9 changes: 3 additions & 6 deletions src/main/kotlin/cash/atto/node/network/ChallengeStore.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@ internal object ChallengeStore {
.scheduler(Scheduler.systemScheduler())
.expireAfterWrite(5, TimeUnit.SECONDS)
.maximumSize(100_000)
.build<URI, String>()
.build<String, URI>()
.asMap()

fun remove(
publicUri: URI,
challenge: String,
): Boolean = challenges.remove(publicUri, challenge)
fun remove(challenge: String): URI? = challenges.remove(challenge)

fun generate(publicUri: URI): String {
val challengePrefix = publicUri.toString().toByteArray()
Expand All @@ -33,7 +30,7 @@ internal object ChallengeStore {
(challengePrefix + it).toHex()
}

challenges[publicUri] = challenge
challenges[challenge] = publicUri
return challenge
}

Expand Down
97 changes: 80 additions & 17 deletions src/main/kotlin/cash/atto/node/network/NetworkProcessor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import cash.atto.commons.fromHexToByteArray
import cash.atto.commons.isValid
import cash.atto.commons.toByteArray
import cash.atto.node.CacheSupport
import cash.atto.node.EventPublisher
import cash.atto.node.InboundConnectionRequested
import cash.atto.node.transaction.Transaction
import cash.atto.protocol.AttoKeepAlive
import cash.atto.protocol.AttoNode
Expand Down Expand Up @@ -49,9 +51,9 @@ import org.springframework.context.event.EventListener
import org.springframework.core.env.Environment
import org.springframework.scheduling.annotation.Scheduled
import org.springframework.stereotype.Component
import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.URI
import java.security.SecureRandom
import java.time.Duration
import java.util.concurrent.Executors
import kotlin.time.Duration.Companion.seconds
Expand All @@ -64,6 +66,7 @@ class NetworkProcessor(
environment: Environment,
private val networkProperties: NetworkProperties,
private val connectionManager: NodeConnectionManager,
private val eventPublisher: EventPublisher,
) : CacheSupport {
private val logger = KotlinLogging.logger {}

Expand All @@ -74,8 +77,6 @@ class NetworkProcessor(
const val CONNECTION_TIMEOUT_IN_SECONDS = 5L
}

val random = SecureRandom.getInstanceStrong()!!

private val httpClient =
HttpClient(io.ktor.client.engine.cio.CIO) {
install(
Expand Down Expand Up @@ -115,6 +116,7 @@ class NetworkProcessor(
.newBuilder()
.scheduler(Scheduler.systemScheduler())
.expireAfterWrite(Duration.ofSeconds(CONNECTION_TIMEOUT_IN_SECONDS))
.maximumSize(10_000)
.build<URI, MutableSharedFlow<AttoNode>>()
.asMap()

Expand Down Expand Up @@ -143,13 +145,35 @@ class NetworkProcessor(
post("/handshakes") {
try {
val remoteHost = call.request.origin.remoteHost

if (BannedMonitor.isBanned(InetAddress.getByName(remoteHost))) {
logger.trace { "Rejected handshake from banned address $remoteHost" }
call.respond(HttpStatusCode.Forbidden)
return@post
}

eventPublisher.publish(InboundConnectionRequested(InetAddress.getByName(remoteHost)))

val counterResponse = call.receive<CounterChallengeResponse>()
val node = counterResponse.node
val publicUri = node.publicUri
val challenge = counterResponse.challenge

if (!ChallengeStore.remove(publicUri, challenge)) {
logger.trace { "Received invalid challenge request from $publicUri $remoteHost $counterResponse" }
val publicUri = ChallengeStore.remove(challenge)
if (publicUri == null) {
logger.trace { "Received invalid challenge request from $remoteHost $counterResponse" }
call.respond(HttpStatusCode.BadRequest)
return@post
}

val node = counterResponse.node

if (node.publicUri != publicUri) {
logger.trace { "Node publicUri ${node.publicUri} doesn't match expected $publicUri from $remoteHost" }
call.respond(HttpStatusCode.BadRequest)
return@post
}

if (counterResponse.genesis != genesisTransaction.hash) {
logger.trace { "Received mismatched genesis hash from $publicUri $remoteHost $counterResponse" }
call.respond(HttpStatusCode.BadRequest)
return@post
}
Expand Down Expand Up @@ -203,20 +227,47 @@ class NetworkProcessor(

logger.trace { "New websocket connection attempt from $remoteHost" }

val publicUri = URI(call.request.headers[PUBLIC_URI_HEADER]!!)
if (BannedMonitor.isBanned(InetAddress.getByName(remoteHost))) {
logger.trace { "Rejected websocket from banned address $remoteHost" }
call.respond(HttpStatusCode.Forbidden)
return@webSocket
}

if (publicUri == thisNode.publicUri) {
logger.trace { "Can't connect as a server to $publicUri. This uri is this node." }
eventPublisher.publish(InboundConnectionRequested(InetAddress.getByName(remoteHost)))

val publicUriHeader = call.request.headers[PUBLIC_URI_HEADER]
val challengeHeader = call.request.headers[CHALLENGE_HEADER]

if (publicUriHeader == null || challengeHeader == null) {
logger.trace { "Missing required headers from $remoteHost" }
call.respond(HttpStatusCode.BadRequest)
return@webSocket
}

val challenge = call.request.headers[CHALLENGE_HEADER]!!
val publicUri = URI(publicUriHeader)

logger.trace { "Headers received: publicUri=$publicUri, challenge=$challenge" }
val scheme = publicUri.scheme
if (scheme != "ws" && scheme != "wss") {
logger.trace { "Invalid URI scheme '$scheme' from $remoteHost" }
call.respond(HttpStatusCode.BadRequest)
return@webSocket
}

if (!challenge.isChallengePrefixValid()) {
if (networkProperties.loopbackBlocked && InetAddress.getByName(publicUri.host).isLoopbackAddress) {
logger.trace { "Loopback address not allowed for $publicUri from $remoteHost" }
call.respond(HttpStatusCode.BadRequest)
return@webSocket
}

if (publicUri == thisNode.publicUri) {
logger.trace { "Can't connect as a server to $publicUri. This uri is this node." }
call.respond(HttpStatusCode.BadRequest)
return@webSocket
}

logger.trace { "Headers received: publicUri=$publicUri, challenge=$challengeHeader" }

if (!challengeHeader.isChallengePrefixValid()) {
logger.trace { "Received invalid challenge prefix request from $publicUri $remoteHost" }
call.respond(HttpStatusCode.BadRequest)
return@webSocket
Expand All @@ -240,11 +291,11 @@ class NetworkProcessor(
val counterChallenge = ChallengeStore.generate(publicUri)
val counterResponse =
CounterChallengeResponse(
challenge,
challengeHeader,
genesisTransaction.hash,
thisNode,
timestamp,
signer.sign(AttoChallenge(challenge.fromHexToByteArray()), timestamp),
signer.sign(AttoChallenge(challengeHeader.fromHexToByteArray()), timestamp),
counterChallenge,
)
val result =
Expand All @@ -271,7 +322,7 @@ class NetworkProcessor(

val response = result.body<ChallengeResponse>()

if (!ChallengeStore.remove(publicUri, counterChallenge)) {
if (ChallengeStore.remove(counterChallenge) == null) {
connectingMap.remove(publicUri)
logger.trace { "Received invalid challenge response from $publicUri $remoteHost $response" }
call.respond(HttpStatusCode.BadRequest)
Expand All @@ -280,6 +331,13 @@ class NetworkProcessor(

val node = response.node

if (node.publicUri != publicUri) {
connectingMap.remove(publicUri)
logger.trace { "Node publicUri ${node.publicUri} doesn't match header $publicUri from $remoteHost" }
call.respond(HttpStatusCode.BadRequest)
return@webSocket
}

val counterHash =
AttoHash.hash(64, node.publicKey.value, counterChallenge.fromHexToByteArray(), response.timestamp.toByteArray())

Expand Down Expand Up @@ -317,7 +375,7 @@ class NetworkProcessor(
}

@Scheduled(fixedRate = 1_000)
suspend fun boostrap() {
suspend fun bootstrap() {
networkProperties
.defaultNodes
.asSequence()
Expand Down Expand Up @@ -406,6 +464,11 @@ class NetworkProcessor(
return
}

if (node.publicUri != publicUri) {
logger.trace { "Node publicUri ${node.publicUri} doesn't match expected $publicUri" }
return
}

connectionManager.manage(node, connectionSocketAddress, session)
} catch (e: Exception) {
logger.trace(e) { "Exception while trying to connect to $publicUri" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ import org.springframework.context.annotation.Configuration
class NetworkProperties {
var expirationTimeInSeconds: Long = 300
var defaultNodes: MutableSet<String> = HashSet()
var loopbackBlocked: Boolean = true
}
37 changes: 24 additions & 13 deletions src/main/kotlin/cash/atto/node/network/NodeConnectionManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.github.benmanes.caffeine.cache.Scheduler
import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.websocket.Frame
import io.ktor.websocket.WebSocketSession
import io.ktor.websocket.close
import io.ktor.websocket.readBytes
import jakarta.annotation.PreDestroy
import kotlinx.coroutines.CoroutineScope
Expand Down Expand Up @@ -92,9 +93,10 @@ class NodeConnectionManager(
val publicUri = node.publicUri
val connection = NodeConnection(node, connectionSocketAddress, session)

if (connectionMap.putIfAbsent(publicUri, connection) != null) {
logger.trace { "Connection to ${node.publicUri} already managed. New connection will be ignored" }
connection.disconnect()
val previousConnection = connectionMap.put(publicUri, connection)
if (previousConnection != null) {
logger.trace { "Connection to ${node.publicUri} already managed. Replacing with new connection (last wins)" }
previousConnection.disconnect()
}

try {
Expand Down Expand Up @@ -163,34 +165,38 @@ class NodeConnectionManager(

@EventListener
fun send(networkMessage: BroadcastNetworkMessage<*>) {
val strategy = networkMessage.strategy
val message = networkMessage.payload

logger.trace { "Sending $networkMessage" }

connectionMap.values
.asSequence()
.filter { strategy.shouldBroadcast(it.node) }
.filter { networkMessage.accepts(it.node.publicUri, it.node) }
.forEach { connection ->
scope.launch {
send(connection.node.publicUri, message)
}
}
}

@EventListener
fun ban(event: NodeBanned) {
connectionMap
.entries
.filter { it.value.connectionInetSocketAddress.address == event.address }
.forEach { (uri, _) ->
logger.info { "Disconnecting $uri due to ban of ${event.address}" }
connectionMap.remove(uri)
}
}

@Scheduled(fixedRate = 10_000)
fun keepAlive() {
val sample = connectionMap.toMap().values.randomOrNull()
val message = AttoKeepAlive(sample?.node?.publicUri)
send(BroadcastNetworkMessage(strategy = BroadcastStrategy.EVERYONE, payload = message))
}

private fun BroadcastStrategy.shouldBroadcast(node: AttoNode): Boolean =
when (this) {
BroadcastStrategy.EVERYONE -> true
BroadcastStrategy.VOTERS -> node.isVoter()
}

private inner class NodeConnection(
val node: AttoNode,
val connectionInetSocketAddress: InetSocketAddress,
Expand All @@ -205,8 +211,13 @@ class NodeConnectionManager(
.onCompletion { cause -> logger.info(cause) { "Disconnected from ${node.publicUri}" } }
.map { it.readBytes() }

fun disconnect() {
session.cancel()
suspend fun disconnect() {
try {
session.close()
} catch (e: Exception) {
logger.trace(e) { "Exception during graceful close of ${node.publicUri}, cancelling session" }
session.cancel()
}
}

suspend fun send(message: ByteArray) {
Expand Down
Loading
Loading