]> nmode's Git Repositories - signal-cli/commitdiff
Add pre key cleanup and improve refresh
authorAsamK <asamk@gmx.de>
Sun, 1 Oct 2023 14:01:46 +0000 (16:01 +0200)
committerAsamK <asamk@gmx.de>
Sun, 1 Oct 2023 14:29:41 +0000 (16:29 +0200)
lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java
lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java
lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java
lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java
lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java
lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java
lib/src/main/java/org/asamk/signal/manager/storage/prekeys/SignedPreKeyStore.java
lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java

index 5b6bce475ee03aeb54b550a75885481b5a81fdf6..7b2ba140de7cfce44235c80fc579e857dde9d784 100644 (file)
@@ -11,6 +11,7 @@ import java.security.KeyStoreException;
 import java.security.NoSuchAlgorithmException;
 import java.security.cert.CertificateException;
 import java.util.List;
+import java.util.concurrent.TimeUnit;
 
 import okhttp3.Interceptor;
 
@@ -19,6 +20,10 @@ public class ServiceConfig {
     public final static int PREKEY_MINIMUM_COUNT = 10;
     public final static int PREKEY_BATCH_SIZE = 100;
     public final static int PREKEY_MAXIMUM_ID = Medium.MAX_VALUE;
+    public static final long PREKEY_ARCHIVE_AGE = TimeUnit.DAYS.toMillis(30);
+    public static final long PREKEY_STALE_AGE = TimeUnit.DAYS.toMillis(90);
+    public static final long SIGNED_PREKEY_ROTATE_AGE = TimeUnit.DAYS.toMillis(2);
+
     public final static int MAX_ATTACHMENT_SIZE = 150 * 1024 * 1024;
     public final static long MAX_ENVELOPE_SIZE = 0;
     public final static long AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE = 10 * 1024 * 1024;
index 54eab17bf521304fd4979e2e78e55a51578804ee..34d11171943430a89dede2e32811f93532d79bcf 100644 (file)
@@ -5,6 +5,7 @@ import org.asamk.signal.manager.internal.SignalDependencies;
 import org.asamk.signal.manager.storage.SignalAccount;
 import org.asamk.signal.manager.util.KeyUtils;
 import org.signal.libsignal.protocol.IdentityKeyPair;
+import org.signal.libsignal.protocol.InvalidKeyIdException;
 import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
 import org.signal.libsignal.protocol.state.PreKeyRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
@@ -18,6 +19,9 @@ import org.whispersystems.signalservice.internal.push.OneTimePreKeyCounts;
 import java.io.IOException;
 import java.util.List;
 
+import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_STALE_AGE;
+import static org.asamk.signal.manager.config.ServiceConfig.SIGNED_PREKEY_ROTATE_AGE;
+
 public class PreKeyHelper {
 
     private final static Logger logger = LoggerFactory.getLogger(PreKeyHelper.class);
@@ -38,30 +42,6 @@ public class PreKeyHelper {
     }
 
     public void refreshPreKeysIfNecessary(ServiceIdType serviceIdType) throws IOException {
-        OneTimePreKeyCounts preKeyCounts;
-        try {
-            preKeyCounts = dependencies.getAccountManager().getPreKeyCounts(serviceIdType);
-        } catch (AuthorizationFailedException e) {
-            logger.debug("Failed to get pre key count, ignoring: " + e.getClass().getSimpleName());
-            preKeyCounts = new OneTimePreKeyCounts(0, 0);
-        }
-        if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
-            logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain",
-                    serviceIdType,
-                    preKeyCounts.getEcCount(),
-                    ServiceConfig.PREKEY_MINIMUM_COUNT);
-            refreshPreKeys(serviceIdType);
-        }
-        if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
-            logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain",
-                    serviceIdType,
-                    preKeyCounts.getKyberCount(),
-                    ServiceConfig.PREKEY_MINIMUM_COUNT);
-            refreshKyberPreKeys(serviceIdType);
-        }
-    }
-
-    private void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
         final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
         if (identityKeyPair == null) {
             return;
@@ -70,28 +50,73 @@ public class PreKeyHelper {
         if (accountId == null) {
             return;
         }
+
+        OneTimePreKeyCounts preKeyCounts;
         try {
-            refreshPreKeys(serviceIdType, identityKeyPair);
+            preKeyCounts = dependencies.getAccountManager().getPreKeyCounts(serviceIdType);
+        } catch (AuthorizationFailedException e) {
+            logger.debug("Failed to get pre key count, ignoring: " + e.getClass().getSimpleName());
+            preKeyCounts = new OneTimePreKeyCounts(0, 0);
+        }
+
+        SignedPreKeyRecord signedPreKeyRecord = null;
+        List<PreKeyRecord> preKeyRecords = null;
+        KyberPreKeyRecord lastResortKyberPreKeyRecord = null;
+        List<KyberPreKeyRecord> kyberPreKeyRecords = null;
+
+        try {
+            if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
+                logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain",
+                        serviceIdType,
+                        preKeyCounts.getEcCount(),
+                        ServiceConfig.PREKEY_MINIMUM_COUNT);
+                preKeyRecords = generatePreKeys(serviceIdType);
+            }
+            if (signedPreKeyNeedsRefresh(serviceIdType)) {
+                logger.debug("Refreshing {} signed pre key.", serviceIdType);
+                signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
+            }
         } catch (Exception e) {
             logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
             account.resetPreKeyOffsets(serviceIdType);
-            refreshPreKeys(serviceIdType, identityKeyPair);
+            preKeyRecords = generatePreKeys(serviceIdType);
+            signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
+        }
+
+        try {
+            if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
+                logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain",
+                        serviceIdType,
+                        preKeyCounts.getKyberCount(),
+                        ServiceConfig.PREKEY_MINIMUM_COUNT);
+                kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair);
+            }
+            if (lastResortKyberPreKeyNeedsRefresh(serviceIdType)) {
+                logger.debug("Refreshing {} last resort kyber pre key.", serviceIdType);
+                lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
+            }
+        } catch (Exception e) {
+            logger.warn("Failed to store new kyber pre keys, resetting preKey id offset", e);
+            account.resetKyberPreKeyOffsets(serviceIdType);
+            kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair);
+            lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
+        }
+
+        if (signedPreKeyRecord != null
+                || preKeyRecords != null
+                || lastResortKyberPreKeyRecord != null
+                || kyberPreKeyRecords != null) {
+            final var preKeyUpload = new PreKeyUpload(serviceIdType,
+                    identityKeyPair.getPublicKey(),
+                    signedPreKeyRecord,
+                    preKeyRecords,
+                    lastResortKyberPreKeyRecord,
+                    kyberPreKeyRecords);
+            dependencies.getAccountManager().setPreKeys(preKeyUpload);
         }
