From 3213ff4afc3b793bbebeb5cd8d1f6e7082575b48 Mon Sep 17 00:00:00 2001 From: pm47 Date: Thu, 15 Apr 2021 17:59:51 +0200 Subject: [PATCH 1/9] update channels db (timestamps + commit_number) Did some refactoring in tests and introduced a new `migrationCheck` helper method. Note that the change of data type in sqlite for the `commitment_number` field (from `BLOB` to `INTEGER`) is not a migration. If the table has been created before, it will stay like it was. It doesn't matter due to how sqlite stores data, and we make sure in tests that there is no regression. --- .../fr/acinq/eclair/db/jdbc/JdbcUtils.scala | 15 +- .../fr/acinq/eclair/db/pg/PgChannelsDb.scala | 27 +- .../eclair/db/sqlite/SqliteChannelsDb.scala | 7 +- .../scala/fr/acinq/eclair/TestDatabases.scala | 23 +- .../fr/acinq/eclair/db/ChannelsDbSpec.scala | 308 ++++++++++++------ 5 files changed, 269 insertions(+), 111 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala index 2fb1e3ca8d..a405c96efe 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala @@ -22,7 +22,7 @@ import org.sqlite.SQLiteConnection import scodec.Codec import scodec.bits.{BitVector, ByteVector} -import java.sql.{Connection, ResultSet, Statement} +import java.sql.{Connection, ResultSet, Statement, Timestamp} import java.util.UUID import javax.sql.DataSource import scala.collection.immutable.Queue @@ -123,18 +123,16 @@ trait JdbcUtils { def getByteVector32FromHexNullable(columnLabel: String): Option[ByteVector32] = { val s = rs.getString(columnLabel) - if (rs.wasNull()) None else { - Some(ByteVector32(ByteVector.fromValidHex(s))) - } + if (rs.wasNull()) None else Some(ByteVector32(ByteVector.fromValidHex(s))) } def getBitVectorOpt(columnLabel: String): Option[BitVector] = Option(rs.getBytes(columnLabel)).map(BitVector(_)) def getByteVector(columnLabel: String): ByteVector = ByteVector(rs.getBytes(columnLabel)) - def getByteVectorNullable(columnLabel: String): ByteVector = { + def getByteVectorNullable(columnLabel: String): Option[ByteVector] = { val result = rs.getBytes(columnLabel) - if (rs.wasNull()) ByteVector.empty else ByteVector(result) + if (rs.wasNull()) None else Some(ByteVector(result)) } def getByteVector32(columnLabel: String): ByteVector32 = ByteVector32(ByteVector(rs.getBytes(columnLabel))) @@ -164,6 +162,11 @@ trait JdbcUtils { if (rs.wasNull()) None else Some(MilliSatoshi(result)) } + def getTimestampNullable(label: String): Option[Timestamp] = { + val result = rs.getTimestamp(label) + if (rs.wasNull()) None else Some(result) + } + } object ExtendedResultSet { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala index dae8dfc6fb..e223a12c53 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala @@ -27,7 +27,8 @@ import fr.acinq.eclair.db.pg.PgUtils.PgLock import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec import grizzled.slf4j.Logging -import java.sql.Statement +import java.sql.{Statement, Timestamp} +import java.time.Instant import javax.sql.DataSource import scala.collection.immutable.Queue @@ -38,7 +39,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit import lock._ val DB_NAME = "channels" - val CURRENT_VERSION = 3 + val CURRENT_VERSION = 4 inTransaction { pg => using(pg.createStatement()) { statement => @@ -51,14 +52,28 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit statement.executeUpdate("ALTER TABLE local_channels ADD COLUMN closed_timestamp BIGINT") } + def migration34(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN created_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + created_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_sent_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_sent_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_received_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_received_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_connected_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_connected_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN closed_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + closed_timestamp * interval '1 millisecond'") + + statement.executeUpdate("ALTER TABLE htlc_infos ALTER COLUMN commitment_number SET DATA TYPE BIGINT USING commitment_number::BIGINT") + } + getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp BIGINT, last_payment_sent_timestamp BIGINT, last_payment_received_timestamp BIGINT, last_connected_timestamp BIGINT, closed_timestamp BIGINT)") - statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp TIMESTAMP WITH TIME ZONE, last_payment_sent_timestamp TIMESTAMP WITH TIME ZONE, last_payment_received_timestamp TIMESTAMP WITH TIME ZONE, last_connected_timestamp TIMESTAMP WITH TIME ZONE, closed_timestamp TIMESTAMP WITH TIME ZONE)") + statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number BIGINT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") case Some(v@2) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration23(statement) + migration34(statement) + case Some(v@3) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration34(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -89,7 +104,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit private def updateChannelMetaTimestampColumn(channelId: ByteVector32, columnName: String): Unit = { inTransaction { pg => using(pg.prepareStatement(s"UPDATE local_channels SET $columnName=? WHERE channel_id=?")) { statement => - statement.setLong(1, System.currentTimeMillis) + statement.setTimestamp(1, Timestamp.from(Instant.now())) statement.setString(2, channelId.toHex) statement.executeUpdate() } @@ -152,7 +167,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit withLock { pg => using(pg.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement => statement.setString(1, channelId.toHex) - statement.setString(2, commitmentNumber.toString) + statement.setLong(2, commitmentNumber) val rs = statement.executeQuery var q: Queue[(ByteVector32, CltvExpiry)] = Queue() while (rs.next()) { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala index 55082bf16f..26590765b4 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala @@ -23,7 +23,6 @@ import fr.acinq.eclair.db.ChannelsDb import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends -import fr.acinq.eclair.payment.{ChannelPaymentRelayed, PaymentEvent, PaymentReceived, PaymentRelayed, PaymentSent} import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec import grizzled.slf4j.Logging @@ -63,9 +62,9 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0, created_timestamp INTEGER, last_payment_sent_timestamp INTEGER, last_payment_received_timestamp INTEGER, last_connected_timestamp INTEGER, closed_timestamp INTEGER)") - statement.executeUpdate("CREATE TABLE htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") - statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0, created_timestamp INTEGER, last_payment_sent_timestamp INTEGER, last_payment_received_timestamp INTEGER, last_connected_timestamp INTEGER, closed_timestamp INTEGER)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number INTEGER NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") case Some(v@1) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration12(statement) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala index c38755a666..4ab1358eb0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala @@ -4,7 +4,7 @@ import akka.actor.ActorSystem import com.opentable.db.postgres.embedded.EmbeddedPostgres import com.zaxxer.hikari.HikariConfig import fr.acinq.eclair.db._ -import fr.acinq.eclair.db.pg.PgUtils.PgLock +import fr.acinq.eclair.db.pg.PgUtils.{PgLock, getVersion, using} import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler import org.postgresql.jdbc.PgConnection import org.sqlite.SQLiteConnection @@ -64,6 +64,7 @@ object TestDatabases { // @formatter:off override val connection: PgConnection = pg.getPostgresDatabase.getConnection.asInstanceOf[PgConnection] + // NB: we use a lazy val here: databases won't be initialized until we reference that variable override lazy val db: Databases = Databases.PostgresDatabases(hikariConfig, UUID.randomUUID(), lock, jdbcUrlFile_opt = Some(jdbcUrlFile), readOnlyUser_opt = None) override def close(): Unit = pg.close() // @formatter:on @@ -77,4 +78,24 @@ object TestDatabases { // @formatter:on } + def migrationCheck( + dbs: TestDatabases, + initializeTables: Connection => Unit, + dbName: String, + targetVersion: Int, + postCheck: Connection => Unit + ): Unit = { + val connection = dbs.connection + // initialize the database to a previous version and populate data + initializeTables(connection) + // this will trigger the initialization of tables and the migration + val _ = dbs.db + // check that db version was updated + using(connection.createStatement()) { statement => + assert(getVersion(statement, dbName).contains(targetVersion)) + } + // post-migration checks + postCheck(connection) + } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala index 84994a695d..609e8ff0ff 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala @@ -18,23 +18,25 @@ package fr.acinq.eclair.db import com.softwaremill.quicklens._ import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases} +import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, migrationCheck} +import fr.acinq.eclair.db.ChannelsDbSpec.{getPgTimestamp, getTimestamp, testCases} import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.jdbc.JdbcUtils.using -import fr.acinq.eclair.db.pg.PgChannelsDb import fr.acinq.eclair.db.pg.PgUtils.{getVersion, setVersion} +import fr.acinq.eclair.db.pg.{PgChannelsDb, PgUtils} import fr.acinq.eclair.db.sqlite.SqliteChannelsDb import fr.acinq.eclair.db.sqlite.SqliteUtils.ExtendedResultSet._ import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec import fr.acinq.eclair.wire.internal.channel.ChannelCodecsSpec -import fr.acinq.eclair.{CltvExpiry, ShortChannelId, randomBytes32} +import fr.acinq.eclair.{CltvExpiry, ShortChannelId, TestDatabases, randomBytes32} import org.scalatest.funsuite.AnyFunSuite import scodec.bits.ByteVector -import java.sql.SQLException +import java.sql.{Connection, SQLException} import java.util.concurrent.Executors -import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future} import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future} +import scala.util.Random class ChannelsDbSpec extends AnyFunSuite { @@ -107,56 +109,42 @@ class ChannelsDbSpec extends AnyFunSuite { test("channel metadata") { forAllDbs { dbs => val db = dbs.channels - val connection = dbs.connection val channel1 = ChannelCodecsSpec.normal val channel2 = channel1.modify(_.commitments.channelId).setTo(randomBytes32) - def getTimestamp(channelId: ByteVector32, columnName: String): Option[Long] = { - using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement => - // data type differs depending on underlying database system - dbs match { - case _: TestPgDatabases => statement.setString(1, channelId.toHex) - case _: TestSqliteDatabases => statement.setBytes(1, channelId.toArray) - } - val rs = statement.executeQuery() - rs.next() - rs.getLongNullable(columnName) - } - } - // first we add channels db.addOrUpdateChannel(channel1) db.addOrUpdateChannel(channel2) // make sure initially all metadata are empty - assert(getTimestamp(channel1.channelId, "created_timestamp").isEmpty) - assert(getTimestamp(channel1.channelId, "last_payment_sent_timestamp").isEmpty) - assert(getTimestamp(channel1.channelId, "last_payment_received_timestamp").isEmpty) - assert(getTimestamp(channel1.channelId, "last_connected_timestamp").isEmpty) - assert(getTimestamp(channel1.channelId, "closed_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "created_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_payment_sent_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_payment_received_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_connected_timestamp").isEmpty) + assert(getTimestamp(dbs, channel1.channelId, "closed_timestamp").isEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Created) - assert(getTimestamp(channel1.channelId, "created_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "created_timestamp").nonEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.PaymentSent) - assert(getTimestamp(channel1.channelId, "last_payment_sent_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_payment_sent_timestamp").nonEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.PaymentReceived) - assert(getTimestamp(channel1.channelId, "last_payment_received_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_payment_received_timestamp").nonEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Connected) - assert(getTimestamp(channel1.channelId, "last_connected_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "last_connected_timestamp").nonEmpty) db.updateChannelMeta(channel1.channelId, ChannelEvent.EventType.Closed(null)) - assert(getTimestamp(channel1.channelId, "closed_timestamp").nonEmpty) + assert(getTimestamp(dbs, channel1.channelId, "closed_timestamp").nonEmpty) // make sure all metadata are still empty for channel 2 - assert(getTimestamp(channel2.channelId, "created_timestamp").isEmpty) - assert(getTimestamp(channel2.channelId, "last_payment_sent_timestamp").isEmpty) - assert(getTimestamp(channel2.channelId, "last_payment_received_timestamp").isEmpty) - assert(getTimestamp(channel2.channelId, "last_connected_timestamp").isEmpty) - assert(getTimestamp(channel2.channelId, "closed_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "created_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "last_payment_sent_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "last_payment_received_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "last_connected_timestamp").isEmpty) + assert(getTimestamp(dbs, channel2.channelId, "closed_timestamp").isEmpty) } } @@ -175,13 +163,22 @@ class ChannelsDbSpec extends AnyFunSuite { setVersion(statement, "channels", 1) } - // insert 1 row - val channel = ChannelCodecsSpec.normal - val data = stateDataCodec.encode(channel).require.toByteArray - using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement => - statement.setBytes(1, channel.channelId.toArray) - statement.setBytes(2, data) - statement.executeUpdate() + // insert data + for (testCase <- testCases) { + using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?)")) { statement => + statement.setBytes(1, testCase.channelId.toArray) + statement.setBytes(2, testCase.data.toArray) + statement.executeUpdate() + } + for (commitmentNumber <- testCase.commitmentNumbers) { + using(sqlite.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement => + statement.setBytes(1, testCase.channelId.toArray) + statement.setLong(2, commitmentNumber) + statement.setBytes(3, randomBytes32.toArray) + statement.setLong(4, 500000 + Random.nextInt(500000)) + statement.executeUpdate() + } + } } // check that db migration works @@ -189,71 +186,194 @@ class ChannelsDbSpec extends AnyFunSuite { using(sqlite.createStatement()) { statement => assert(getVersion(statement, "channels").contains(3)) } - assert(db.listLocalChannels() === List(channel)) - db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail + assert(db.listLocalChannels().size === testCases.size) + for (testCase <- testCases) { + db.updateChannelMeta(testCase.channelId, ChannelEvent.EventType.Created) // this call must not fail + for (commitmentNumber <- testCase.commitmentNumbers) { + assert(db.listHtlcInfos(testCase.channelId, commitmentNumber).size === testCase.commitmentNumbers.count(_ == commitmentNumber)) + } + } } } - test("migrate channel database v2 -> v3") { + test("migrate channel database v2 -> v3/v4") { + def postCheck(channelsDb: ChannelsDb): Unit = { + assert(channelsDb.listLocalChannels().size === testCases.filterNot(_.isClosed).size) + for (testCase <- testCases.filterNot(_.isClosed)) { + channelsDb.updateChannelMeta(testCase.channelId, ChannelEvent.EventType.Created) // this call must not fail + for (commitmentNumber <- testCase.commitmentNumbers) { + assert(channelsDb.listHtlcInfos(testCase.channelId, commitmentNumber).size === testCase.commitmentNumbers.count(_ == commitmentNumber)) + } + } + } + forAllDbs { case dbs: TestPgDatabases => - val pg = dbs.connection + migrationCheck( + dbs = dbs, + initializeTables = connection => { + // initialize a v2 database + using(connection.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") + setVersion(statement, "channels", 2) + } + // insert data + testCases.foreach { testCase => + using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement => + statement.setString(1, testCase.channelId.toHex) + statement.setBytes(2, testCase.data.toArray) + statement.setBoolean(3, testCase.isClosed) + statement.executeUpdate() + for (commitmentNumber <- testCase.commitmentNumbers) { + using(connection.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement => + statement.setString(1, testCase.channelId.toHex) + statement.setLong(2, commitmentNumber) + statement.setString(3, randomBytes32.toHex) + statement.setLong(4, 500000 + Random.nextInt(500000)) + statement.executeUpdate() + } + } + } + } + }, + dbName = "channels", + targetVersion = 4, + postCheck = _ => postCheck(dbs.channels) + ) + case dbs: TestSqliteDatabases => + migrationCheck( + dbs = dbs, + initializeTables = connection => { + // create a v2 channels database + using(connection.createStatement()) { statement => + statement.execute("PRAGMA foreign_keys = ON") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") + setVersion(statement, "channels", 2) + } + // insert data + testCases.foreach { testCase => + using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement => + statement.setBytes(1, testCase.channelId.toArray) + statement.setBytes(2, testCase.data.toArray) + statement.setBoolean(3, testCase.isClosed) + statement.executeUpdate() + for (commitmentNumber <- testCase.commitmentNumbers) { + using(connection.prepareStatement("INSERT INTO htlc_infos (channel_id, commitment_number, payment_hash, cltv_expiry) VALUES (?, ?, ?, ?)")) { statement => + statement.setBytes(1, testCase.channelId.toArray) + statement.setLong(2, commitmentNumber) + statement.setBytes(3, randomBytes32.toArray) + statement.setLong(4, 500000 + Random.nextInt(500000)) + statement.executeUpdate() + } + } + } + } + }, + dbName = "channels", + targetVersion = 3, + postCheck = _ => postCheck(dbs.channels) + ) + } + } - // create a v2 channels database - using(pg.createStatement()) { statement => - statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") - setVersion(statement, "channels", 2) + test("migrate pg channel database v3->v4") { + val dbs = TestPgDatabases() + + migrationCheck( + dbs = dbs, + initializeTables = connection => { + using(connection.createStatement()) { statement => + // initialize a v3 database + statement.executeUpdate("CREATE TABLE local_channels (channel_id TEXT NOT NULL PRIMARY KEY, data BYTEA NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT FALSE, created_timestamp BIGINT, last_payment_sent_timestamp BIGINT, last_payment_received_timestamp BIGINT, last_connected_timestamp BIGINT, closed_timestamp BIGINT)") + statement.executeUpdate("CREATE TABLE htlc_infos (channel_id TEXT NOT NULL, commitment_number TEXT NOT NULL, payment_hash TEXT NOT NULL, cltv_expiry BIGINT NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") + PgUtils.setVersion(statement, "channels", 3) } - - // insert 1 row - val channel = ChannelCodecsSpec.normal - val data = stateDataCodec.encode(channel).require.toByteArray - using(pg.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed) VALUES (?, ?, ?)")) { statement => - statement.setString(1, channel.channelId.toHex) - statement.setBytes(2, data) - statement.setBoolean(3, false) - statement.executeUpdate() + // insert data + testCases.foreach { testCase => + using(connection.prepareStatement("INSERT INTO local_channels (channel_id, data, is_closed, created_timestamp, last_payment_sent_timestamp, last_payment_received_timestamp, last_connected_timestamp, closed_timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, testCase.channelId.toHex) + statement.setBytes(2, testCase.data.toArray) + statement.setBoolean(3, testCase.isClosed) + statement.setObject(4, testCase.createdTimestamp.orNull) + statement.setObject(5, testCase.lastPaymentSentTimestamp.orNull) + statement.setObject(6, testCase.lastPaymentReceivedTimestamp.orNull) + statement.setObject(7, testCase.lastConnectedTimestamp.orNull) + statement.setObject(8, testCase.closedTimestamp.orNull) + statement.executeUpdate() + } } - - // check that db migration works - val db = dbs.channels - using(pg.createStatement()) { statement => - assert(getVersion(statement, "channels").contains(3)) + }, + dbName = "channels", + targetVersion = 4, + postCheck = connection => { + assert(dbs.channels.listLocalChannels().size === testCases.filterNot(_.isClosed).size) + testCases.foreach { testCase => + assert(getPgTimestamp(connection, testCase.channelId, "created_timestamp") === testCase.createdTimestamp) + assert(getPgTimestamp(connection, testCase.channelId, "last_payment_sent_timestamp") === testCase.lastPaymentSentTimestamp) + assert(getPgTimestamp(connection, testCase.channelId, "last_payment_received_timestamp") === testCase.lastPaymentReceivedTimestamp) + assert(getPgTimestamp(connection, testCase.channelId, "last_connected_timestamp") === testCase.lastConnectedTimestamp) + assert(getPgTimestamp(connection, testCase.channelId, "closed_timestamp") === testCase.closedTimestamp) } - assert(db.listLocalChannels() === List(channel)) - db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail - - case dbs: TestSqliteDatabases => - val sqlite = dbs.connection + } + ) - // create a v2 channels database - using(sqlite.createStatement()) { statement => - statement.execute("PRAGMA foreign_keys = ON") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number BLOB NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") - setVersion(statement, "channels", 2) - } + } +} + +object ChannelsDbSpec { + + case class TestCase( + channelId: ByteVector32, + data: ByteVector, + isClosed: Boolean, + createdTimestamp: Option[Long], + lastPaymentSentTimestamp: Option[Long], + lastPaymentReceivedTimestamp: Option[Long], + lastConnectedTimestamp: Option[Long], + closedTimestamp: Option[Long], + commitmentNumbers: Seq[Int] + ) + + private val data = stateDataCodec.encode(ChannelCodecsSpec.normal).require.bytes + val testCases: Seq[TestCase] = for (_ <- 0 until 10) yield TestCase( + channelId = randomBytes32, + data = data, + isClosed = Random.nextBoolean(), + createdTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + lastPaymentSentTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + lastPaymentReceivedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + lastConnectedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + closedTimestamp = if (Random.nextBoolean()) Some(Random.nextInt(Int.MaxValue)) else None, + commitmentNumbers = for (_ <- 0 until Random.nextInt(10)) yield Random.nextInt(5) // there will be repetitions, on purpose + ) + + def getTimestamp(dbs: TestDatabases, channelId: ByteVector32, columnName: String): Option[Long] = { + dbs match { + case _: TestPgDatabases => getPgTimestamp(dbs.connection, channelId, columnName) + case _: TestSqliteDatabases => getSqliteTimestamp(dbs.connection, channelId, columnName) + } + } - // insert 1 row - val channel = ChannelCodecsSpec.normal - val data = stateDataCodec.encode(channel).require.toByteArray - using(sqlite.prepareStatement("INSERT INTO local_channels VALUES (?, ?, ?)")) { statement => - statement.setBytes(1, channel.channelId.toArray) - statement.setBytes(2, data) - statement.setBoolean(3, false) - statement.executeUpdate() - } + def getSqliteTimestamp(connection: Connection, channelId: ByteVector32, columnName: String): Option[Long] = { + using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement => + statement.setBytes(1, channelId.toArray) + val rs = statement.executeQuery() + rs.next() + rs.getLongNullable(columnName) + } + } - // check that db migration works - val db = dbs.channels - using(sqlite.createStatement()) { statement => - assert(getVersion(statement, "channels").contains(3)) - } - assert(db.listLocalChannels() === List(channel)) - db.updateChannelMeta(channel.channelId, ChannelEvent.EventType.Created) // this call must not fail + def getPgTimestamp(connection: Connection, channelId: ByteVector32, columnName: String): Option[Long] = { + using(connection.prepareStatement(s"SELECT $columnName FROM local_channels WHERE channel_id=?")) { statement => + statement.setString(1, channelId.toHex) + val rs = statement.executeQuery() + rs.next() + rs.getTimestampNullable(columnName).map(_.getTime) } } } \ No newline at end of file From bac226ebddaca494d0edc3b279916bbfc912666c Mon Sep 17 00:00:00 2001 From: pm47 Date: Fri, 16 Apr 2021 15:13:51 +0200 Subject: [PATCH 2/9] update audit db (timestamps) --- .../fr/acinq/eclair/db/pg/PgAuditDb.scala | 85 +++++++++++-------- .../fr/acinq/eclair/db/AuditDbSpec.scala | 20 +++-- 2 files changed, 61 insertions(+), 44 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index 19244891f0..1e10235729 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -29,7 +29,8 @@ import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey import fr.acinq.eclair.{MilliSatoshi, MilliSatoshiLong} import grizzled.slf4j.Logging -import java.sql.Statement +import java.sql.{Statement, Timestamp} +import java.time.Instant import java.util.UUID import javax.sql.DataSource import scala.collection.immutable.Queue @@ -40,7 +41,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { import ExtendedResultSet._ val DB_NAME = "audit" - val CURRENT_VERSION = 5 + val CURRENT_VERSION = 6 case class RelayedPart(channelId: ByteVector32, amount: MilliSatoshi, direction: String, relayType: String, timestamp: Long) @@ -52,15 +53,25 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.executeUpdate("CREATE INDEX relayed_trampoline_payment_hash_idx ON relayed_trampoline(payment_hash)") } + def migration56(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE sent ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE received ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE relayed ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE relayed_trampoline ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE network_fees ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE channel_events ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE channel_errors ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + } + getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE relayed_trampoline (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, next_node_id TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") + statement.executeUpdate("CREATE TABLE channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp TIMESTAMP WITH TIME ZONE NOT NULL)") statement.executeUpdate("CREATE INDEX sent_timestamp_idx ON sent(timestamp)") statement.executeUpdate("CREATE INDEX received_timestamp_idx ON received(timestamp)") @@ -74,6 +85,10 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { case Some(v@4) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration45(statement) + migration56(statement) + case Some(v@5) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration56(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -90,7 +105,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setBoolean(4, e.isFunder) statement.setBoolean(5, e.isPrivate) statement.setString(6, e.event.label) - statement.setLong(7, System.currentTimeMillis) + statement.setTimestamp(7, Timestamp.from(Instant.now())) statement.executeUpdate() } } @@ -109,7 +124,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(7, e.paymentPreimage.toHex) statement.setString(8, e.recipientNodeId.value.toHex) statement.setString(9, p.toChannelId.toHex) - statement.setLong(10, p.timestamp) + statement.setTimestamp(10, Timestamp.from(Instant.ofEpochMilli(p.timestamp))) statement.addBatch() }) statement.executeBatch() @@ -124,7 +139,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setLong(1, p.amount.toLong) statement.setString(2, e.paymentHash.toHex) statement.setString(3, p.fromChannelId.toHex) - statement.setLong(4, p.timestamp) + statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(p.timestamp))) statement.addBatch() }) statement.executeBatch() @@ -143,7 +158,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(1, e.paymentHash.toHex) statement.setLong(2, nextTrampolineAmount.toLong) statement.setString(3, nextTrampolineNodeId.value.toHex) - statement.setLong(4, e.timestamp) + statement.setTimestamp(4, Timestamp.from(Instant.ofEpochMilli(e.timestamp))) statement.executeUpdate() } // trampoline relayed payments do MPP aggregation and may have M inputs and N outputs @@ -156,7 +171,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(3, p.channelId.toHex) statement.setString(4, p.direction) statement.setString(5, p.relayType) - statement.setLong(6, e.timestamp) + statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(e.timestamp))) statement.executeUpdate() } } @@ -171,7 +186,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(3, e.tx.txid.toHex) statement.setLong(4, e.fee.toLong) statement.setString(5, e.txType) - statement.setLong(6, System.currentTimeMillis) + statement.setTimestamp(6, Timestamp.from(Instant.now())) statement.executeUpdate() } } @@ -189,7 +204,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { statement.setString(3, errorName) statement.setString(4, errorMessage) statement.setBoolean(5, e.isFatal) - statement.setLong(6, System.currentTimeMillis) + statement.setTimestamp(6, Timestamp.from(Instant.now())) statement.executeUpdate() } } @@ -197,9 +212,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listSent(from: Long, to: Long): Seq[PaymentSent] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ? ORDER BY timestamp")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() var sentByParentId = Map.empty[UUID, PaymentSent] while (rs.next()) { @@ -210,7 +225,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { MilliSatoshi(rs.getLong("fees_msat")), rs.getByteVector32FromHex("to_channel_id"), None, // we don't store the route in the audit DB - rs.getLong("timestamp")) + rs.getTimestamp("timestamp").getTime) val sent = sentByParentId.get(parentId) match { case Some(s) => s.copy(parts = s.parts :+ part) case None => PaymentSent( @@ -229,9 +244,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listReceived(from: Long, to: Long): Seq[PaymentReceived] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ? ORDER BY timestamp")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() var receivedByHash = Map.empty[ByteVector32, PaymentReceived] while (rs.next()) { @@ -239,7 +254,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { val part = PaymentReceived.PartialPayment( MilliSatoshi(rs.getLong("amount_msat")), rs.getByteVector32FromHex("from_channel_id"), - rs.getLong("timestamp")) + rs.getTimestamp("timestamp").getTime) val received = receivedByHash.get(paymentHash) match { case Some(r) => r.copy(parts = r.parts :+ part) case None => PaymentReceived(paymentHash, Seq(part)) @@ -253,9 +268,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] = inTransaction { pg => var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)] - using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp >= ? AND timestamp < ?")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() while (rs.next()) { val paymentHash = rs.getByteVector32FromHex("payment_hash") @@ -264,9 +279,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { trampolineByHash += (paymentHash -> (amount, nodeId)) } } - using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]] while (rs.next()) { @@ -276,7 +291,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { MilliSatoshi(rs.getLong("amount_msat")), rs.getString("direction"), rs.getString("relay_type"), - rs.getLong("timestamp")) + rs.getTimestamp("timestamp").getTime) relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part)) } relayedByHash.flatMap { @@ -300,9 +315,9 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => - statement.setLong(1, from) - statement.setLong(2, to) + using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement => + statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) + statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() var q: Queue[NetworkFee] = Queue() while (rs.next()) { @@ -312,7 +327,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { txId = rs.getByteVector32FromHex("tx_id"), fee = Satoshi(rs.getLong("fee_sat")), txType = rs.getString("tx_type"), - timestamp = rs.getLong("timestamp")) + timestamp = rs.getTimestamp("timestamp").getTime) } q } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala index 3e0fe7acfa..3931da18fe 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala @@ -34,6 +34,8 @@ import fr.acinq.eclair.wire.protocol.Error import org.scalatest.Tag import org.scalatest.funsuite.AnyFunSuite +import java.sql.Timestamp +import java.time.Instant import java.util.UUID import javax.sql.DataSource import scala.concurrent.duration._ @@ -182,7 +184,7 @@ class AuditDbSpec extends AnyFunSuite { } } - test("handle migration version 1 -> 5") { + test("migrate audit database v1 -> v5/v6") { forAllDbs { case _: TestPgDatabases => // no migration case dbs: TestSqliteDatabases => @@ -255,7 +257,7 @@ class AuditDbSpec extends AnyFunSuite { } } - test("handle migration version 2 -> 5") { + test("migrate audit database v2 -> v5/v6") { forAllDbs { case _: TestPgDatabases => // no migration case dbs: TestSqliteDatabases => @@ -306,7 +308,7 @@ class AuditDbSpec extends AnyFunSuite { } } - test("handle migration version 3 -> 5") { + test("migrate audit database v3 -> v5/v6") { forAllDbs { case _: TestPgDatabases => // no migration case dbs: TestSqliteDatabases => @@ -400,7 +402,7 @@ class AuditDbSpec extends AnyFunSuite { } } - test("handle migration version 4 -> 5") { + test("migrate audit database v4 -> v5/v6") { forAllDbs { case dbs: TestPgDatabases => import fr.acinq.eclair.db.pg.PgUtils.getVersion @@ -483,7 +485,7 @@ class AuditDbSpec extends AnyFunSuite { val migratedDb = new PgAuditDb()(datasource) inTransaction { pg => using(pg.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) + assert(getVersion(statement, "audit").contains(6)) } } @@ -493,7 +495,7 @@ class AuditDbSpec extends AnyFunSuite { inTransaction { pg => using(pg.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) + assert(getVersion(statement, "audit").contains(6)) } } @@ -605,7 +607,7 @@ class AuditDbSpec extends AnyFunSuite { if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray) statement.setString(4, "IN") statement.setString(5, "unknown") // invalid relay type - statement.setLong(6, 10) + if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(10))) else statement.setLong(6, 10) statement.executeUpdate() } @@ -615,7 +617,7 @@ class AuditDbSpec extends AnyFunSuite { if (isPg) statement.setString(3, randomBytes32.toHex) else statement.setBytes(3, randomBytes32.toArray) statement.setString(4, "UP") // invalid direction statement.setString(5, "channel") - statement.setLong(6, 20) + if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(20))) else statement.setLong(6, 20) statement.executeUpdate() } @@ -628,7 +630,7 @@ class AuditDbSpec extends AnyFunSuite { if (isPg) statement.setString(3, channelId.toHex) else statement.setBytes(3, channelId.toArray) statement.setString(4, "IN") // missing a corresponding OUT statement.setString(5, "channel") - statement.setLong(6, 30) + if (isPg) statement.setTimestamp(6, Timestamp.from(Instant.ofEpochMilli(30))) else statement.setLong(6, 30) statement.executeUpdate() } From c7a88f1b15bccb7b91f898b63c776cc2fbf174e6 Mon Sep 17 00:00:00 2001 From: pm47 Date: Tue, 20 Apr 2021 16:53:35 +0200 Subject: [PATCH 3/9] clean up audit tests --- .../fr/acinq/eclair/db/AuditDbSpec.scala | 460 ++++++++---------- 1 file changed, 213 insertions(+), 247 deletions(-) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala index 3931da18fe..6a17040c4c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.db import fr.acinq.bitcoin.Crypto.PrivateKey import fr.acinq.bitcoin.{ByteVector32, SatoshiLong, Transaction} -import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases} +import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases, migrationCheck} import fr.acinq.eclair._ import fr.acinq.eclair.channel.Helpers.Closing.MutualClose import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError} @@ -26,7 +26,7 @@ import fr.acinq.eclair.db.AuditDb.Stats import fr.acinq.eclair.db.DbEventHandler.ChannelEvent import fr.acinq.eclair.db.jdbc.JdbcUtils.using import fr.acinq.eclair.db.pg.PgAuditDb -import fr.acinq.eclair.db.pg.PgUtils.{inTransaction, setVersion} +import fr.acinq.eclair.db.pg.PgUtils.{getVersion, setVersion} import fr.acinq.eclair.db.sqlite.SqliteAuditDb import fr.acinq.eclair.payment._ import fr.acinq.eclair.transactions.Transactions.PlaceHolderPubKey @@ -37,7 +37,6 @@ import org.scalatest.funsuite.AnyFunSuite import java.sql.Timestamp import java.time.Instant import java.util.UUID -import javax.sql.DataSource import scala.concurrent.duration._ import scala.util.Random @@ -184,13 +183,20 @@ class AuditDbSpec extends AnyFunSuite { } } - test("migrate audit database v1 -> v5/v6") { - forAllDbs { - case _: TestPgDatabases => // no migration - case dbs: TestSqliteDatabases => - import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion - val connection = dbs.connection + test("migrate sqlite audit database v1 -> v5/v6") { + + val dbs = TestSqliteDatabases() + val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil) + val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None) + val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None) + val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil) + val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true) + + migrationCheck( + dbs = dbs, + initializeTables = connection => { // simulate existing previous version db using(connection.createStatement()) { statement => statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") @@ -210,17 +216,6 @@ class AuditDbSpec extends AnyFunSuite { setVersion(statement, "audit", 1) } - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(1)) - } - - val ps = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 42000 msat, PrivateKey(ByteVector32.One).publicKey, PaymentSent.PartialPayment(UUID.randomUUID(), 42000 msat, 1000 msat, randomBytes32, None) :: Nil) - val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 42001 msat, 1001 msat, randomBytes32, None) - val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 42002 msat, 1002 msat, randomBytes32, None) - val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 84003 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil) - val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) - val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true) - // add a row (no ID on sent) using(connection.prepareStatement("INSERT INTO sent VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setLong(1, ps.recipientAmount.toLong) @@ -231,15 +226,12 @@ class AuditDbSpec extends AnyFunSuite { statement.setLong(6, ps.timestamp) statement.executeUpdate() } - - val migratedDb = new SqliteAuditDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) - } - + }, + dbName = "audit", + targetVersion = 5, + postCheck = connection => { // existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default - assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))))) + assert(dbs.audit.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))))) val postMigrationDb = new SqliteAuditDb(connection) @@ -254,16 +246,19 @@ class AuditDbSpec extends AnyFunSuite { // the old record will have the UNKNOWN_UUID but the new ones will have their actual id val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1) assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected) - } + } + ) } - test("migrate audit database v2 -> v5/v6") { - forAllDbs { - case _: TestPgDatabases => // no migration - case dbs: TestSqliteDatabases => - import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion - val connection = dbs.connection + test("migrate sqlite audit database v2 -> v5") { + val dbs = TestSqliteDatabases() + + val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) + val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true) + migrationCheck( + dbs = dbs, + initializeTables = connection => { // simulate existing previous version db using(connection.createStatement()) { statement => statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") @@ -282,39 +277,39 @@ class AuditDbSpec extends AnyFunSuite { setVersion(statement, "audit", 2) } - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(2)) - } - - val e1 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, LocalError(new RuntimeException("oops")), isFatal = true) - val e2 = ChannelErrorOccurred(null, randomBytes32, randomKey.publicKey, null, RemoteError(Error(randomBytes32, "remote oops")), isFatal = true) - - val migratedDb = new SqliteAuditDb(connection) - + }, + dbName = "audit", + targetVersion = 5, + postCheck = connection => { + val migratedDb = dbs.audit using(connection.createStatement()) { statement => assert(getVersion(statement, "audit").contains(5)) } - migratedDb.add(e1) val postMigrationDb = new SqliteAuditDb(connection) - using(connection.createStatement()) { statement => assert(getVersion(statement, "audit").contains(5)) } - postMigrationDb.add(e2) - } + } + ) } - test("migrate audit database v3 -> v5/v6") { - forAllDbs { - case _: TestPgDatabases => // no migration - case dbs: TestSqliteDatabases => - import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion - val connection = dbs.connection + test("migrate sqlite audit database v3 -> v5") { + + val dbs = TestSqliteDatabases() + + val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100) + val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110) + val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil) + val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) + val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115) + + migrationCheck( + dbs = dbs, + initializeTables = connection => { // simulate existing previous version db using(connection.createStatement()) { statement => statement.executeUpdate("CREATE TABLE IF NOT EXISTS balance_updated (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, amount_msat INTEGER NOT NULL, capacity_sat INTEGER NOT NULL, reserve_sat INTEGER NOT NULL, timestamp INTEGER NOT NULL)") @@ -336,14 +331,6 @@ class AuditDbSpec extends AnyFunSuite { setVersion(statement, "audit", 3) } - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(3)) - } - - val pp1 = PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 100) - val pp2 = PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 110) - val ps1 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, PrivateKey(ByteVector32.One).publicKey, pp1 :: pp2 :: Nil) - for (pp <- Seq(pp1, pp2)) { using(connection.prepareStatement("INSERT INTO sent (amount_msat, fees_msat, payment_hash, payment_preimage, to_channel_id, timestamp, id) VALUES (?, ?, ?, ?, ?, ?, ?)")) { statement => statement.setLong(1, pp.amount.toLong) @@ -357,9 +344,6 @@ class AuditDbSpec extends AnyFunSuite { } } - val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) - val relayed2 = ChannelPaymentRelayed(650 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 115) - for (relayed <- Seq(relayed1, relayed2)) { using(connection.prepareStatement("INSERT INTO relayed (amount_in_msat, amount_out_msat, payment_hash, from_channel_id, to_channel_id, timestamp) VALUES (?, ?, ?, ?, ?, ?)")) { statement => statement.setLong(1, relayed.amountIn.toLong) @@ -371,12 +355,14 @@ class AuditDbSpec extends AnyFunSuite { statement.executeUpdate() } } - - val migratedDb = new SqliteAuditDb(connection) + }, + dbName = "audit", + targetVersion = 5, + postCheck = connection => { + val migratedDb = dbs.audit using(connection.createStatement()) { statement => assert(getVersion(statement, "audit").contains(5)) } - assert(migratedDb.listSent(50, 150).toSet === Set( ps1.copy(id = pp1.id, recipientAmount = pp1.amount, parts = pp1 :: Nil), ps1.copy(id = pp2.id, recipientAmount = pp2.amount, parts = pp2 :: Nil) @@ -384,214 +370,194 @@ class AuditDbSpec extends AnyFunSuite { assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) val postMigrationDb = new SqliteAuditDb(connection) - using(connection.createStatement()) { statement => assert(getVersion(statement, "audit").contains(5)) } - val ps2 = PaymentSent(UUID.randomUUID(), randomBytes32, randomBytes32, 1100 msat, randomKey.publicKey, Seq( PaymentSent.PartialPayment(UUID.randomUUID(), 500 msat, 10 msat, randomBytes32, None, 160), PaymentSent.PartialPayment(UUID.randomUUID(), 600 msat, 5 msat, randomBytes32, None, 165) )) val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) - postMigrationDb.add(ps2) assert(postMigrationDb.listSent(155, 200) === Seq(ps2)) postMigrationDb.add(relayed3) assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) - } + } + ) } test("migrate audit database v4 -> v5/v6") { - forAllDbs { - case dbs: TestPgDatabases => - import fr.acinq.eclair.db.pg.PgUtils.getVersion - implicit val datasource: DataSource = dbs.datasource - // simulate existing previous version db - inTransaction { pg => - using(pg.createStatement()) { statement => - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)") - - statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") - - setVersion(statement, "audit", 4) - } - } + val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) + val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110) - inTransaction { pg => - using(pg.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(4)) - } - } + forAllDbs { + case dbs: TestPgDatabases => + migrationCheck( + dbs = dbs, + initializeTables = connection => { + // simulate existing previous version db + using(connection.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat BIGINT NOT NULL, fees_msat BIGINT NOT NULL, recipient_amount_msat BIGINT NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash TEXT NOT NULL, payment_preimage TEXT NOT NULL, recipient_node_id TEXT NOT NULL, to_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat BIGINT NOT NULL, payment_hash TEXT NOT NULL, from_channel_id TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash TEXT NOT NULL, amount_msat BIGINT NOT NULL, channel_id TEXT NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, tx_id TEXT NOT NULL, fee_sat BIGINT NOT NULL, tx_type TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, capacity_sat BIGINT NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp BIGINT NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id TEXT NOT NULL, node_id TEXT NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal BOOLEAN NOT NULL, timestamp BIGINT NOT NULL)") + + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") + + setVersion(statement, "audit", 4) + } - val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) - val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110) - - inTransaction { pg => - using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setString(1, relayed1.paymentHash.toHex) - statement.setLong(2, relayed1.amountIn.toLong) - statement.setString(3, relayed1.fromChannelId.toHex) - statement.setString(4, "IN") - statement.setString(5, "channel") - statement.setLong(6, relayed1.timestamp) - statement.executeUpdate() - } - using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setString(1, relayed1.paymentHash.toHex) - statement.setLong(2, relayed1.amountOut.toLong) - statement.setString(3, relayed1.toChannelId.toHex) - statement.setString(4, "OUT") - statement.setString(5, "channel") - statement.setLong(6, relayed1.timestamp) - statement.executeUpdate() - } - for (incoming <- relayed2.incoming) { - using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setString(1, relayed2.paymentHash.toHex) - statement.setLong(2, incoming.amount.toLong) - statement.setString(3, incoming.channelId.toHex) + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, relayed1.paymentHash.toHex) + statement.setLong(2, relayed1.amountIn.toLong) + statement.setString(3, relayed1.fromChannelId.toHex) statement.setString(4, "IN") - statement.setString(5, "trampoline") - statement.setLong(6, relayed2.timestamp) + statement.setString(5, "channel") + statement.setLong(6, relayed1.timestamp) statement.executeUpdate() } - } - for (outgoing <- relayed2.outgoing) { - using(pg.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setString(1, relayed2.paymentHash.toHex) - statement.setLong(2, outgoing.amount.toLong) - statement.setString(3, outgoing.channelId.toHex) + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, relayed1.paymentHash.toHex) + statement.setLong(2, relayed1.amountOut.toLong) + statement.setString(3, relayed1.toChannelId.toHex) statement.setString(4, "OUT") - statement.setString(5, "trampoline") - statement.setLong(6, relayed2.timestamp) + statement.setString(5, "channel") + statement.setLong(6, relayed1.timestamp) statement.executeUpdate() } - } - } - - val migratedDb = new PgAuditDb()(datasource) - inTransaction { pg => - using(pg.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(6)) - } - } - - assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) - - val postMigrationDb = new PgAuditDb()(datasource) + for (incoming <- relayed2.incoming) { + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, relayed2.paymentHash.toHex) + statement.setLong(2, incoming.amount.toLong) + statement.setString(3, incoming.channelId.toHex) + statement.setString(4, "IN") + statement.setString(5, "trampoline") + statement.setLong(6, relayed2.timestamp) + statement.executeUpdate() + } + } + for (outgoing <- relayed2.outgoing) { + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setString(1, relayed2.paymentHash.toHex) + statement.setLong(2, outgoing.amount.toLong) + statement.setString(3, outgoing.channelId.toHex) + statement.setString(4, "OUT") + statement.setString(5, "trampoline") + statement.setLong(6, relayed2.timestamp) + statement.executeUpdate() + } + } + }, + dbName = "audit", + targetVersion = 6, + postCheck = connection => { + val migratedDb = dbs.audit + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit").contains(6)) + } + assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) - inTransaction { pg => - using(pg.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(6)) + val postMigrationDb = new PgAuditDb()(dbs.datasource) + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit").contains(6)) + } + val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) + postMigrationDb.add(relayed3) + assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) } - } - - val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) - - postMigrationDb.add(relayed3) - assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) + ) case dbs: TestSqliteDatabases => - import fr.acinq.eclair.db.sqlite.SqliteUtils.getVersion - val connection = dbs.connection - - // simulate existing previous version db - using(connection.createStatement()) { statement => - statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)") - - statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") - - setVersion(statement, "audit", 4) - } - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(4)) - } + migrationCheck( + dbs = dbs, + initializeTables = connection => { + // simulate existing previous version db + using(connection.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE IF NOT EXISTS sent (amount_msat INTEGER NOT NULL, fees_msat INTEGER NOT NULL, recipient_amount_msat INTEGER NOT NULL, payment_id TEXT NOT NULL, parent_payment_id TEXT NOT NULL, payment_hash BLOB NOT NULL, payment_preimage BLOB NOT NULL, recipient_node_id BLOB NOT NULL, to_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS received (amount_msat INTEGER NOT NULL, payment_hash BLOB NOT NULL, from_channel_id BLOB NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS relayed (payment_hash BLOB NOT NULL, amount_msat INTEGER NOT NULL, channel_id BLOB NOT NULL, direction TEXT NOT NULL, relay_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS network_fees (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, tx_id BLOB NOT NULL, fee_sat INTEGER NOT NULL, tx_type TEXT NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_events (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, capacity_sat INTEGER NOT NULL, is_funder BOOLEAN NOT NULL, is_private BOOLEAN NOT NULL, event TEXT NOT NULL, timestamp INTEGER NOT NULL)") + statement.executeUpdate("CREATE TABLE IF NOT EXISTS channel_errors (channel_id BLOB NOT NULL, node_id BLOB NOT NULL, error_name TEXT NOT NULL, error_message TEXT NOT NULL, is_fatal INTEGER NOT NULL, timestamp INTEGER NOT NULL)") + + statement.executeUpdate("CREATE INDEX IF NOT EXISTS sent_timestamp_idx ON sent(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS received_timestamp_idx ON received(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_timestamp_idx ON relayed(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS relayed_payment_hash_idx ON relayed(payment_hash)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS network_fees_timestamp_idx ON network_fees(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_events_timestamp_idx ON channel_events(timestamp)") + statement.executeUpdate("CREATE INDEX IF NOT EXISTS channel_errors_timestamp_idx ON channel_errors(timestamp)") + + setVersion(statement, "audit", 4) + } - val relayed1 = ChannelPaymentRelayed(600 msat, 500 msat, randomBytes32, randomBytes32, randomBytes32, 105) - val relayed2 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(300 msat, randomBytes32), PaymentRelayed.Part(350 msat, randomBytes32)), Seq(PaymentRelayed.Part(600 msat, randomBytes32)), PlaceHolderPubKey, 0 msat, 110) + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, relayed1.paymentHash.toArray) + statement.setLong(2, relayed1.amountIn.toLong) + statement.setBytes(3, relayed1.fromChannelId.toArray) + statement.setString(4, "IN") + statement.setString(5, "channel") + statement.setLong(6, relayed1.timestamp) + statement.executeUpdate() + } + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, relayed1.paymentHash.toArray) + statement.setLong(2, relayed1.amountOut.toLong) + statement.setBytes(3, relayed1.toChannelId.toArray) + statement.setString(4, "OUT") + statement.setString(5, "channel") + statement.setLong(6, relayed1.timestamp) + statement.executeUpdate() + } + for (incoming <- relayed2.incoming) { + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, relayed2.paymentHash.toArray) + statement.setLong(2, incoming.amount.toLong) + statement.setBytes(3, incoming.channelId.toArray) + statement.setString(4, "IN") + statement.setString(5, "trampoline") + statement.setLong(6, relayed2.timestamp) + statement.executeUpdate() + } + } + for (outgoing <- relayed2.outgoing) { + using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => + statement.setBytes(1, relayed2.paymentHash.toArray) + statement.setLong(2, outgoing.amount.toLong) + statement.setBytes(3, outgoing.channelId.toArray) + statement.setString(4, "OUT") + statement.setString(5, "trampoline") + statement.setLong(6, relayed2.timestamp) + statement.executeUpdate() + } + } + }, + dbName = "audit", + targetVersion = 5, + postCheck = connection => { + val migratedDb = dbs.audit + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit").contains(5)) + } + assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) - using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setBytes(1, relayed1.paymentHash.toArray) - statement.setLong(2, relayed1.amountIn.toLong) - statement.setBytes(3, relayed1.fromChannelId.toArray) - statement.setString(4, "IN") - statement.setString(5, "channel") - statement.setLong(6, relayed1.timestamp) - statement.executeUpdate() - } - using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setBytes(1, relayed1.paymentHash.toArray) - statement.setLong(2, relayed1.amountOut.toLong) - statement.setBytes(3, relayed1.toChannelId.toArray) - statement.setString(4, "OUT") - statement.setString(5, "channel") - statement.setLong(6, relayed1.timestamp) - statement.executeUpdate() - } - for (incoming <- relayed2.incoming) { - using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setBytes(1, relayed2.paymentHash.toArray) - statement.setLong(2, incoming.amount.toLong) - statement.setBytes(3, incoming.channelId.toArray) - statement.setString(4, "IN") - statement.setString(5, "trampoline") - statement.setLong(6, relayed2.timestamp) - statement.executeUpdate() - } - } - for (outgoing <- relayed2.outgoing) { - using(connection.prepareStatement("INSERT INTO relayed VALUES (?, ?, ?, ?, ?, ?)")) { statement => - statement.setBytes(1, relayed2.paymentHash.toArray) - statement.setLong(2, outgoing.amount.toLong) - statement.setBytes(3, outgoing.channelId.toArray) - statement.setString(4, "OUT") - statement.setString(5, "trampoline") - statement.setLong(6, relayed2.timestamp) - statement.executeUpdate() + val postMigrationDb = new SqliteAuditDb(connection) + using(connection.createStatement()) { statement => + assert(getVersion(statement, "audit").contains(5)) + } + val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) + postMigrationDb.add(relayed3) + assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) } - } - - val migratedDb = new SqliteAuditDb(connection) - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) - } - - assert(migratedDb.listRelayed(100, 120) === Seq(relayed1, relayed2)) - - val postMigrationDb = new SqliteAuditDb(connection) - - using(connection.createStatement()) { statement => - assert(getVersion(statement, "audit").contains(5)) - } - - val relayed3 = TrampolinePaymentRelayed(randomBytes32, Seq(PaymentRelayed.Part(450 msat, randomBytes32), PaymentRelayed.Part(500 msat, randomBytes32)), Seq(PaymentRelayed.Part(800 msat, randomBytes32)), randomKey.publicKey, 700 msat, 150) - - postMigrationDb.add(relayed3) - assert(postMigrationDb.listRelayed(100, 160) === Seq(relayed1, relayed2, relayed3)) + ) } } From f418d5ba15810666550d6a5c20d173a55edd3ccf Mon Sep 17 00:00:00 2001 From: pm47 Date: Tue, 20 Apr 2021 17:49:16 +0200 Subject: [PATCH 4/9] add ORDER BY timestamp to listRelayed --- .../src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index 1e10235729..e1fd4aa2a1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -268,7 +268,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] = inTransaction { pg => var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)] - using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement => + using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() From 641abaebc41915e8050f3f42168847a859391327 Mon Sep 17 00:00:00 2001 From: pm47 Date: Wed, 21 Apr 2021 21:21:03 +0200 Subject: [PATCH 5/9] remove alignment in sql statements --- .../scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala index e223a12c53..29e6c20ed2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala @@ -53,13 +53,13 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit } def migration34(statement: Statement): Unit = { - statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN created_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + created_timestamp * interval '1 millisecond'") - statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_sent_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_sent_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN created_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + created_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_sent_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_sent_timestamp * interval '1 millisecond'") statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_payment_received_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_payment_received_timestamp * interval '1 millisecond'") - statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_connected_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_connected_timestamp * interval '1 millisecond'") - statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN closed_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + closed_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN last_connected_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + last_connected_timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE local_channels ALTER COLUMN closed_timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + closed_timestamp * interval '1 millisecond'") - statement.executeUpdate("ALTER TABLE htlc_infos ALTER COLUMN commitment_number SET DATA TYPE BIGINT USING commitment_number::BIGINT") + statement.executeUpdate("ALTER TABLE htlc_infos ALTER COLUMN commitment_number SET DATA TYPE BIGINT USING commitment_number::BIGINT") } getVersion(statement, DB_NAME) match { From 6b20b25fe85053f4b8b586e88a076af5f869af5c Mon Sep 17 00:00:00 2001 From: pm47 Date: Wed, 21 Apr 2021 21:24:20 +0200 Subject: [PATCH 6/9] remove redundant ORDER BY clauses --- .../src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index e1fd4aa2a1..ed50a01baa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -212,7 +212,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listSent(from: Long, to: Long): Seq[PaymentSent] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ? ORDER BY timestamp")) { statement => + using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() @@ -244,7 +244,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listReceived(from: Long, to: Long): Seq[PaymentReceived] = inTransaction { pg => - using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ? ORDER BY timestamp")) { statement => + using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() @@ -268,7 +268,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] = inTransaction { pg => var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)] - using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement => + using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() @@ -279,7 +279,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { trampolineByHash += (paymentHash -> (amount, nodeId)) } } - using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement => + using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) val rs = statement.executeQuery() From c97900aaf8541bcfe42c240052f36647b665c639 Mon Sep 17 00:00:00 2001 From: pm47 Date: Wed, 21 Apr 2021 21:25:36 +0200 Subject: [PATCH 7/9] remove superfluous IF NOT EXISTS clauses --- .../scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala index 26590765b4..70a62846c3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala @@ -62,9 +62,9 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { getVersion(statement, DB_NAME) match { case None => - statement.executeUpdate("CREATE TABLE IF NOT EXISTS local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0, created_timestamp INTEGER, last_payment_sent_timestamp INTEGER, last_payment_received_timestamp INTEGER, last_connected_timestamp INTEGER, closed_timestamp INTEGER)") - statement.executeUpdate("CREATE TABLE IF NOT EXISTS htlc_infos (channel_id BLOB NOT NULL, commitment_number INTEGER NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") - statement.executeUpdate("CREATE INDEX IF NOT EXISTS htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") + statement.executeUpdate("CREATE TABLE local_channels (channel_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, is_closed BOOLEAN NOT NULL DEFAULT 0, created_timestamp INTEGER, last_payment_sent_timestamp INTEGER, last_payment_received_timestamp INTEGER, last_connected_timestamp INTEGER, closed_timestamp INTEGER)") + statement.executeUpdate("CREATE TABLE htlc_infos (channel_id BLOB NOT NULL, commitment_number INTEGER NOT NULL, payment_hash BLOB NOT NULL, cltv_expiry INTEGER NOT NULL, FOREIGN KEY(channel_id) REFERENCES local_channels(channel_id))") + statement.executeUpdate("CREATE INDEX htlc_infos_idx ON htlc_infos(channel_id, commitment_number)") case Some(v@1) => logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") migration12(statement) From aba59abd14725ccf71acd5eed1224f874b27fb8c Mon Sep 17 00:00:00 2001 From: pm47 Date: Wed, 21 Apr 2021 21:27:20 +0200 Subject: [PATCH 8/9] nits --- .../scala/fr/acinq/eclair/TestDatabases.scala | 13 ++++++------- .../fr/acinq/eclair/db/AuditDbSpec.scala | 2 +- .../fr/acinq/eclair/db/ChannelsDbSpec.scala | 19 +++++++++---------- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala index 4ab1358eb0..34279ebea6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala @@ -4,8 +4,8 @@ import akka.actor.ActorSystem import com.opentable.db.postgres.embedded.EmbeddedPostgres import com.zaxxer.hikari.HikariConfig import fr.acinq.eclair.db._ -import fr.acinq.eclair.db.pg.PgUtils.{PgLock, getVersion, using} import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler +import fr.acinq.eclair.db.pg.PgUtils.{PgLock, getVersion, using} import org.postgresql.jdbc.PgConnection import org.sqlite.SQLiteConnection @@ -78,12 +78,11 @@ object TestDatabases { // @formatter:on } - def migrationCheck( - dbs: TestDatabases, - initializeTables: Connection => Unit, - dbName: String, - targetVersion: Int, - postCheck: Connection => Unit + def migrationCheck(dbs: TestDatabases, + initializeTables: Connection => Unit, + dbName: String, + targetVersion: Int, + postCheck: Connection => Unit ): Unit = { val connection = dbs.connection // initialize the database to a previous version and populate data diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala index 6a17040c4c..59ec639baa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/AuditDbSpec.scala @@ -183,7 +183,7 @@ class AuditDbSpec extends AnyFunSuite { } } - test("migrate sqlite audit database v1 -> v5/v6") { + test("migrate sqlite audit database v1 -> v5") { val dbs = TestSqliteDatabases() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala index 609e8ff0ff..d26ad9976c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala @@ -327,16 +327,15 @@ class ChannelsDbSpec extends AnyFunSuite { object ChannelsDbSpec { - case class TestCase( - channelId: ByteVector32, - data: ByteVector, - isClosed: Boolean, - createdTimestamp: Option[Long], - lastPaymentSentTimestamp: Option[Long], - lastPaymentReceivedTimestamp: Option[Long], - lastConnectedTimestamp: Option[Long], - closedTimestamp: Option[Long], - commitmentNumbers: Seq[Int] + case class TestCase(channelId: ByteVector32, + data: ByteVector, + isClosed: Boolean, + createdTimestamp: Option[Long], + lastPaymentSentTimestamp: Option[Long], + lastPaymentReceivedTimestamp: Option[Long], + lastConnectedTimestamp: Option[Long], + closedTimestamp: Option[Long], + commitmentNumbers: Seq[Int] ) private val data = stateDataCodec.encode(ChannelCodecsSpec.normal).require.bytes From 1ff9b7bd1bc11f2ddeb72f801f95896943f1b665 Mon Sep 17 00:00:00 2001 From: pm47 Date: Thu, 22 Apr 2021 09:57:16 +0200 Subject: [PATCH 9/9] fixup! remove alignment in sql statements --- .../main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index ed50a01baa..2e44c80ccc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -54,13 +54,13 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { } def migration56(statement: Statement): Unit = { - statement.executeUpdate("ALTER TABLE sent ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") - statement.executeUpdate("ALTER TABLE received ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") - statement.executeUpdate("ALTER TABLE relayed ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE sent ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE received ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE relayed ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") statement.executeUpdate("ALTER TABLE relayed_trampoline ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") - statement.executeUpdate("ALTER TABLE network_fees ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") - statement.executeUpdate("ALTER TABLE channel_events ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") - statement.executeUpdate("ALTER TABLE channel_errors ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE network_fees ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE channel_events ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") + statement.executeUpdate("ALTER TABLE channel_errors ALTER COLUMN timestamp SET DATA TYPE TIMESTAMP WITH TIME ZONE USING timestamp with time zone 'epoch' + timestamp * interval '1 millisecond'") } getVersion(statement, DB_NAME) match {