diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala index 2fb1e3ca8d..a405c96efe 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala @@ -22,7 +22,7 @@ import org.sqlite.SQLiteConnection import scodec.Codec import scodec.bits.{BitVector, ByteVector} -import java.sql.{Connection, ResultSet, Statement} +import java.sql.{Connection, ResultSet, Statement, Timestamp} import java.util.UUID import javax.sql.DataSource import scala.collection.immutable.Queue @@ -123,18 +123,16 @@ trait JdbcUtils { def getByteVector32FromHexNullable(columnLabel: String): Option[ByteVector32] = { val s = rs.getString(columnLabel) - if (rs.wasNull()) None else { - Some(ByteVector32(ByteVector.fromValidHex(s))) - } + if (rs.wasNull()) None else Some(ByteVector32(ByteVector.fromValidHex(s))) } def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_)) def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel)) - def getByteVectorNullable(columnLabel: String): ByteVector = { + def getByteVectorNullable(columnLabel: String): Option[ByteVector] = { val result = rs.getBytes(columnLabel) - if (rs.wasNull()) ByteVector.empty else ByteVector(result) + if (rs.wasNull()) None else Some(ByteVector(result)) } def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel))) @@ -164,6 +162,11 @@ trait JdbcUtils { if (rs.wasNull()) None else Some(MilliSatoshi(result)) } + def getTimestampNullable(label: String): Option[Timestamp] = { + val result = rs.getTimestamp(label) + if (rs.wasNull()) None else Some(result) + } + } object ExtendedResultSet { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index 19244891f0..2e44c80ccc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -29,7 +29,8 @@ import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey import fr.acinq.eclair.{MilliSatoshi, MilliSatoshiLong} import grizzled.slf4j.Logging -import java.sql.Statement +import java.sql.{Statement, Timestamp} +import java.time.Instant import java.util.UUID import javax.sql.DataSource import scala.collection.immutable.Queue @@ -40,7 +41,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { import ExtendedResultSet._ val DB_NAME = "audit" - val CURRENT_VERSION = 5 + val CURRENT_VERSION = 6 case class RelayedPart(channelId: ByteVector32, amount: MilliSatoshi, direction: String, relayType: String, timestamp: Long) @@ -52,15 +53,25 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX relayed_trampoline_payment_hash_idx ON relayed_trampoline(payment_hash)") } + def migration56(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE sent ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE received ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE relayed ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE relayed_trampoline ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE network_fees ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE channel_events ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE channel_errors ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + } + getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON sent(timestamp)") statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)") @@ -74,6 +85,10 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { case Some(v@4) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration45(statement) + migration56(statement) + case Some(v@5) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration56(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -90,7 +105,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setBoolean(4, e.isFunder) statement.setBoolean(5, e.isPrivate) statement.setString(6, e.event.label) - statement.setLong(7, System.currentTimeMillis) + statement.setTimestamp(7, Timestamp.from(Instant.now())) statement.executeUpdate() } } @@ -109,7 +124,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(7, e.paymentPreimage.toHex) statement.setString(8, e.recipientNodeId.value.toHex) statement.setString(9, p.toChannelId.toHex) - statement.setLong(10, p.timestamp) + statement.setTimestamp(10, Timestamp.from(Instant.ofEpochMilli(p.timestamp))) statement.addBatch() }) statement.executeBatch() @@ -124,7 +139,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setLong(1, p.amount.toLong) statement.setString(2, e.paymentHash.toHex) statement.setString(3, p.fromChannelId.toHex) - statement.setLong(4, p.timestamp) + statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(p.timestamp))) statement.addBatch() }) statement.executeBatch() @@ -143,7 +158,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(1, e.paymentHash.toHex) statement.setLong(2, nextTrampolineAmount.toLong) statement.setString(3, nextTrampolineNodeId.value.toHex) - statement.setLong(4, e.timestamp) + statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(e.timestamp))) statement.executeUpdate() } // trampoline relayed payments do MPP aggregation and may have M inputs and N outputs @@ -156,7 +171,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(3, p.channelId.toHex) statement.setString(4, p.direction) statement.setString(5, p.relayType) - statement.setLong(6, e.timestamp) + statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(e.timestamp))) statement.executeUpdate() } } @@ -171,7 +186,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(3, e.tx.txid.toHex) statement.setLong(4, e.fee.toLong) statement.setString(5, e.txType) - statement.setLong(6, System.currentTimeMillis) + statement.setTimestamp(6, Timestamp.from(Instant.now())) statement.executeUpdate() } } @@ -189,7 +204,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(3, errorName) statement.setString(4, errorMessage) statement.setBoolean(5, e.isFatal) - statement.setLong(6, System.currentTimeMillis) + statement.setTimestamp(6, Timestamp.from(Instant.now())) statement.executeUpdate() } } @@ -197,9 +212,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listSent(from: Long, to: Long): Seq[PaymentSent] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() var sentByParentId = Map.empty[UUID, PaymentSent] while (rs.next()) { @@ -210,7 +225,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { MilliSatoshi(rs.getLong("fees_msat")), rs.getByteVector32FromHex("to_channel_id"), None, // we don't store the route in the audit DB - rs.getLong("timestamp")) + rs.getTimestamp("timestamp").getTime) val sent = sentByParentId.get(parentId) match { case Some(s) => s.copy(parts = s.parts :+ part) case None => PaymentSent( @@ -229,9 +244,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listReceived(from: Long, to: Long): Seq[PaymentReceived] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() var receivedByHash = Map.empty[ByteVector32, PaymentReceived] while (rs.next()) { @@ -239,7 +254,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { val part = PaymentReceived.PartialPayment( MilliSatoshi(rs.getLong("amount_msat")), rs.getByteVector32FromHex("from_channel_id"), - rs.getLong("timestamp")) + rs.getTimestamp("timestamp").getTime) val received = receivedByHash.get(paymentHash) match { case Some(r) => r.copy(parts = r.parts :+ part) case None => PaymentReceived(paymentHash, Seq(part)) @@ -253,9 +268,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] = inTransaction { pg => var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)] - using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() while (rs.next()) { val paymentHash = rs.getByteVector32FromHex("payment_hash") @@ -264,9 +279,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { trampolineByHash += (paymentHash -> (amount, nodeId)) } } - using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]] while (rs.next()) { @@ -276,7 +291,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { MilliSatoshi(rs.getLong("amount_msat")), rs.getString("direction"), rs.getString("relay_type"), - rs.getLong("timestamp")) + rs.getTimestamp("timestamp").getTime) relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part)) } relayedByHash.flatMap { @@ -300,9 +315,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() var q: Queue[NetworkFee] = Queue() while (rs.next()) { @@ -312,7 +327,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { txId = rs.getByteVector32FromHex("tx_id"), fee = Satoshi(rs.getLong("fee_sat")), txType = rs.getString("tx_type"), - timestamp = rs.getLong("timestamp")) + timestamp = rs.getTimestamp("timestamp").getTime) } q } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala index dae8dfc6fb..29e6c20ed2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala @@ -27,7 +27,8 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec import grizzled.slf4j.Logging -import java.sql.Statement +import java.sql.{Statement, Timestamp} +import java.time.Instant import javax.sql.DataSource import scala.collection.immutable.Queue @@ -38,7 +39,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit import lock._ val DB_NAME = "channels" - val CURRENT_VERSION = 3 + val CURRENT_VERSION = 4 inTransaction { pg => using(pg.createStatement()) { statement => @@ -51,14 +52,28 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN closed_timestamp BIGINT") } + def migration34(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN created_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + created_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_sent_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_sent_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_received_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_received_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_connected_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_connected_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN closed_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + closed_timestamp * interval '1 millisecond'") + + statement.executeUpdate("ALTER TABLE htlc_infos ALTER COLUMN commitment_number SET DATA TYPE BIGINT USING commitment_number::BIGINT") + } + getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp BIGINT, last_payment_sent_timestamp BIGINT, last_payment_received_timestamp BIGINT, last_connected_timestamp BIGINT, closed_timestamp BIGINT)") - statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp TIMESTAMP WITH TIME ZONE, last_payment_sent_timestamp TIMESTAMP WITH TIME ZONE, last_payment_received_timestamp TIMESTAMP WITH TIME ZONE, last_connected_timestamp TIMESTAMP WITH TIME ZONE, closed_timestamp TIMESTAMP WITH TIME ZONE)") + statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number BIGINT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") case Some(v@2) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration23(statement) + migration34(statement) + case Some(v@3) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration34(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -89,7 +104,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit private def updateChannelMetaTimestampColumn(channelId: ByteVector32, columnName: String): Unit = { inTransaction { pg => using(pg.prepareStatement(s"UPDATE local_channels SET $columnName=? WHERE channel_id=?")) { statement => - statement.setLong(1, System.currentTimeMillis) + statement.setTimestamp(1, Timestamp.from(Instant.now())) statement.setString(2, channelId.toHex) statement.executeUpdate() } @@ -152,7 +167,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit withLock { pg => using(pg.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement => statement.setString(1, channelId.toHex) - statement.setString(2, commitmentNumber.toString) + statement.setLong(2, commitmentNumber) val rs = statement.executeQuery var q: Queue[(ByteVector32, CltvExpiry)] = Queue() while (rs.next()) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala index 55082bf16f..70a62846c3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala @@ -23,7 +23,6 @@ import fr.acinq.eclair.db.ChannelsDb import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends -import fr.acinq.eclair.payment.{ChannelPaymentRelayed, PaymentEvent, PaymentReceived, PaymentRelayed, PaymentSent} import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec import grizzled.slf4j.Logging @@ -64,7 +63,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { getVersion(statement, DB_NAME) match { case None => statement.executeUpdate("CREATE TABLE local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0, created_timestamp INTEGER, last_payment_sent_timestamp INTEGER, last_payment_received_timestamp INTEGER, last_connected_timestamp INTEGER, closed_timestamp INTEGER)") - statement.executeUpdate("CREATE TABLE htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE TABLE htlc_infos (channel_id BLOB NOT NULL, commitment_number INTEGER NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") case Some(v@1) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala index c38755a666..34279ebea6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala @@ -4,8 +4,8 @@ import akka.actor.ActorSystem import com.opentable.db.postgres.embedded.EmbeddedPostgres import com.zaxxer.hikari.HikariConfig import fr.acinq.eclair.db._ -import fr.acinq.eclair.db.pg.PgUtils.PgLock import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler +import fr.acinq.eclair.db.pg.PgUtils.{PgLock, getVersion, using} import org.postgresql.jdbc.PgConnection import org.sqlite.SQLiteConnection @@ -64,6 +64,7 @@ object TestDatabases { // @formatter:off override val connection: PgConnection = pg.getPostgresDatabase.getConnection.asInstanceOf[PgConnection] + // NB: we use a lazy val here: databases won't be initialized until we reference that variable override lazy val db: Databases = Databases.PostgresDatabases(hikariConfig, UUID.randomUUID(), lock, jdbcUrlFile_opt = Some(jdbcUrlFile), readOnlyUser_opt = None) override def close(): Unit = pg.close() // @formatter:on @@ -77,4 +78,23 @@ object TestDatabases { // @formatter:on } + def migrationCheck(dbs: TestDatabases, + initializeTables: Connection => Unit, + dbName: String, + targetVersion: Int, + postCheck: Connection => Unit + ): Unit = { + val connection = dbs.connection + // initialize the database to a previous version and populate data + initializeTables(connection) + // this will trigger the initialization of tables and the migration + val _ = dbs.db + // check that db version was updated + using(connection.createStatement()) { statement => + assert(getVersion(statement, dbName).contains(targetVersion)) + } + // post-migration checks + postCheck(connection) + } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala index 3e0fe7acfa..59ec639baa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{ByteVector32, SatoshiLong, Transaction} -import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases} +import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, migrationCheck} import fr.acinq.eclair._ import fr.acinq.eclair.channel.Helpers.Closing.MutualClose import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError} @@ -26,7 +26,7 @@ import fr.acinq.eclair.db.AuditDb.Stats import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.jdbc.JdbcUtils.using import fr.acinq.eclair.db.pg.PgAuditDb -import fr.acinq.eclair.db.pg.PgUtils.{inTransaction, setVersion} +import fr.acinq.eclair.db.pg.PgUtils.{getVersion, setVersion} import fr.acinq.eclair.db.sqlite.SqliteAuditDb import fr.acinq.eclair.payment._ import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey @@ -34,8 +34,9 @@ import fr.acinq.eclair.wire.protocol.Error import org.scalatest.Tag import org.scalatest.funsuite.AnyFunSuite +import java.sql.Timestamp +import java.time.Instant import java.util.UUID -import javax.sql.DataSource import scala.concurrent.duration._ import scala.util.Random @@ -182,13 +183,20 @@ class AuditDbSpec extends AnyFunSuite { } } - test("handle migration version 1 -> 5") { - forAllDbs { - case _: TestPgDatabases => // no migration - case dbs: TestSqliteDatabases => - import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion - val connection = dbs.connection + test("migrate sqlite audit database v1 -> v5") { + + val dbs = TestSqliteDatabases() + val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil) + val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None) + val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None) + val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil) + val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true) + + migrationCheck( + dbs = dbs, + initializeTables = connection => { // simulate existing previous version db using(connection.createStatement()) { statement => statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") @@ -208,17 +216,6 @@ class AuditDbSpec extends AnyFunSuite { setVersion(statement, "audit", 1) } - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(1)) - } - - val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil) - val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None) - val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None) - val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil) - val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) - val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true) - // add a row (no ID on sent) using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setLong(1, ps.recipientAmount.toLong) @@ -229,15 +226,12 @@ class AuditDbSpec extends AnyFunSuite { statement.setLong(6, ps.timestamp) statement.executeUpdate() } - - val migratedDb = new SqliteAuditDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) - } - + }, + dbName = "audit", + targetVersion = 5, + postCheck = connection => { // existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default - assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))))) + assert(dbs.audit.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))))) val postMigrationDb = new SqliteAuditDb(connection) @@ -252,16 +246,19 @@ class AuditDbSpec extends AnyFunSuite { // the old record will have the UNKNOWN_UUID but the new ones will have their actual id val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1) assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected) - } + } + ) } - test("handle migration version 2 -> 5") { - forAllDbs { - case _: TestPgDatabases => // no migration - case dbs: TestSqliteDatabases => - import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion - val connection = dbs.connection + test("migrate sqlite audit database v2 -> v5") { + val dbs = TestSqliteDatabases() + + val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true) + migrationCheck( + dbs = dbs, + initializeTables = connection => { // simulate existing previous version db using(connection.createStatement()) { statement => statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") @@ -280,39 +277,39 @@ class AuditDbSpec extends AnyFunSuite { setVersion(statement, "audit", 2) } - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(2)) - } - - val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) - val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true) - - val migratedDb = new SqliteAuditDb(connection) - + }, + dbName = "audit", + targetVersion = 5, + postCheck = connection => { + val migratedDb = dbs.audit using(connection.createStatement()) { statement => assert(getVersion(statement, "audit").contains(5)) } - migratedDb.add(e1) val postMigrationDb = new SqliteAuditDb(connection) - using(connection.createStatement()) { statement => assert(getVersion(statement, "audit").contains(5)) } - postMigrationDb.add(e2) - } + } + ) } - test("handle migration version 3 -> 5") { - forAllDbs { - case _: TestPgDatabases => // no migration - case dbs: TestSqliteDatabases => - import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion - val connection = dbs.connection + test("migrate sqlite audit database v3 -> v5") { + + val dbs = TestSqliteDatabases() + + val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100) + val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110) + val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil) + val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) + val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115) + + migrationCheck( + dbs = dbs, + initializeTables = connection => { // simulate existing previous version db using(connection.createStatement()) { statement => statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") @@ -334,14 +331,6 @@ class AuditDbSpec extends AnyFunSuite { setVersion(statement, "audit", 3) } - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(3)) - } - - val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100) - val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110) - val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil) - for (pp <- Seq(pp1, pp2)) { using(connection.prepareStatement("INSERT INTO sent (amount_msat, fees_msat, payment_hash, payment_preimage, to_channel_id, timestamp, id) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setLong(1, pp.amount.toLong) @@ -355,9 +344,6 @@ class AuditDbSpec extends AnyFunSuite { } } - val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) - val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115) - for (relayed <- Seq(relayed1, relayed2)) { using(connection.prepareStatement("INSERT INTO relayed (amount_in_msat, amount_out_msat, payment_hash, from_channel_id, to_channel_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setLong(1, relayed.amountIn.toLong) @@ -369,12 +355,14 @@ class AuditDbSpec extends AnyFunSuite { statement.executeUpdate() } } - - val migratedDb = new SqliteAuditDb(connection) + }, + dbName = "audit", + targetVersion = 5, + postCheck = connection => { + val migratedDb = dbs.audit using(connection.createStatement()) { statement => assert(getVersion(statement, "audit").contains(5)) } - assert(migratedDb.listSent(50, 150).toSet === Set( ps1.copy(id = pp1.id, recipientAmount = pp1.amount, parts = pp1 :: Nil), ps1.copy(id = pp2.id, recipientAmount = pp2.amount, parts = pp2 :: Nil) @@ -382,214 +370,194 @@ class AuditDbSpec extends AnyFunSuite { assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) val postMigrationDb = new SqliteAuditDb(connection) - using(connection.createStatement()) { statement => assert(getVersion(statement, "audit").contains(5)) } - val ps2 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, randomKey.publicKey, Seq( PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 160), PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 165) )) val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) - postMigrationDb.add(ps2) assert(postMigrationDb.listSent(155, 200) === Seq(ps2)) postMigrationDb.add(relayed3) assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) - } + } + ) } - test("handle migration version 4 -> 5") { - forAllDbs { - case dbs: TestPgDatabases => - import fr.acinq.eclair.db.pg.PgUtils.getVersion - implicit val datasource: DataSource = dbs.datasource + test("migrate audit database v4 -> v5/v6") { - // simulate existing previous version db - inTransaction { pg => - using(pg.createStatement()) { statement => - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)") - - statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") - - setVersion(statement, "audit", 4) - } - } + val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) + val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110) - inTransaction { pg => - using(pg.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(4)) - } - } + forAllDbs { + case dbs: TestPgDatabases => + migrationCheck( + dbs = dbs, + initializeTables = connection => { + // simulate existing previous version db + using(connection.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)") + + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") + + setVersion(statement, "audit", 4) + } - val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) - val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110) - - inTransaction { pg => - using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setString(1, relayed1.paymentHash.toHex) - statement.setLong(2, relayed1.amountIn.toLong) - statement.setString(3, relayed1.fromChannelId.toHex) - statement.setString(4, "IN") - statement.setString(5, "channel") - statement.setLong(6, relayed1.timestamp) - statement.executeUpdate() - } - using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setString(1, relayed1.paymentHash.toHex) - statement.setLong(2, relayed1.amountOut.toLong) - statement.setString(3, relayed1.toChannelId.toHex) - statement.setString(4, "OUT") - statement.setString(5, "channel") - statement.setLong(6, relayed1.timestamp) - statement.executeUpdate() - } - for (incoming <- relayed2.incoming) { - using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setString(1, relayed2.paymentHash.toHex) - statement.setLong(2, incoming.amount.toLong) - statement.setString(3, incoming.channelId.toHex) + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, relayed1.paymentHash.toHex) + statement.setLong(2, relayed1.amountIn.toLong) + statement.setString(3, relayed1.fromChannelId.toHex) statement.setString(4, "IN") - statement.setString(5, "trampoline") - statement.setLong(6, relayed2.timestamp) + statement.setString(5, "channel") + statement.setLong(6, relayed1.timestamp) statement.executeUpdate() } - } - for (outgoing <- relayed2.outgoing) { - using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setString(1, relayed2.paymentHash.toHex) - statement.setLong(2, outgoing.amount.toLong) - statement.setString(3, outgoing.channelId.toHex) + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, relayed1.paymentHash.toHex) + statement.setLong(2, relayed1.amountOut.toLong) + statement.setString(3, relayed1.toChannelId.toHex) statement.setString(4, "OUT") - statement.setString(5, "trampoline") - statement.setLong(6, relayed2.timestamp) + statement.setString(5, "channel") + statement.setLong(6, relayed1.timestamp) statement.executeUpdate() } - } - } - - val migratedDb = new PgAuditDb()(datasource) - inTransaction { pg => - using(pg.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) - } - } - - assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) - - val postMigrationDb = new PgAuditDb()(datasource) + for (incoming <- relayed2.incoming) { + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, relayed2.paymentHash.toHex) + statement.setLong(2, incoming.amount.toLong) + statement.setString(3, incoming.channelId.toHex) + statement.setString(4, "IN") + statement.setString(5, "trampoline") + statement.setLong(6, relayed2.timestamp) + statement.executeUpdate() + } + } + for (outgoing <- relayed2.outgoing) { + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, relayed2.paymentHash.toHex) + statement.setLong(2, outgoing.amount.toLong) + statement.setString(3, outgoing.channelId.toHex) + statement.setString(4, "OUT") + statement.setString(5, "trampoline") + statement.setLong(6, relayed2.timestamp) + statement.executeUpdate() + } + } + }, + dbName = "audit", + targetVersion = 6, + postCheck = connection => { + val migratedDb = dbs.audit + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit").contains(6)) + } + assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) - inTransaction { pg => - using(pg.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) + val postMigrationDb = new PgAuditDb()(dbs.datasource) + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit").contains(6)) + } + val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) + postMigrationDb.add(relayed3) + assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) } - } - - val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) - - postMigrationDb.add(relayed3) - assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) + ) case dbs: TestSqliteDatabases => - import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion - val connection = dbs.connection - - // simulate existing previous version db - using(connection.createStatement()) { statement => - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)") - - statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") - - setVersion(statement, "audit", 4) - } - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(4)) - } + migrationCheck( + dbs = dbs, + initializeTables = connection => { + // simulate existing previous version db + using(connection.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)") + + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") + + setVersion(statement, "audit", 4) + } - val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) - val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110) + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, relayed1.paymentHash.toArray) + statement.setLong(2, relayed1.amountIn.toLong) + statement.setBytes(3, relayed1.fromChannelId.toArray) + statement.setString(4, "IN") + statement.setString(5, "channel") + statement.setLong(6, relayed1.timestamp) + statement.executeUpdate() + } + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, relayed1.paymentHash.toArray) + statement.setLong(2, relayed1.amountOut.toLong) + statement.setBytes(3, relayed1.toChannelId.toArray) + statement.setString(4, "OUT") + statement.setString(5, "channel") + statement.setLong(6, relayed1.timestamp) + statement.executeUpdate() + } + for (incoming <- relayed2.incoming) { + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, relayed2.paymentHash.toArray) + statement.setLong(2, incoming.amount.toLong) + statement.setBytes(3, incoming.channelId.toArray) + statement.setString(4, "IN") + statement.setString(5, "trampoline") + statement.setLong(6, relayed2.timestamp) + statement.executeUpdate() + } + } + for (outgoing <- relayed2.outgoing) { + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, relayed2.paymentHash.toArray) + statement.setLong(2, outgoing.amount.toLong) + statement.setBytes(3, outgoing.channelId.toArray) + statement.setString(4, "OUT") + statement.setString(5, "trampoline") + statement.setLong(6, relayed2.timestamp) + statement.executeUpdate() + } + } + }, + dbName = "audit", + targetVersion = 5, + postCheck = connection => { + val migratedDb = dbs.audit + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit").contains(5)) + } + assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) - using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setBytes(1, relayed1.paymentHash.toArray) - statement.setLong(2, relayed1.amountIn.toLong) - statement.setBytes(3, relayed1.fromChannelId.toArray) - statement.setString(4, "IN") - statement.setString(5, "channel") - statement.setLong(6, relayed1.timestamp) - statement.executeUpdate() - } - using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setBytes(1, relayed1.paymentHash.toArray) - statement.setLong(2, relayed1.amountOut.toLong) - statement.setBytes(3, relayed1.toChannelId.toArray) - statement.setString(4, "OUT") - statement.setString(5, "channel") - statement.setLong(6, relayed1.timestamp) - statement.executeUpdate() - } - for (incoming <- relayed2.incoming) { - using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setBytes(1, relayed2.paymentHash.toArray) - statement.setLong(2, incoming.amount.toLong) - statement.setBytes(3, incoming.channelId.toArray) - statement.setString(4, "IN") - statement.setString(5, "trampoline") - statement.setLong(6, relayed2.timestamp) - statement.executeUpdate() - } - } - for (outgoing <- relayed2.outgoing) { - using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setBytes(1, relayed2.paymentHash.toArray) - statement.setLong(2, outgoing.amount.toLong) - statement.setBytes(3, outgoing.channelId.toArray) - statement.setString(4, "OUT") - statement.setString(5, "trampoline") - statement.setLong(6, relayed2.timestamp) - statement.executeUpdate() + val postMigrationDb = new SqliteAuditDb(connection) + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit").contains(5)) + } + val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) + postMigrationDb.add(relayed3) + assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) } - } - - val migratedDb = new SqliteAuditDb(connection) - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) - } - - assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) - - val postMigrationDb = new SqliteAuditDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) - } - - val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) - - postMigrationDb.add(relayed3) - assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) + ) } } @@ -605,7 +573,7 @@ class AuditDbSpec extends AnyFunSuite { if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray) statement.setString(4, "IN") statement.setString(5, "unknown") // invalid relay type - statement.setLong(6, 10) + if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(10))) else statement.setLong(6, 10) statement.executeUpdate() } @@ -615,7 +583,7 @@ class AuditDbSpec extends AnyFunSuite { if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray) statement.setString(4, "UP") // invalid direction statement.setString(5, "channel") - statement.setLong(6, 20) + if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(20))) else statement.setLong(6, 20) statement.executeUpdate() } @@ -628,7 +596,7 @@ class AuditDbSpec extends AnyFunSuite { if (isPg) statement.setString(3, channelId.toHex) else statement.setBytes(3, channelId.toArray) statement.setString(4, "IN") // missing a corresponding OUT statement.setString(5, "channel") - statement.setLong(6, 30) + if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(30))) else statement.setLong(6, 30) statement.executeUpdate() } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala index 84994a695d..d26ad9976c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala @@ -18,23 +18,25 @@ package fr.acinq.eclair.db import com.softwaremill.quicklens._ import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases} +import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, migrationCheck} +import fr.acinq.eclair.db.ChannelsDbSpec.{getPgTimestamp, getTimestamp, testCases} import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.jdbc.JdbcUtils.using -import fr.acinq.eclair.db.pg.PgChannelsDb import fr.acinq.eclair.db.pg.PgUtils.{getVersion, setVersion} +import fr.acinq.eclair.db.pg.{PgChannelsDb, PgUtils} import fr.acinq.eclair.db.sqlite.SqliteChannelsDb import fr.acinq.eclair.db.sqlite.SqliteUtils.ExtendedResultSet._ import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec -import fr.acinq.eclair.{CltvExpiry, ShortChannelId, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, ShortChannelId, TestDatabases, randomBytes32} import org.scalatest.funsuite.AnyFunSuite import scodec.bits.ByteVector -import java.sql.SQLException +import java.sql.{Connection, SQLException} import java.util.concurrent.Executors -import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future} +import scala.util.Random class ChannelsDbSpec extends AnyFunSuite { @@ -107,56 +109,42 @@ class ChannelsDbSpec extends AnyFunSuite { test("channel metadata") { forAllDbs { dbs => val db = dbs.channels - val connection = dbs.connection val channel1 = ChannelCodecsSpec.normal val channel2 = channel1.modify(_.commitments.channelId).setTo(randomBytes32) - def getTimestamp(channelId: ByteVector32, columnName: String): Option[Long] = { - using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement => - // data type differs depending on underlying database system - dbs match { - case _: TestPgDatabases => statement.setString(1, channelId.toHex) - case _: TestSqliteDatabases => statement.setBytes(1, channelId.toArray) - } - val rs = statement.executeQuery() - rs.next() - rs.getLongNullable(columnName) - } - } - // first we add channels db.addOrUpdateChannel(channel1) db.addOrUpdateChannel(channel2) // make sure initially all metadata are empty - assert(getTimestamp(channel1.channelId, "created_timestamp").isEmpty) - assert(getTimestamp(channel1.channelId, "last_payment_sent_timestamp").isEmpty) - assert(getTimestamp(channel1.channelId, "last_payment_received_timestamp").isEmpty) - assert(getTimestamp(channel1.channelId, "last_connected_timestamp").isEmpty) - assert(getTimestamp(channel1.channelId, "closed_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "created_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_payment_sent_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_payment_received_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_connected_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "closed_timestamp").isEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Created) - assert(getTimestamp(channel1.channelId, "created_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "created_timestamp").nonEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.PaymentSent) - assert(getTimestamp(channel1.channelId, "last_payment_sent_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_payment_sent_timestamp").nonEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.PaymentReceived) - assert(getTimestamp(channel1.channelId, "last_payment_received_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_payment_received_timestamp").nonEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Connected) - assert(getTimestamp(channel1.channelId, "last_connected_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_connected_timestamp").nonEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Closed(null)) - assert(getTimestamp(channel1.channelId, "closed_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "closed_timestamp").nonEmpty) // make sure all metadata are still empty for channel 2 - assert(getTimestamp(channel2.channelId, "created_timestamp").isEmpty) - assert(getTimestamp(channel2.channelId, "last_payment_sent_timestamp").isEmpty) - assert(getTimestamp(channel2.channelId, "last_payment_received_timestamp").isEmpty) - assert(getTimestamp(channel2.channelId, "last_connected_timestamp").isEmpty) - assert(getTimestamp(channel2.channelId, "closed_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "created_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "last_payment_sent_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "last_payment_received_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "last_connected_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "closed_timestamp").isEmpty) } } @@ -175,13 +163,22 @@ class ChannelsDbSpec extends AnyFunSuite { setVersion(statement, "channels", 1) } - // insert 1 row - val channel = ChannelCodecsSpec.normal - val data = stateDataCodec.encode(channel).require.toByteArray - using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement => - statement.setBytes(1, channel.channelId.toArray) - statement.setBytes(2, data) - statement.executeUpdate() + // insert data + for (testCase <- testCases) { + using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement => + statement.setBytes(1, testCase.channelId.toArray) + statement.setBytes(2, testCase.data.toArray) + statement.executeUpdate() + } + for (commitmentNumber <- testCase.commitmentNumbers) { + using(sqlite.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement => + statement.setBytes(1, testCase.channelId.toArray) + statement.setLong(2, commitmentNumber) + statement.setBytes(3, randomBytes32.toArray) + statement.setLong(4, 500000 + Random.nextInt(500000)) + statement.executeUpdate() + } + } } // check that db migration works @@ -189,71 +186,193 @@ class ChannelsDbSpec extends AnyFunSuite { using(sqlite.createStatement()) { statement => assert(getVersion(statement, "channels").contains(3)) } - assert(db.listLocalChannels() === List(channel)) - db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail + assert(db.listLocalChannels().size === testCases.size) + for (testCase <- testCases) { + db.updateChannelMeta(testCase.channelId, ChannelEvent.EventType.Created) // this call must not fail + for (commitmentNumber <- testCase.commitmentNumbers) { + assert(db.listHtlcInfos(testCase.channelId, commitmentNumber).size === testCase.commitmentNumbers.count(_ == commitmentNumber)) + } + } } } - test("migrate channel database v2 -> v3") { + test("migrate channel database v2 -> v3/v4") { + def postCheck(channelsDb: ChannelsDb): Unit = { + assert(channelsDb.listLocalChannels().size === testCases.filterNot(_.isClosed).size) + for (testCase <- testCases.filterNot(_.isClosed)) { + channelsDb.updateChannelMeta(testCase.channelId, ChannelEvent.EventType.Created) // this call must not fail + for (commitmentNumber <- testCase.commitmentNumbers) { + assert(channelsDb.listHtlcInfos(testCase.channelId, commitmentNumber).size === testCase.commitmentNumbers.count(_ == commitmentNumber)) + } + } + } + forAllDbs { case dbs: TestPgDatabases => - val pg = dbs.connection + migrationCheck( + dbs = dbs, + initializeTables = connection => { + // initialize a v2 database + using(connection.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") + setVersion(statement, "channels", 2) + } + // insert data + testCases.foreach { testCase => + using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement => + statement.setString(1, testCase.channelId.toHex) + statement.setBytes(2, testCase.data.toArray) + statement.setBoolean(3, testCase.isClosed) + statement.executeUpdate() + for (commitmentNumber <- testCase.commitmentNumbers) { + using(connection.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement => + statement.setString(1, testCase.channelId.toHex) + statement.setLong(2, commitmentNumber) + statement.setString(3, randomBytes32.toHex) + statement.setLong(4, 500000 + Random.nextInt(500000)) + statement.executeUpdate() + } + } + } + } + }, + dbName = "channels", + targetVersion = 4, + postCheck = _ => postCheck(dbs.channels) + ) + case dbs: TestSqliteDatabases => + migrationCheck( + dbs = dbs, + initializeTables = connection => { + // create a v2 channels database + using(connection.createStatement()) { statement => + statement.execute("PRAGMA foreign_keys = ON") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") + setVersion(statement, "channels", 2) + } + // insert data + testCases.foreach { testCase => + using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement => + statement.setBytes(1, testCase.channelId.toArray) + statement.setBytes(2, testCase.data.toArray) + statement.setBoolean(3, testCase.isClosed) + statement.executeUpdate() + for (commitmentNumber <- testCase.commitmentNumbers) { + using(connection.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement => + statement.setBytes(1, testCase.channelId.toArray) + statement.setLong(2, commitmentNumber) + statement.setBytes(3, randomBytes32.toArray) + statement.setLong(4, 500000 + Random.nextInt(500000)) + statement.executeUpdate() + } + } + } + } + }, + dbName = "channels", + targetVersion = 3, + postCheck = _ => postCheck(dbs.channels) + ) + } + } - // create a v2 channels database - using(pg.createStatement()) { statement => - statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") - setVersion(statement, "channels", 2) + test("migrate pg channel database v3->v4") { + val dbs = TestPgDatabases() + + migrationCheck( + dbs = dbs, + initializeTables = connection => { + using(connection.createStatement()) { statement => + // initialize a v3 database + statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp BIGINT, last_payment_sent_timestamp BIGINT, last_payment_received_timestamp BIGINT, last_connected_timestamp BIGINT, closed_timestamp BIGINT)") + statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") + PgUtils.setVersion(statement, "channels", 3) } - - // insert 1 row - val channel = ChannelCodecsSpec.normal - val data = stateDataCodec.encode(channel).require.toByteArray - using(pg.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement => - statement.setString(1, channel.channelId.toHex) - statement.setBytes(2, data) - statement.setBoolean(3, false) - statement.executeUpdate() + // insert data + testCases.foreach { testCase => + using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed, created_timestamp, last_payment_sent_timestamp, last_payment_received_timestamp, last_connected_timestamp, closed_timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, testCase.channelId.toHex) + statement.setBytes(2, testCase.data.toArray) + statement.setBoolean(3, testCase.isClosed) + statement.setObject(4, testCase.createdTimestamp.orNull) + statement.setObject(5, testCase.lastPaymentSentTimestamp.orNull) + statement.setObject(6, testCase.lastPaymentReceivedTimestamp.orNull) + statement.setObject(7, testCase.lastConnectedTimestamp.orNull) + statement.setObject(8, testCase.closedTimestamp.orNull) + statement.executeUpdate() + } } - - // check that db migration works - val db = dbs.channels - using(pg.createStatement()) { statement => - assert(getVersion(statement, "channels").contains(3)) + }, + dbName = "channels", + targetVersion = 4, + postCheck = connection => { + assert(dbs.channels.listLocalChannels().size === testCases.filterNot(_.isClosed).size) + testCases.foreach { testCase => + assert(getPgTimestamp(connection, testCase.channelId, "created_timestamp") === testCase.createdTimestamp) + assert(getPgTimestamp(connection, testCase.channelId, "last_payment_sent_timestamp") === testCase.lastPaymentSentTimestamp) + assert(getPgTimestamp(connection, testCase.channelId, "last_payment_received_timestamp") === testCase.lastPaymentReceivedTimestamp) + assert(getPgTimestamp(connection, testCase.channelId, "last_connected_timestamp") === testCase.lastConnectedTimestamp) + assert(getPgTimestamp(connection, testCase.channelId, "closed_timestamp") === testCase.closedTimestamp) } - assert(db.listLocalChannels() === List(channel)) - db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail - - case dbs: TestSqliteDatabases => - val sqlite = dbs.connection + } + ) - // create a v2 channels database - using(sqlite.createStatement()) { statement => - statement.execute("PRAGMA foreign_keys = ON") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") - setVersion(statement, "channels", 2) - } + } +} + +object ChannelsDbSpec { + + case class TestCase(channelId: ByteVector32, + data: ByteVector, + isClosed: Boolean, + createdTimestamp: Option[Long], + lastPaymentSentTimestamp: Option[Long], + lastPaymentReceivedTimestamp: Option[Long], + lastConnectedTimestamp: Option[Long], + closedTimestamp: Option[Long], + commitmentNumbers: Seq[Int] + ) + + private val data = stateDataCodec.encode(ChannelCodecsSpec.normal).require.bytes + val testCases: Seq[TestCase] = for (_ <- 0 until 10) yield TestCase( + channelId = randomBytes32, + data = data, + isClosed = Random.nextBoolean(), + createdTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + lastPaymentSentTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + lastPaymentReceivedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + lastConnectedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + closedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + commitmentNumbers = for (_ <- 0 until Random.nextInt(10)) yield Random.nextInt(5) // there will be repetitions, on purpose + ) + + def getTimestamp(dbs: TestDatabases, channelId: ByteVector32, columnName: String): Option[Long] = { + dbs match { + case _: TestPgDatabases => getPgTimestamp(dbs.connection, channelId, columnName) + case _: TestSqliteDatabases => getSqliteTimestamp(dbs.connection, channelId, columnName) + } + } - // insert 1 row - val channel = ChannelCodecsSpec.normal - val data = stateDataCodec.encode(channel).require.toByteArray - using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?, ?)")) { statement => - statement.setBytes(1, channel.channelId.toArray) - statement.setBytes(2, data) - statement.setBoolean(3, false) - statement.executeUpdate() - } + def getSqliteTimestamp(connection: Connection, channelId: ByteVector32, columnName: String): Option[Long] = { + using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement => + statement.setBytes(1, channelId.toArray) + val rs = statement.executeQuery() + rs.next() + rs.getLongNullable(columnName) + } + } - // check that db migration works - val db = dbs.channels - using(sqlite.createStatement()) { statement => - assert(getVersion(statement, "channels").contains(3)) - } - assert(db.listLocalChannels() === List(channel)) - db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail + def getPgTimestamp(connection: Connection, channelId: ByteVector32, columnName: String): Option[Long] = { + using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement => + statement.setString(1, channelId.toHex) + val rs = statement.executeQuery() + rs.next() + rs.getTimestampNullable(columnName).map(_.getTime) } } } \ No newline at end of file