From 002a87d3baf8d9a2c1e3bfc8f3dde05ec891173f Mon Sep 17 00:00:00 2001 From: AsamK Date: Sun, 1 Oct 2023 16:01:46 +0200 Subject: [PATCH] Add pre key cleanup and improve refresh --- .../signal/manager/config/ServiceConfig.java | 5 + .../signal/manager/helper/PreKeyHelper.java | 188 +++++++++++------- .../manager/storage/AccountDatabase.java | 12 +- .../signal/manager/storage/SignalAccount.java | 27 ++- .../storage/prekeys/KyberPreKeyStore.java | 86 +++++++- .../manager/storage/prekeys/PreKeyStore.java | 58 +++++- .../storage/prekeys/SignedPreKeyStore.java | 29 +++ .../storage/protocol/SignalProtocolStore.java | 14 +- 8 files changed, 330 insertions(+), 89 deletions(-) diff --git a/lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java b/lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java index 5b6bce47..7b2ba140 100644 --- a/lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java +++ b/lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java @@ -11,6 +11,7 @@ import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; import java.util.List; +import java.util.concurrent.TimeUnit; import okhttp3.Interceptor; @@ -19,6 +20,10 @@ public class ServiceConfig { public final static int PREKEY_MINIMUM_COUNT = 10; public final static int PREKEY_BATCH_SIZE = 100; public final static int PREKEY_MAXIMUM_ID = Medium.MAX_VALUE; + public static final long PREKEY_ARCHIVE_AGE = TimeUnit.DAYS.toMillis(30); + public static final long PREKEY_STALE_AGE = TimeUnit.DAYS.toMillis(90); + public static final long SIGNED_PREKEY_ROTATE_AGE = TimeUnit.DAYS.toMillis(2); + public final static int MAX_ATTACHMENT_SIZE = 150 * 1024 * 1024; public final static long MAX_ENVELOPE_SIZE = 0; public final static long AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE = 10 * 1024 * 1024; diff --git a/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java b/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java index 54eab17b..34d11171 100644 --- a/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java +++ b/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java @@ -5,6 +5,7 @@ import org.asamk.signal.manager.internal.SignalDependencies; import org.asamk.signal.manager.storage.SignalAccount; import org.asamk.signal.manager.util.KeyUtils; import org.signal.libsignal.protocol.IdentityKeyPair; +import org.signal.libsignal.protocol.InvalidKeyIdException; import org.signal.libsignal.protocol.state.KyberPreKeyRecord; import org.signal.libsignal.protocol.state.PreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord; @@ -18,6 +19,9 @@ import org.whispersystems.signalservice.internal.push.OneTimePreKeyCounts; import java.io.IOException; import java.util.List; +import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_STALE_AGE; +import static org.asamk.signal.manager.config.ServiceConfig.SIGNED_PREKEY_ROTATE_AGE; + public class PreKeyHelper { private final static Logger logger = LoggerFactory.getLogger(PreKeyHelper.class); @@ -38,30 +42,6 @@ public class PreKeyHelper { } public void refreshPreKeysIfNecessary(ServiceIdType serviceIdType) throws IOException { - OneTimePreKeyCounts preKeyCounts; - try { - preKeyCounts = dependencies.getAccountManager().getPreKeyCounts(serviceIdType); - } catch (AuthorizationFailedException e) { - logger.debug("Failed to get pre key count, ignoring: " + e.getClass().getSimpleName()); - preKeyCounts = new OneTimePreKeyCounts(0, 0); - } - if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) { - logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain", - serviceIdType, - preKeyCounts.getEcCount(), - ServiceConfig.PREKEY_MINIMUM_COUNT); - refreshPreKeys(serviceIdType); - } - if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) { - logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain", - serviceIdType, - preKeyCounts.getKyberCount(), - ServiceConfig.PREKEY_MINIMUM_COUNT); - refreshKyberPreKeys(serviceIdType); - } - } - - private void refreshPreKeys(ServiceIdType serviceIdType) throws IOException { final var identityKeyPair = account.getIdentityKeyPair(serviceIdType); if (identityKeyPair == null) { return; @@ -70,28 +50,73 @@ public class PreKeyHelper { if (accountId == null) { return; } + + OneTimePreKeyCounts preKeyCounts; try { - refreshPreKeys(serviceIdType, identityKeyPair); + preKeyCounts = dependencies.getAccountManager().getPreKeyCounts(serviceIdType); + } catch (AuthorizationFailedException e) { + logger.debug("Failed to get pre key count, ignoring: " + e.getClass().getSimpleName()); + preKeyCounts = new OneTimePreKeyCounts(0, 0); + } + + SignedPreKeyRecord signedPreKeyRecord = null; + List preKeyRecords = null; + KyberPreKeyRecord lastResortKyberPreKeyRecord = null; + List kyberPreKeyRecords = null; + + try { + if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) { + logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain", + serviceIdType, + preKeyCounts.getEcCount(), + ServiceConfig.PREKEY_MINIMUM_COUNT); + preKeyRecords = generatePreKeys(serviceIdType); + } + if (signedPreKeyNeedsRefresh(serviceIdType)) { + logger.debug("Refreshing {} signed pre key.", serviceIdType); + signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair); + } } catch (Exception e) { logger.warn("Failed to store new pre keys, resetting preKey id offset", e); account.resetPreKeyOffsets(serviceIdType); - refreshPreKeys(serviceIdType, identityKeyPair); + preKeyRecords = generatePreKeys(serviceIdType); + signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair); + } + + try { + if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) { + logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain", + serviceIdType, + preKeyCounts.getKyberCount(), + ServiceConfig.PREKEY_MINIMUM_COUNT); + kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair); + } + if (lastResortKyberPreKeyNeedsRefresh(serviceIdType)) { + logger.debug("Refreshing {} last resort kyber pre key.", serviceIdType); + lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair); + } + } catch (Exception e) { + logger.warn("Failed to store new kyber pre keys, resetting preKey id offset", e); + account.resetKyberPreKeyOffsets(serviceIdType); + kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair); + lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair); + } + + if (signedPreKeyRecord != null + || preKeyRecords != null + || lastResortKyberPreKeyRecord != null + || kyberPreKeyRecords != null) { + final var preKeyUpload = new PreKeyUpload(serviceIdType, + identityKeyPair.getPublicKey(), + signedPreKeyRecord, + preKeyRecords, + lastResortKyberPreKeyRecord, + kyberPreKeyRecords); + dependencies.getAccountManager().setPreKeys(preKeyUpload); } - } - private void refreshPreKeys( - final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair - ) throws IOException { - final var oneTimePreKeys = generatePreKeys(serviceIdType); - final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair); - - final var preKeyUpload = new PreKeyUpload(serviceIdType, - identityKeyPair.getPublicKey(), - signedPreKeyRecord, - oneTimePreKeys, - null, - null); - dependencies.getAccountManager().setPreKeys(preKeyUpload); + cleanSignedPreKeys((serviceIdType)); + cleanOneTimePreKeys(serviceIdType); } private List generatePreKeys(ServiceIdType serviceIdType) { @@ -103,6 +128,21 @@ public class PreKeyHelper { return records; } + private boolean signedPreKeyNeedsRefresh(ServiceIdType serviceIdType) { + final var accountData = account.getAccountData(serviceIdType); + + final var activeSignedPreKeyId = accountData.getPreKeyMetadata().getActiveSignedPreKeyId(); + if (activeSignedPreKeyId == -1) { + return true; + } + try { + final var signedPreKeyRecord = accountData.getSignedPreKeyStore().loadSignedPreKey(activeSignedPreKeyId); + return signedPreKeyRecord.getTimestamp() < System.currentTimeMillis() - SIGNED_PREKEY_ROTATE_AGE; + } catch (InvalidKeyIdException e) { + return true; + } + } + private SignedPreKeyRecord generateSignedPreKey(ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair) { final var signedPreKeyId = account.getNextSignedPreKeyId(serviceIdType); @@ -112,39 +152,6 @@ public class PreKeyHelper { return record; } - private void refreshKyberPreKeys(ServiceIdType serviceIdType) throws IOException { - final var identityKeyPair = account.getIdentityKeyPair(serviceIdType); - if (identityKeyPair == null) { - return; - } - final var accountId = account.getAccountId(serviceIdType); - if (accountId == null) { - return; - } - try { - refreshKyberPreKeys(serviceIdType, identityKeyPair); - } catch (Exception e) { - logger.warn("Failed to store new pre keys, resetting preKey id offset", e); - account.resetKyberPreKeyOffsets(serviceIdType); - refreshKyberPreKeys(serviceIdType, identityKeyPair); - } - } - - private void refreshKyberPreKeys( - final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair - ) throws IOException { - final var oneTimePreKeys = generateKyberPreKeys(serviceIdType, identityKeyPair); - final var lastResortPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair); - - final var preKeyUpload = new PreKeyUpload(serviceIdType, - identityKeyPair.getPublicKey(), - null, - null, - lastResortPreKeyRecord, - oneTimePreKeys); - dependencies.getAccountManager().setPreKeys(preKeyUpload); - } - private List generateKyberPreKeys( ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair ) { @@ -156,6 +163,22 @@ public class PreKeyHelper { return records; } + private boolean lastResortKyberPreKeyNeedsRefresh(ServiceIdType serviceIdType) { + final var accountData = account.getAccountData(serviceIdType); + + final var activeLastResortKyberPreKeyId = accountData.getPreKeyMetadata().getActiveLastResortKyberPreKeyId(); + if (activeLastResortKyberPreKeyId == -1) { + return true; + } + try { + final var kyberPreKeyRecord = accountData.getKyberPreKeyStore() + .loadKyberPreKey(activeLastResortKyberPreKeyId); + return kyberPreKeyRecord.getTimestamp() < System.currentTimeMillis() - SIGNED_PREKEY_ROTATE_AGE; + } catch (InvalidKeyIdException e) { + return true; + } + } + private KyberPreKeyRecord generateLastResortKyberPreKey( ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair ) { @@ -166,4 +189,23 @@ public class PreKeyHelper { return record; } + + private void cleanSignedPreKeys(ServiceIdType serviceIdType) { + final var accountData = account.getAccountData(serviceIdType); + + final var activeSignedPreKeyId = accountData.getPreKeyMetadata().getActiveSignedPreKeyId(); + accountData.getSignedPreKeyStore().removeOldSignedPreKeys(activeSignedPreKeyId); + + final var activeLastResortKyberPreKeyId = accountData.getPreKeyMetadata().getActiveLastResortKyberPreKeyId(); + accountData.getKyberPreKeyStore().removeOldLastResortKyberPreKeys(activeLastResortKyberPreKeyId); + } + + private void cleanOneTimePreKeys(ServiceIdType serviceIdType) { + long threshold = System.currentTimeMillis() - PREKEY_STALE_AGE; + int minCount = 200; + + final var accountData = account.getAccountData(serviceIdType); + accountData.getPreKeyStore().deleteAllStaleOneTimeEcPreKeys(threshold, minCount); + accountData.getKyberPreKeyStore().deleteAllStaleOneTimeKyberPreKeys(threshold, minCount); + } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java b/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java index ee202d39..fb34b440 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java @@ -30,7 +30,7 @@ import java.util.UUID; public class AccountDatabase extends Database { private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class); - private static final long DATABASE_VERSION = 15; + private static final long DATABASE_VERSION = 16; private AccountDatabase(final HikariDataSource dataSource) { super(logger, DATABASE_VERSION, dataSource); @@ -493,5 +493,15 @@ public class AccountDatabase extends Database { """); } } + if (oldVersion < 16) { + logger.debug("Updating database: Adding stale_timestamp prekey field"); + try (final var statement = connection.createStatement()) { + statement.executeUpdate(""" + ALTER TABLE pre_key ADD COLUMN stale_timestamp INTEGER; + ALTER TABLE kyber_pre_key ADD COLUMN stale_timestamp INTEGER; + ALTER TABLE kyber_pre_key ADD COLUMN timestamp INTEGER DEFAULT 0; + """); + } + } } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java index 3ff8a864..18a195e6 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java @@ -622,6 +622,11 @@ public class SignalAccount implements Closeable { } else { aciAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset(); } + if (rootNode.hasNonNull("activeSignedPreKeyId")) { + aciAccountData.preKeyMetadata.activeSignedPreKeyId = rootNode.get("activeSignedPreKeyId").asInt(-1); + } else { + aciAccountData.preKeyMetadata.activeSignedPreKeyId = -1; + } if (rootNode.hasNonNull("pniPreKeyIdOffset")) { pniAccountData.preKeyMetadata.preKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1); } else { @@ -632,6 +637,11 @@ public class SignalAccount implements Closeable { } else { pniAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset(); } + if (rootNode.hasNonNull("pniActiveSignedPreKeyId")) { + pniAccountData.preKeyMetadata.activeSignedPreKeyId = rootNode.get("pniActiveSignedPreKeyId").asInt(-1); + } else { + pniAccountData.preKeyMetadata.activeSignedPreKeyId = -1; + } if (rootNode.hasNonNull("kyberPreKeyIdOffset")) { aciAccountData.preKeyMetadata.kyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1); } else { @@ -1003,8 +1013,10 @@ public class SignalAccount implements Closeable { .put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion) .put("preKeyIdOffset", aciAccountData.getPreKeyMetadata().preKeyIdOffset) .put("nextSignedPreKeyId", aciAccountData.getPreKeyMetadata().nextSignedPreKeyId) + .put("activeSignedPreKeyId", aciAccountData.getPreKeyMetadata().activeSignedPreKeyId) .put("pniPreKeyIdOffset", pniAccountData.getPreKeyMetadata().preKeyIdOffset) .put("pniNextSignedPreKeyId", pniAccountData.getPreKeyMetadata().nextSignedPreKeyId) + .put("pniActiveSignedPreKeyId", pniAccountData.getPreKeyMetadata().activeSignedPreKeyId) .put("kyberPreKeyIdOffset", aciAccountData.getPreKeyMetadata().kyberPreKeyIdOffset) .put("activeLastResortKyberPreKeyId", aciAccountData.getPreKeyMetadata().activeLastResortKyberPreKeyId) @@ -1058,6 +1070,7 @@ public class SignalAccount implements Closeable { final var preKeyMetadata = getAccountData(serviceIdType).getPreKeyMetadata(); preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset(); preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset(); + preKeyMetadata.activeSignedPreKeyId = -1; save(); } @@ -1072,6 +1085,7 @@ public class SignalAccount implements Closeable { records.size(), serviceIdType, preKeyMetadata.preKeyIdOffset); + accountData.signalProtocolStore.markAllOneTimeEcPreKeysStaleIfNecessary(System.currentTimeMillis()); for (var record : records) { if (preKeyMetadata.preKeyIdOffset != record.getId()) { logger.error("Invalid pre key id {}, expected {}", record.getId(), preKeyMetadata.preKeyIdOffset); @@ -1095,6 +1109,7 @@ public class SignalAccount implements Closeable { } accountData.getSignedPreKeyStore().storeSignedPreKey(record.getId(), record); preKeyMetadata.nextSignedPreKeyId = (preKeyMetadata.nextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID; + preKeyMetadata.activeSignedPreKeyId = record.getId(); save(); } @@ -1112,6 +1127,7 @@ public class SignalAccount implements Closeable { records.size(), serviceIdType, preKeyMetadata.kyberPreKeyIdOffset); + accountData.signalProtocolStore.markAllOneTimeEcPreKeysStaleIfNecessary(System.currentTimeMillis()); for (var record : records) { if (preKeyMetadata.kyberPreKeyIdOffset != record.getId()) { logger.error("Invalid kyber pre key id {}, expected {}", @@ -1745,6 +1761,7 @@ public class SignalAccount implements Closeable { private int preKeyIdOffset = 1; private int nextSignedPreKeyId = 1; + private int activeSignedPreKeyId = -1; private int kyberPreKeyIdOffset = 1; private int activeLastResortKyberPreKeyId = -1; @@ -1756,6 +1773,10 @@ public class SignalAccount implements Closeable { return nextSignedPreKeyId; } + public int getActiveSignedPreKeyId() { + return activeSignedPreKeyId; + } + public int getKyberPreKeyIdOffset() { return kyberPreKeyIdOffset; } @@ -1819,17 +1840,17 @@ public class SignalAccount implements Closeable { SignalAccount.this::isMultiDevice)); } - private PreKeyStore getPreKeyStore() { + public PreKeyStore getPreKeyStore() { return getOrCreate(() -> preKeyStore, () -> preKeyStore = new PreKeyStore(getAccountDatabase(), serviceIdType)); } - private SignedPreKeyStore getSignedPreKeyStore() { + public SignedPreKeyStore getSignedPreKeyStore() { return getOrCreate(() -> signedPreKeyStore, () -> signedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), serviceIdType)); } - private KyberPreKeyStore getKyberPreKeyStore() { + public KyberPreKeyStore getKyberPreKeyStore() { return getOrCreate(() -> kyberPreKeyStore, () -> kyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), serviceIdType)); } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java index 9ce660b2..4067e2fe 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java @@ -15,6 +15,8 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.util.List; +import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_ARCHIVE_AGE; + public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore { private static final String TABLE_KYBER_PRE_KEY = "kyber_pre_key"; @@ -33,6 +35,8 @@ public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore { key_id INTEGER NOT NULL, serialized BLOB NOT NULL, is_last_resort INTEGER NOT NULL, + stale_timestamp INTEGER, + timestamp INTEGER DEFAULT 0, UNIQUE(account_id_type, key_id) ) STRICT; """); @@ -104,8 +108,8 @@ public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore { public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record, final boolean isLastResort) { final var sql = ( """ - INSERT INTO %s (account_id_type, key_id, serialized, is_last_resort) - VALUES (?, ?, ?, ?) + INSERT INTO %s (account_id_type, key_id, serialized, is_last_resort, timestamp) + VALUES (?, ?, ?, ?, ?) """ ).formatted(TABLE_KYBER_PRE_KEY); try (final var connection = database.getConnection()) { @@ -114,6 +118,7 @@ public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore { statement.setInt(2, keyId); statement.setBytes(3, record.serialize()); statement.setBoolean(4, isLastResort); + statement.setLong(5, record.getTimestamp()); statement.executeUpdate(); } } catch (SQLException e) { @@ -209,13 +214,86 @@ public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore { } } + public void removeOldLastResortKyberPreKeys(int activeLastResortKyberPreKeyId) { + final var sql = ( + """ + DELETE FROM %s AS p + WHERE p._id IN ( + SELECT p._id + FROM %s AS p + WHERE p.account_id_type = ? + AND p.is_last_resort = TRUE + AND p.key_id != ? + AND p.timestamp < ? + ORDER BY p.timestamp DESC + LIMIT -1 OFFSET 1 + ) + """ + ).formatted(TABLE_KYBER_PRE_KEY, TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setInt(2, activeLastResortKyberPreKeyId); + statement.setLong(3, System.currentTimeMillis() - PREKEY_ARCHIVE_AGE); + statement.executeUpdate(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update kyber_pre_key store", e); + } + } + @Override public void deleteAllStaleOneTimeKyberPreKeys(final long threshold, final int minCount) { - //TODO + final var sql = ( + """ + DELETE FROM %s AS p + WHERE p.account_id_type = ?1 + AND p.stale_timestamp < ?2 + AND p.is_last_resort = FALSE + AND p._id NOT IN ( + SELECT _id + FROM %s p2 + WHERE p2.account_id_type = ?1 + ORDER BY + CASE WHEN p2.stale_timestamp IS NULL THEN 1 ELSE 0 END DESC, + p2.stale_timestamp DESC, + p2._id DESC + LIMIT ?3 + ) + """ + ).formatted(TABLE_KYBER_PRE_KEY, TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setLong(2, threshold); + statement.setInt(3, minCount); + final var rowCount = statement.executeUpdate(); + if (rowCount > 0) { + logger.debug("Deleted {} stale one time kyber pre keys", rowCount); + } + } + } catch (SQLException e) { + throw new RuntimeException("Failed update kyber_pre_key store", e); + } } @Override public void markAllOneTimeKyberPreKeysStaleIfNecessary(final long staleTime) { - //TODO + final var sql = ( + """ + UPDATE %s + SET stale_timestamp = ? + WHERE account_id_type = ? AND stale_timestamp IS NULL AND is_last_resort = FALSE + """ + ).formatted(TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setLong(1, staleTime); + statement.setInt(2, accountIdType); + statement.executeUpdate(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update kyber_pre_key store", e); + } } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java index f6dc271e..c3ac9632 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java @@ -9,6 +9,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair; import org.signal.libsignal.protocol.state.PreKeyRecord; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.whispersystems.signalservice.api.SignalServicePreKeyStore; import org.whispersystems.signalservice.api.push.ServiceIdType; import java.sql.Connection; @@ -16,7 +17,7 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.util.Collection; -public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeyStore { +public class PreKeyStore implements SignalServicePreKeyStore { private static final String TABLE_PRE_KEY = "pre_key"; private final static Logger logger = LoggerFactory.getLogger(PreKeyStore.class); @@ -34,6 +35,7 @@ public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeySt key_id INTEGER NOT NULL, public_key BLOB NOT NULL, private_key BLOB NOT NULL, + stale_timestamp INTEGER, UNIQUE(account_id_type, key_id) ) STRICT; """); @@ -181,4 +183,58 @@ public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeySt return null; } } + + @Override + public void deleteAllStaleOneTimeEcPreKeys(final long threshold, final int minCount) { + final var sql = ( + """ + DELETE FROM %s AS p + WHERE p.account_id_type = ?1 + AND p.stale_timestamp < ?2 + AND p._id NOT IN ( + SELECT _id + FROM %s AS p2 + WHERE p2.account_id_type = ?1 + ORDER BY + CASE WHEN p2.stale_timestamp IS NULL THEN 1 ELSE 0 END DESC, + p2.stale_timestamp DESC, + p2._id DESC + LIMIT ?3 + ) + """ + ).formatted(TABLE_PRE_KEY, TABLE_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setLong(2, threshold); + statement.setInt(3, minCount); + final var rowCount = statement.executeUpdate(); + if (rowCount > 0) { + logger.debug("Deleted {} stale one time pre keys", rowCount); + } + } + } catch (SQLException e) { + throw new RuntimeException("Failed update pre_key store", e); + } + } + + @Override + public void markAllOneTimeEcPreKeysStaleIfNecessary(final long staleTime) { + final var sql = ( + """ + UPDATE %s + SET stale_timestamp = ? + WHERE account_id_type = ? AND stale_timestamp IS NULL + """ + ).formatted(TABLE_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setLong(1, staleTime); + statement.setInt(2, accountIdType); + statement.executeUpdate(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update pre_key store", e); + } + } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/SignedPreKeyStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/SignedPreKeyStore.java index d618f6c3..7c726f16 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/SignedPreKeyStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/SignedPreKeyStore.java @@ -18,6 +18,8 @@ import java.util.Collection; import java.util.List; import java.util.Objects; +import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_ARCHIVE_AGE; + public class SignedPreKeyStore implements org.signal.libsignal.protocol.state.SignedPreKeyStore { private static final String TABLE_SIGNED_PRE_KEY = "signed_pre_key"; @@ -144,6 +146,33 @@ public class SignedPreKeyStore implements org.signal.libsignal.protocol.state.Si } } + public void removeOldSignedPreKeys(int activePreKeyId) { + final var sql = ( + """ + DELETE FROM %s AS p + WHERE p._id IN ( + SELECT p._id + FROM %s AS p + WHERE p.account_id_type = ? + AND p.key_id != ? + AND p.timestamp < ? + ORDER BY p.timestamp DESC + LIMIT -1 OFFSET 1 + ) + """ + ).formatted(TABLE_SIGNED_PRE_KEY, TABLE_SIGNED_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setInt(2, activePreKeyId); + statement.setLong(3, System.currentTimeMillis() - PREKEY_ARCHIVE_AGE); + statement.executeUpdate(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update signed_pre_key store", e); + } + } + void addLegacySignedPreKeys(final Collection signedPreKeys) { logger.debug("Migrating legacy signedPreKeys to database"); long start = System.nanoTime(); diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java index f042f0e2..ea1a5c6f 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java @@ -9,12 +9,12 @@ import org.signal.libsignal.protocol.groups.state.SenderKeyRecord; import org.signal.libsignal.protocol.state.IdentityKeyStore; import org.signal.libsignal.protocol.state.KyberPreKeyRecord; import org.signal.libsignal.protocol.state.PreKeyRecord; -import org.signal.libsignal.protocol.state.PreKeyStore; import org.signal.libsignal.protocol.state.SessionRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyStore; import org.whispersystems.signalservice.api.SignalServiceAccountDataStore; import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore; +import org.whispersystems.signalservice.api.SignalServicePreKeyStore; import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore; import org.whispersystems.signalservice.api.SignalServiceSessionStore; import org.whispersystems.signalservice.api.push.DistributionId; @@ -27,7 +27,7 @@ import java.util.function.Supplier; public class SignalProtocolStore implements SignalServiceAccountDataStore { - private final PreKeyStore preKeyStore; + private final SignalServicePreKeyStore preKeyStore; private final SignedPreKeyStore signedPreKeyStore; private final SignalServiceKyberPreKeyStore kyberPreKeyStore; private final SignalServiceSessionStore sessionStore; @@ -36,7 +36,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore { private final Supplier isMultiDevice; public SignalProtocolStore( - final PreKeyStore preKeyStore, + final SignalServicePreKeyStore preKeyStore, final SignedPreKeyStore signedPreKeyStore, final SignalServiceKyberPreKeyStore kyberPreKeyStore, final SignalServiceSessionStore sessionStore, @@ -254,12 +254,12 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore { } @Override - public void deleteAllStaleOneTimeEcPreKeys(final long l, final int i) { - // TODO + public void deleteAllStaleOneTimeEcPreKeys(final long threshold, final int minCount) { + preKeyStore.deleteAllStaleOneTimeEcPreKeys(threshold, minCount); } @Override - public void markAllOneTimeEcPreKeysStaleIfNecessary(final long l) { - // TODO + public void markAllOneTimeEcPreKeysStaleIfNecessary(final long staleTime) { + preKeyStore.markAllOneTimeEcPreKeysStaleIfNecessary(staleTime); } } -- 2.50.1