]> nmode's Git Repositories - signal-cli/blobdiff - lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java
Add pre key cleanup and improve refresh
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / helper / PreKeyHelper.java
index 54eab17bf521304fd4979e2e78e55a51578804ee..34d11171943430a89dede2e32811f93532d79bcf 100644 (file)
@@ -5,6 +5,7 @@ import org.asamk.signal.manager.internal.SignalDependencies;
 import org.asamk.signal.manager.storage.SignalAccount;
 import org.asamk.signal.manager.util.KeyUtils;
 import org.signal.libsignal.protocol.IdentityKeyPair;
+import org.signal.libsignal.protocol.InvalidKeyIdException;
 import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
 import org.signal.libsignal.protocol.state.PreKeyRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
@@ -18,6 +19,9 @@ import org.whispersystems.signalservice.internal.push.OneTimePreKeyCounts;
 import java.io.IOException;
 import java.util.List;
 
+import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_STALE_AGE;
+import static org.asamk.signal.manager.config.ServiceConfig.SIGNED_PREKEY_ROTATE_AGE;
+
 public class PreKeyHelper {
 
     private final static Logger logger = LoggerFactory.getLogger(PreKeyHelper.class);
@@ -38,30 +42,6 @@ public class PreKeyHelper {
     }
 
     public void refreshPreKeysIfNecessary(ServiceIdType serviceIdType) throws IOException {
-        OneTimePreKeyCounts preKeyCounts;
-        try {
-            preKeyCounts = dependencies.getAccountManager().getPreKeyCounts(serviceIdType);
-        } catch (AuthorizationFailedException e) {
-            logger.debug("Failed to get pre key count, ignoring: " + e.getClass().getSimpleName());
-            preKeyCounts = new OneTimePreKeyCounts(0, 0);
-        }
-        if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
-            logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain",
-                    serviceIdType,
-                    preKeyCounts.getEcCount(),
-                    ServiceConfig.PREKEY_MINIMUM_COUNT);
-            refreshPreKeys(serviceIdType);
-        }
-        if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
-            logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain",
-                    serviceIdType,
-                    preKeyCounts.getKyberCount(),
-                    ServiceConfig.PREKEY_MINIMUM_COUNT);
-            refreshKyberPreKeys(serviceIdType);
-        }
-    }
-
-    private void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
         final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
         if (identityKeyPair == null) {
             return;
@@ -70,28 +50,73 @@ public class PreKeyHelper {
         if (accountId == null) {
             return;
         }
+
+        OneTimePreKeyCounts preKeyCounts;
         try {
-            refreshPreKeys(serviceIdType, identityKeyPair);
+            preKeyCounts = dependencies.getAccountManager().getPreKeyCounts(serviceIdType);
+        } catch (AuthorizationFailedException e) {
+            logger.debug("Failed to get pre key count, ignoring: " + e.getClass().getSimpleName());
+            preKeyCounts = new OneTimePreKeyCounts(0, 0);
+        }
+
+        SignedPreKeyRecord signedPreKeyRecord = null;
+        List<PreKeyRecord> preKeyRecords = null;
+        KyberPreKeyRecord lastResortKyberPreKeyRecord = null;
+        List<KyberPreKeyRecord> kyberPreKeyRecords = null;
+
+        try {
+            if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
+                logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain",
+                        serviceIdType,
+                        preKeyCounts.getEcCount(),
+                        ServiceConfig.PREKEY_MINIMUM_COUNT);
+                preKeyRecords = generatePreKeys(serviceIdType);
+            }
+            if (signedPreKeyNeedsRefresh(serviceIdType)) {
+                logger.debug("Refreshing {} signed pre key.", serviceIdType);
+                signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
+            }
         } catch (Exception e) {
             logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
             account.resetPreKeyOffsets(serviceIdType);
-            refreshPreKeys(serviceIdType, identityKeyPair);
+            preKeyRecords = generatePreKeys(serviceIdType);
+            signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
+        }
+
+        try {
+            if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
+                logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain",
+                        serviceIdType,
+                        preKeyCounts.getKyberCount(),
+                        ServiceConfig.PREKEY_MINIMUM_COUNT);
+                kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair);
+            }
+            if (lastResortKyberPreKeyNeedsRefresh(serviceIdType)) {
+                logger.debug("Refreshing {} last resort kyber pre key.", serviceIdType);
+                lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
+            }
+        } catch (Exception e) {
+            logger.warn("Failed to store new kyber pre keys, resetting preKey id offset", e);
+            account.resetKyberPreKeyOffsets(serviceIdType);
+            kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair);
+            lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
+        }
+
+        if (signedPreKeyRecord != null
+                || preKeyRecords != null
+                || lastResortKyberPreKeyRecord != null
+                || kyberPreKeyRecords != null) {
+            final var preKeyUpload = new PreKeyUpload(serviceIdType,
+                    identityKeyPair.getPublicKey(),
+                    signedPreKeyRecord,
+                    preKeyRecords,
+                    lastResortKyberPreKeyRecord,
+                    kyberPreKeyRecords);
+            dependencies.getAccountManager().setPreKeys(preKeyUpload);
         }