-    }
 
-    private void refreshPreKeys(
-            final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
-    ) throws IOException {
-        final var oneTimePreKeys = generatePreKeys(serviceIdType);
-        final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
-
-        final var preKeyUpload = new PreKeyUpload(serviceIdType,
-                identityKeyPair.getPublicKey(),
-                signedPreKeyRecord,
-                oneTimePreKeys,
-                null,
-                null);
-        dependencies.getAccountManager().setPreKeys(preKeyUpload);
+        cleanSignedPreKeys((serviceIdType));
+        cleanOneTimePreKeys(serviceIdType);
     }
 
     private List<PreKeyRecord> generatePreKeys(ServiceIdType serviceIdType) {
@@ -103,6 +128,21 @@ public class PreKeyHelper {
         return records;
     }
 
+    private boolean signedPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
+        final var accountData = account.getAccountData(serviceIdType);
+
+        final var activeSignedPreKeyId = accountData.getPreKeyMetadata().getActiveSignedPreKeyId();
+        if (activeSignedPreKeyId == -1) {
+            return true;
+        }
+        try {
+            final var signedPreKeyRecord = accountData.getSignedPreKeyStore().loadSignedPreKey(activeSignedPreKeyId);
+            return signedPreKeyRecord.getTimestamp() < System.currentTimeMillis() - SIGNED_PREKEY_ROTATE_AGE;
+        } catch (InvalidKeyIdException e) {
+            return true;
+        }
+    }
+
     private SignedPreKeyRecord generateSignedPreKey(ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair) {
         final var signedPreKeyId = account.getNextSignedPreKeyId(serviceIdType);
 
@@ -112,39 +152,6 @@ public class PreKeyHelper {
         return record;
     }
 
-    private void refreshKyberPreKeys(ServiceIdType serviceIdType) throws IOException {
-        final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
-        if (identityKeyPair == null) {
-            return;
-        }
-        final var accountId = account.getAccountId(serviceIdType);
-        if (accountId == null) {
-            return;
-        }
-        try {
-            refreshKyberPreKeys(serviceIdType, identityKeyPair);
-        } catch (Exception e) {
-            logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
-            account.resetKyberPreKeyOffsets(serviceIdType);
-            refreshKyberPreKeys(serviceIdType, identityKeyPair);
-        }
-    }
-
-    private void refreshKyberPreKeys(
-            final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
-    ) throws IOException {
-        final var oneTimePreKeys = generateKyberPreKeys(serviceIdType, identityKeyPair);
-        final var lastResortPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
-
-        final var preKeyUpload = new PreKeyUpload(serviceIdType,
-                identityKeyPair.getPublicKey(),
-                null,
-                null,
-                lastResortPreKeyRecord,
-                oneTimePreKeys);
-        dependencies.getAccountManager().setPreKeys(preKeyUpload);
-    }
-
     private List<KyberPreKeyRecord> generateKyberPreKeys(
             ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
     ) {
@@ -156,6 +163,22 @@ public class PreKeyHelper {
         return records;
     }
 
+    private boolean lastResortKyberPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
+        final var accountData = account.getAccountData(serviceIdType);
+
+        final var activeLastResortKyberPreKeyId = accountData.getPreKeyMetadata().getActiveLastResortKyberPreKeyId();
+        if (activeLastResortKyberPreKeyId == -1) {
+            return true;
+        }
+        try {
+            final var kyberPreKeyRecord = accountData.getKyberPreKeyStore()
+                    .loadKyberPreKey(activeLastResortKyberPreKeyId);
+            return kyberPreKeyRecord.getTimestamp() < System.currentTimeMillis() - SIGNED_PREKEY_ROTATE_AGE;
+        } catch (InvalidKeyIdException e) {
+            return true;
+        }
+    }
+
     private KyberPreKeyRecord generateLastResortKyberPreKey(
             ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair
     ) {
@@ -166,4 +189,23 @@ public class PreKeyHelper {
 
         return record;
     }
+
+    private void cleanSignedPreKeys(ServiceIdType serviceIdType) {
+        final var accountData = account.getAccountData(serviceIdType);
+
+        final var activeSignedPreKeyId = accountData.getPreKeyMetadata().getActiveSignedPreKeyId();
+        accountData.getSignedPreKeyStore().removeOldSignedPreKeys(activeSignedPreKeyId);
+
+        final var activeLastResortKyberPreKeyId = accountData.getPreKeyMetadata().getActiveLastResortKyberPreKeyId();
+        accountData.getKyberPreKeyStore().removeOldLastResortKyberPreKeys(activeLastResortKyberPreKeyId);
+    }
+
+    private void cleanOneTimePreKeys(ServiceIdType serviceIdType) {
+        long threshold = System.currentTimeMillis() - PREKEY_STALE_AGE;
+        int minCount = 200;
+
+        final var accountData = account.getAccountData(serviceIdType);
+        accountData.getPreKeyStore().deleteAllStaleOneTimeEcPreKeys(threshold, minCount);
+        accountData.getKyberPreKeyStore().deleteAllStaleOneTimeKyberPreKeys(threshold, minCount);
+    }
 }
index ee202d3928fe4ce25c096ec1ead78efdd76f392b..fb34b4409ea560aaf4baf5d0b80107978f650dc2 100644 (file)
@@ -30,7 +30,7 @@ import java.util.UUID;
 public class AccountDatabase extends Database {
 
     private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
-    private static final long DATABASE_VERSION = 15;
+    private static final long DATABASE_VERSION = 16;
 
     private AccountDatabase(final HikariDataSource dataSource) {
         super(logger, DATABASE_VERSION, dataSource);
@@ -493,5 +493,15 @@ public class AccountDatabase extends Database {
                                         """);
             }
         }
+        if (oldVersion < 16) {
+            logger.debug("Updating database: Adding stale_timestamp prekey field");
+            try (final var statement = connection.createStatement()) {
+                statement.executeUpdate("""
+                                        ALTER TABLE pre_key ADD COLUMN stale_timestamp INTEGER;
+                                        ALTER TABLE kyber_pre_key ADD COLUMN stale_timestamp INTEGER;
+                                        ALTER TABLE kyber_pre_key ADD COLUMN timestamp INTEGER DEFAULT 0;
+                                        """);
+            }
+        }
     }
 }
index 3ff8a864b8e9862746cd03a9863bb85e86347fc8..18a195e6a809db2c507efb5215ec6eadded92d65 100644 (file)
@@ -622,6 +622,11 @@ public class SignalAccount implements Closeable {
         } else {
             aciAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset();
         }
+        if (rootNode.hasNonNull("activeSignedPreKeyId")) {
+            aciAccountData.preKeyMetadata.activeSignedPreKeyId = rootNode.get("activeSignedPreKeyId").asInt(-1);
+        } else {
+            aciAccountData.preKeyMetadata.activeSignedPreKeyId = -1;
+        }
         if (rootNode.hasNonNull("pniPreKeyIdOffset")) {
             pniAccountData.preKeyMetadata.preKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
         } else {
@@ -632,6 +637,11 @@ public class SignalAccount implements Closeable {
         } else {
             pniAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset();
         }
+        if (rootNode.hasNonNull("pniActiveSignedPreKeyId")) {
+            pniAccountData.preKeyMetadata.activeSignedPreKeyId = rootNode.get("pniActiveSignedPreKeyId").asInt(-1);
+        } else {
+            pniAccountData.preKeyMetadata.activeSignedPreKeyId = -1;
+        }
         if (rootNode.hasNonNull("kyberPreKeyIdOffset")) {
             aciAccountData.preKeyMetadata.kyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1);
         } else {
@@ -1003,8 +1013,10 @@ public class SignalAccount implements Closeable {
                     .put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion)
                     .put("preKeyIdOffset", aciAccountData.getPreKeyMetadata().preKeyIdOffset)
                     .put("nextSignedPreKeyId", aciAccountData.getPreKeyMetadata().nextSignedPreKeyId)
+                    .put("activeSignedPreKeyId", aciAccountData.getPreKeyMetadata().activeSignedPreKeyId)
                     .put("pniPreKeyIdOffset", pniAccountData.getPreKeyMetadata().preKeyIdOffset)
                     .put("pniNextSignedPreKeyId", pniAccountData.getPreKeyMetadata().nextSignedPreKeyId)
+                    .put("pniActiveSignedPreKeyId", pniAccountData.getPreKeyMetadata().activeSignedPreKeyId)
                     .put("kyberPreKeyIdOffset", aciAccountData.getPreKeyMetadata().kyberPreKeyIdOffset)
                     .put("activeLastResortKyberPreKeyId",
                             aciAccountData.getPreKeyMetadata().activeLastResortKyberPreKeyId)
@@ -1058,6 +1070,7 @@ public class SignalAccount implements Closeable {
         final var preKeyMetadata = getAccountData(serviceIdType).getPreKeyMetadata();
         preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset();
         preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset();
+        preKeyMetadata.activeSignedPreKeyId = -1;
         save();
     }
 
@@ -1072,6 +1085,7 @@ public class SignalAccount implements Closeable {
                 records.size(),
                 serviceIdType,
                 preKeyMetadata.preKeyIdOffset);
+        accountData.signalProtocolStore.markAllOneTimeEcPreKeysStaleIfNecessary(System.currentTimeMillis());
         for (var record : records) {
             if (preKeyMetadata.preKeyIdOffset != record.getId()) {
                 logger.error("Invalid pre key id {}, expected {}", record.getId(), preKeyMetadata.preKeyIdOffset);
@@ -1095,6 +1109,7 @@ public class SignalAccount implements Closeable {
         }
         accountData.getSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
         preKeyMetadata.nextSignedPreKeyId = (preKeyMetadata.nextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
+        preKeyMetadata.activeSignedPreKeyId = record.getId();
         save();
     }
 
@@ -1112,6 +1127,7 @@ public class SignalAccount implements Closeable {
                 records.size(),
                 serviceIdType,
                 preKeyMetadata.kyberPreKeyIdOffset);
+        accountData.signalProtocolStore.markAllOneTimeEcPreKeysStaleIfNecessary(System.currentTimeMillis());
         for (var record : records) {
             if (preKeyMetadata.kyberPreKeyIdOffset != record.getId()) {
                 logger.error("Invalid kyber pre key id {}, expected {}",
@@ -1745,6 +1761,7 @@ public class SignalAccount implements Closeable {
 
         private int preKeyIdOffset = 1;
         private int nextSignedPreKeyId = 1;
+        private int activeSignedPreKeyId = -1;
         private int kyberPreKeyIdOffset = 1;
         private int activeLastResortKyberPreKeyId = -1;
 
@@ -1756,6 +1773,10 @@ public class SignalAccount implements Closeable {
             return nextSignedPreKeyId;
         }
 
+        public int getActiveSignedPreKeyId() {
+            return activeSignedPreKeyId;
+        }
+
         public int getKyberPreKeyIdOffset() {
             return kyberPreKeyIdOffset;
         }
@@ -1819,17 +1840,17 @@ public class SignalAccount implements Closeable {
                             SignalAccount.this::isMultiDevice));
         }
 
-        private PreKeyStore getPreKeyStore() {
+        public PreKeyStore getPreKeyStore() {
             return getOrCreate(() -> preKeyStore,
                     () -> preKeyStore = new PreKeyStore(getAccountDatabase(), serviceIdType));
         }
 
-        private SignedPreKeyStore getSignedPreKeyStore() {
+        public SignedPreKeyStore getSignedPreKeyStore() {
             return getOrCreate(() -> signedPreKeyStore,
                     () -> signedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), serviceIdType));
         }
 
-        private KyberPreKeyStore getKyberPreKeyStore() {
+        public KyberPreKeyStore getKyberPreKeyStore() {
             return getOrCreate(() -> kyberPreKeyStore,
                     () -> kyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), serviceIdType));
         }
index 9ce660b205718f2b400463b2bf777be7963a5c7f..4067e2fe4e546359eaa50c2bcfa6a72dc10ed81b 100644 (file)
@@ -15,6 +15,8 @@ import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.util.List;
 
+import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_ARCHIVE_AGE;
+
 public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore {
 
     private static final String TABLE_KYBER_PRE_KEY = "kyber_pre_key";
@@ -33,6 +35,8 @@ public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore {
                                       key_id INTEGER NOT NULL,
                                       serialized BLOB NOT NULL,
                                       is_last_resort INTEGER NOT NULL,
+                                      stale_timestamp INTEGER,
+                                      timestamp INTEGER DEFAULT 0,
                                       UNIQUE(account_id_type, key_id)
                                     ) STRICT;
                                     """);
@@ -104,8 +108,8 @@ public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore {
     public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record, final boolean isLastResort) {
         final var sql = (
                 """
-                INSERT INTO %s (account_id_type, key_id, serialized, is_last_resort)
-                VALUES (?, ?, ?, ?)
+                INSERT INTO %s (account_id_type, key_id, serialized, is_last_resort, timestamp)
+                VALUES (?, ?, ?, ?, ?)
                 """
         ).formatted(TABLE_KYBER_PRE_KEY);
         try (final var connection = database.getConnection()) {
@@ -114,6 +118,7 @@ public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore {
                 statement.setInt(2, keyId);
                 statement.setBytes(3, record.serialize());
                 statement.setBoolean(4, isLastResort);
+                statement.setLong(5, record.getTimestamp());
                 statement.executeUpdate();
             }
         } catch (SQLException e) {
@@ -209,13 +214,86 @@ public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore {
         }
     }
 
+    public void removeOldLastResortKyberPreKeys(int activeLastResortKyberPreKeyId) {
+        final var sql = (
+                """
+                DELETE FROM %s AS p
+                WHERE p._id IN (
+                    SELECT p._id
+                    FROM %s AS p
+                    WHERE p.account_id_type = ?
+                        AND p.is_last_resort = TRUE
+                        AND p.key_id != ?
+                        AND p.timestamp < ?
+                    ORDER BY p.timestamp DESC
+                    LIMIT -1 OFFSET 1
+                )
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY, TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setInt(2, activeLastResortKyberPreKeyId);
+                statement.setLong(3, System.currentTimeMillis() - PREKEY_ARCHIVE_AGE);
+                statement.executeUpdate();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update kyber_pre_key store", e);
+        }
+    }
+
     @Override
     public void deleteAllStaleOneTimeKyberPreKeys(final long threshold, final int minCount) {
-        //TODO
+        final var sql = (
+                """
+                DELETE FROM %s AS p
+                WHERE p.account_id_type = ?1
+                    AND p.stale_timestamp < ?2
+                    AND p.is_last_resort = FALSE
+                    AND p._id NOT IN (
+                        SELECT _id
+                        FROM %s p2
+                        WHERE p2.account_id_type = ?1
+                        ORDER BY
+                          CASE WHEN p2.stale_timestamp IS NULL THEN 1 ELSE 0 END DESC,
+                          p2.stale_timestamp DESC,
+                          p2._id DESC
+                        LIMIT ?3
+                    )
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY, TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setLong(2, threshold);
+                statement.setInt(3, minCount);
+                final var rowCount = statement.executeUpdate();
+                if (rowCount > 0) {
+                    logger.debug("Deleted {} stale one time kyber pre keys", rowCount);
+                }
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update kyber_pre_key store", e);
+        }
     }
 
     @Override
     public void markAllOneTimeKyberPreKeysStaleIfNecessary(final long staleTime) {
-        //TODO
+        final var sql = (
+                """
+                UPDATE %s
+                SET stale_timestamp = ?
+                WHERE account_id_type = ? AND stale_timestamp IS NULL AND is_last_resort = FALSE
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setLong(1, staleTime);
+                statement.setInt(2, accountIdType);
+                statement.executeUpdate();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update kyber_pre_key store", e);
+        }
     }
 }
index f6dc271eea7765c982864deb475aac8a1811eb21..c3ac9632cfefcb11e2e1ee959598ff5ea946180f 100644 (file)
@@ -9,6 +9,7 @@ import org.signal.libsignal.protocol.ecc.ECKeyPair;
 import org.signal.libsignal.protocol.state.PreKeyRecord;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.whispersystems.signalservice.api.SignalServicePreKeyStore;
 import org.whispersystems.signalservice.api.push.ServiceIdType;
 
 import java.sql.Connection;
@@ -16,7 +17,7 @@ import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.util.Collection;
 
-public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeyStore {
+public class PreKeyStore implements SignalServicePreKeyStore {
 
     private static final String TABLE_PRE_KEY = "pre_key";
     private final static Logger logger = LoggerFactory.getLogger(PreKeyStore.class);
@@ -34,6 +35,7 @@ public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeySt
                                       key_id INTEGER NOT NULL,
                                       public_key BLOB NOT NULL,
                                       private_key BLOB NOT NULL,
+                                      stale_timestamp INTEGER,
                                       UNIQUE(account_id_type, key_id)
                                     ) STRICT;
                                     """);
@@ -181,4 +183,58 @@ public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeySt
             return null;
         }
     }
+
+    @Override
+    public void deleteAllStaleOneTimeEcPreKeys(final long threshold, final int minCount) {
+        final var sql = (
+                """
+                DELETE FROM %s AS p
+                WHERE p.account_id_type = ?1
+                    AND p.stale_timestamp < ?2
+                    AND p._id NOT IN (
+                        SELECT _id
+                        FROM %s AS p2
+                        WHERE p2.account_id_type = ?1
+                        ORDER BY
+                          CASE WHEN p2.stale_timestamp IS NULL THEN 1 ELSE 0 END DESC,
+                          p2.stale_timestamp DESC,
+                          p2._id DESC
+                        LIMIT ?3
+                    )
+                """
+        ).formatted(TABLE_PRE_KEY, TABLE_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setLong(2, threshold);
+                statement.setInt(3, minCount);
+                final var rowCount = statement.executeUpdate();
+                if (rowCount > 0) {
+                    logger.debug("Deleted {} stale one time pre keys", rowCount);
+                }
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update pre_key store", e);
+        }
+    }
+
+    @Override
+    public void markAllOneTimeEcPreKeysStaleIfNecessary(final long staleTime) {
+        final var sql = (
+                """
+                UPDATE %s
+                SET stale_timestamp = ?
+                WHERE account_id_type = ? AND stale_timestamp IS NULL
+                """
+        ).formatted(TABLE_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setLong(1, staleTime);
+                statement.setInt(2, accountIdType);
+                statement.executeUpdate();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update pre_key store", e);
+        }
+    }
 }
index d618f6c372c6fcca1123f1a1beb8c90ae99d8611..7c726f16930394e96221886e3c876a87147be5a8 100644 (file)
@@ -18,6 +18,8 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Objects;
 
+import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_ARCHIVE_AGE;
+
 public class SignedPreKeyStore implements org.signal.libsignal.protocol.state.SignedPreKeyStore {
 
     private static final String TABLE_SIGNED_PRE_KEY = "signed_pre_key";
@@ -144,6 +146,33 @@ public class SignedPreKeyStore implements org.signal.libsignal.protocol.state.Si
         }
     }
 
+    public void removeOldSignedPreKeys(int activePreKeyId) {
+        final var sql = (
+                """
+                DELETE FROM %s AS p
+                WHERE p._id IN (
+                    SELECT p._id
+                    FROM %s AS p
+                    WHERE p.account_id_type = ?
+                        AND p.key_id != ?
+                        AND p.timestamp < ?
+                    ORDER BY p.timestamp DESC
+                    LIMIT -1 OFFSET 1
+                )
+                """
+        ).formatted(TABLE_SIGNED_PRE_KEY, TABLE_SIGNED_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setInt(2, activePreKeyId);
+                statement.setLong(3, System.currentTimeMillis() - PREKEY_ARCHIVE_AGE);
+                statement.executeUpdate();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update signed_pre_key store", e);
+        }
+    }
+
     void addLegacySignedPreKeys(final Collection<SignedPreKeyRecord> signedPreKeys) {
         logger.debug("Migrating legacy signedPreKeys to database");
         long start = System.nanoTime();
index f042f0e28dd990170b5fe4632de7534aa2c663f7..ea1a5c6fc795bdddeedf5b16303be09096fb7c7d 100644 (file)
@@ -9,12 +9,12 @@ import org.signal.libsignal.protocol.groups.state.SenderKeyRecord;
 import org.signal.libsignal.protocol.state.IdentityKeyStore;
 import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
 import org.signal.libsignal.protocol.state.PreKeyRecord;
-import org.signal.libsignal.protocol.state.PreKeyStore;
 import org.signal.libsignal.protocol.state.SessionRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyStore;
 import org.whispersystems.signalservice.api.SignalServiceAccountDataStore;
 import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore;
+import org.whispersystems.signalservice.api.SignalServicePreKeyStore;
 import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
 import org.whispersystems.signalservice.api.push.DistributionId;
@@ -27,7 +27,7 @@ import java.util.function.Supplier;
 
 public class SignalProtocolStore implements SignalServiceAccountDataStore {
 
-    private final PreKeyStore preKeyStore;
+    private final SignalServicePreKeyStore preKeyStore;
     private final SignedPreKeyStore signedPreKeyStore;
     private final SignalServiceKyberPreKeyStore kyberPreKeyStore;
     private final SignalServiceSessionStore sessionStore;
@@ -36,7 +36,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
     private final Supplier<Boolean> isMultiDevice;
 
     public SignalProtocolStore(
-            final PreKeyStore preKeyStore,
+            final SignalServicePreKeyStore preKeyStore,
             final SignedPreKeyStore signedPreKeyStore,
             final SignalServiceKyberPreKeyStore kyberPreKeyStore,
             final SignalServiceSessionStore sessionStore,
@@ -254,12 +254,12 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
     }
 
     @Override
-    public void deleteAllStaleOneTimeEcPreKeys(final long l, final int i) {
-        // TODO
+    public void deleteAllStaleOneTimeEcPreKeys(final long threshold, final int minCount) {
+        preKeyStore.deleteAllStaleOneTimeEcPreKeys(threshold, minCount);
     }
 
     @Override
-    public void markAllOneTimeEcPreKeysStaleIfNecessary(final long l) {
-        // TODO
+    public void markAllOneTimeEcPreKeysStaleIfNecessary(final long staleTime) {
+        preKeyStore.markAllOneTimeEcPreKeysStaleIfNecessary(staleTime);
     }
 }