-    }
 
-    private void refreshPreKeys(
-            final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
-    ) throws IOException {
-        final var oneTimePreKeys = generatePreKeys(serviceIdType);
-        final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
-
-        final var preKeyUpload = new PreKeyUpload(serviceIdType,
-                identityKeyPair.getPublicKey(),
-                signedPreKeyRecord,
-                oneTimePreKeys,
-                null,
-                null);
-        dependencies.getAccountManager().setPreKeys(preKeyUpload);
+        cleanSignedPreKeys((serviceIdType));
+        cleanOneTimePreKeys(serviceIdType);
     }
 
     private List<PreKeyRecord> generatePreKeys(ServiceIdType serviceIdType) {
@@ -103,6 +128,21 @@ public class PreKeyHelper {
         return records;
     }
 
+    private boolean signedPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
+        final var accountData = account.getAccountData(serviceIdType);
+
+        final var activeSignedPreKeyId = accountData.getPreKeyMetadata().getActiveSignedPreKeyId();
+        if (activeSignedPreKeyId == -1) {
+            return true;
+        }
+        try {
+            final var signedPreKeyRecord = accountData.getSignedPreKeyStore().loadSignedPreKey(activeSignedPreKeyId);
+            return signedPreKeyRecord.getTimestamp() < System.currentTimeMillis() - SIGNED_PREKEY_ROTATE_AGE;
+        } catch (InvalidKeyIdException e) {
+            return true;
+        }
+    }
+
     private SignedPreKeyRecord generateSignedPreKey(ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair) {
         final var signedPreKeyId = account.getNextSignedPreKeyId(serviceIdType);
 
@@ -112,39 +152,6 @@ public class PreKeyHelper {
         return record;
     }
 
-    private void refreshKyberPreKeys(ServiceIdType serviceIdType) throws IOException {
-        final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
-        if (identityKeyPair == null) {
-            return;
-        }
-        final var accountId = account.getAccountId(serviceIdType);
-        if (accountId == null) {
-            return;
-        }
-        try {
-            refreshKyberPreKeys(serviceIdType, identityKeyPair);
-        } catch (Exception e) {
-            logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
-            account.resetKyberPreKeyOffsets(serviceIdType);
-            refreshKyberPreKeys(serviceIdType, identityKeyPair);
-        }
-    }
-
-    private void refreshKyberPreKeys(
-            final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
-    ) throws IOException {
-        final var oneTimePreKeys = generateKyberPreKeys(serviceIdType, identityKeyPair);
-        final var lastResortPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
-
-        final var preKeyUpload = new PreKeyUpload(serviceIdType,
-                identityKeyPair.getPublicKey(),
-                null,
-                null,
-                lastResortPreKeyRecord,
-                oneTimePreKeys);
-        dependencies.getAccountManager().setPreKeys(preKeyUpload);
-    }
-
     private List<KyberPreKeyRecord> generateKyberPreKeys(
             ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
     ) {
@@ -156,6 +163,22 @@ public class PreKeyHelper {
         return records;
     }
 
+    private boolean lastResortKyberPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
+        final var accountData = account.getAccountData(serviceIdType);
+
+        final var activeLastResortKyberPreKeyId = accountData.getPreKeyMetadata().getActiveLastResortKyberPreKeyId();
+        if (activeLastResortKyberPreKeyId == -1) {
+            return true;
+        }
+        try {
+            final var kyberPreKeyRecord = accountData.getKyberPreKeyStore()
+                    .loadKyberPreKey(activeLastResortKyberPreKeyId);
+            return kyberPreKeyRecord.getTimestamp() < System.currentTimeMillis() - SIGNED_PREKEY_ROTATE_AGE;
+        } catch (InvalidKeyIdException e) {
+            return true;
+        }
+    }
+
     private KyberPreKeyRecord generateLastResortKyberPreKey(
             ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair
     ) {
@@ -166,4 +189,23 @@ public class PreKeyHelper {
 
         return record;
     }
+
+    private void cleanSignedPreKeys(ServiceIdType serviceIdType) {
+        final var accountData = account.getAccountData(serviceIdType);
+
+        final var activeSignedPreKeyId = accountData.getPreKeyMetadata().getActiveSignedPreKeyId();
+        accountData.getSignedPreKeyStore().removeOldSignedPreKeys(activeSignedPreKeyId);
+
+        final var activeLastResortKyberPreKeyId = accountData.getPreKeyMetadata().getActiveLastResortKyberPreKeyId();
+        accountData.getKyberPreKeyStore().removeOldLastResortKyberPreKeys(activeLastResortKyberPreKeyId);
+    }
+
+    private void cleanOneTimePreKeys(ServiceIdType serviceIdType) {
+        long threshold = System.currentTimeMillis() - PREKEY_STALE_AGE;
+        int minCount = 200;
+
+        final var accountData = account.getAccountData(serviceIdType);
+        accountData.getPreKeyStore().deleteAllStaleOneTimeEcPreKeys(threshold, minCount);
+        accountData.getKyberPreKeyStore().deleteAllStaleOneTimeKyberPreKeys(threshold, minCount);
+    }
 }