]> nmode's Git Repositories - signal-cli/blobdiff - lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java
Update libsignal-service
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / helper / PreKeyHelper.java
index 06dc31b91d0431ead7052a20dc99063be643c6bf..bf7ad580bafde15e6668b5abfe7f2d066e0548dc 100644 (file)
@@ -5,26 +5,34 @@ import org.asamk.signal.manager.internal.SignalDependencies;
 import org.asamk.signal.manager.storage.SignalAccount;
 import org.asamk.signal.manager.util.KeyUtils;
 import org.signal.libsignal.protocol.IdentityKeyPair;
+import org.signal.libsignal.protocol.InvalidKeyIdException;
+import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
 import org.signal.libsignal.protocol.state.PreKeyRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.whispersystems.signalservice.api.NetworkResultUtil;
 import org.whispersystems.signalservice.api.account.PreKeyUpload;
+import org.whispersystems.signalservice.api.keys.OneTimePreKeyCounts;
 import org.whispersystems.signalservice.api.push.ServiceIdType;
+import org.whispersystems.signalservice.api.push.exceptions.AuthorizationFailedException;
+import org.whispersystems.signalservice.api.push.exceptions.NonSuccessfulResponseCodeException;
 
 import java.io.IOException;
 import java.util.List;
 
+import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_STALE_AGE;
+import static org.asamk.signal.manager.config.ServiceConfig.SIGNED_PREKEY_ROTATE_AGE;
+import static org.asamk.signal.manager.util.Utils.handleResponseException;
+
 public class PreKeyHelper {
 
-    private final static Logger logger = LoggerFactory.getLogger(PreKeyHelper.class);
+    private static final Logger logger = LoggerFactory.getLogger(PreKeyHelper.class);
 
     private final SignalAccount account;
     private final SignalDependencies dependencies;
 
-    public PreKeyHelper(
-            final SignalAccount account, final SignalDependencies dependencies
-    ) {
+    public PreKeyHelper(final SignalAccount account, final SignalDependencies dependencies) {
         this.account = account;
         this.dependencies = dependencies;
     }
@@ -34,20 +42,27 @@ public class PreKeyHelper {
         refreshPreKeysIfNecessary(ServiceIdType.PNI);
     }
 
+    public void forceRefreshPreKeys() throws IOException {
+        forceRefreshPreKeys(ServiceIdType.ACI);
+        forceRefreshPreKeys(ServiceIdType.PNI);
+    }
+
     public void refreshPreKeysIfNecessary(ServiceIdType serviceIdType) throws IOException {
-        final var preKeyCounts = dependencies.getAccountManager().getPreKeyCounts(serviceIdType);
-        if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
-            refreshPreKeys(serviceIdType);
+        final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
+        if (identityKeyPair == null) {
+            return;
+        }
+        final var accountId = account.getAccountId(serviceIdType);
+        if (accountId == null) {
+            return;
         }
-        // TODO kyber pre keys
-    }
 
-    public void refreshPreKeys() throws IOException {
-        refreshPreKeys(ServiceIdType.ACI);
-        refreshPreKeys(ServiceIdType.PNI);
+        if (refreshPreKeysIfNecessary(serviceIdType, identityKeyPair)) {
+            refreshPreKeysIfNecessary(serviceIdType, identityKeyPair);
+        }
     }
 
-    public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
+    public void forceRefreshPreKeys(ServiceIdType serviceIdType) throws IOException {
         final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
         if (identityKeyPair == null) {
             return;
@@ -56,45 +71,209 @@ public class PreKeyHelper {
         if (accountId == null) {
             return;
         }
+
+        final var counts = new OneTimePreKeyCounts(0, 0);
+        if (refreshPreKeysIfNecessary(serviceIdType, identityKeyPair, counts, true)) {
+            refreshPreKeysIfNecessary(serviceIdType, identityKeyPair, counts, true);
+        }
+    }
+
+    private boolean refreshPreKeysIfNecessary(
+            final ServiceIdType serviceIdType,
+            final IdentityKeyPair identityKeyPair
+    ) throws IOException {
+        OneTimePreKeyCounts preKeyCounts;
         try {
-            refreshPreKeys(serviceIdType, identityKeyPair);
-        } catch (Exception e) {
-            logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
-            account.resetPreKeyOffsets(serviceIdType);
-            refreshPreKeys(serviceIdType, identityKeyPair);
+            preKeyCounts = handleResponseException(dependencies.getKeysApi().getAvailablePreKeyCounts(serviceIdType));
+        } catch (AuthorizationFailedException e) {
+            logger.debug("Failed to get pre key count, ignoring: " + e.getClass().getSimpleName());
+            preKeyCounts = new OneTimePreKeyCounts(0, 0);
         }
+
+        return refreshPreKeysIfNecessary(serviceIdType, identityKeyPair, preKeyCounts, false);
     }
 
-    private void refreshPreKeys(
-            final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
+    private boolean refreshPreKeysIfNecessary(
+            final ServiceIdType serviceIdType,
+            final IdentityKeyPair identityKeyPair,
+            final OneTimePreKeyCounts preKeyCounts,
+            final boolean force
     ) throws IOException {
-        final var oneTimePreKeys = generatePreKeys(serviceIdType);
-        final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
+        List<PreKeyRecord> preKeyRecords = null;
+        if (force || preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
+            logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain",
+                    serviceIdType,
+                    preKeyCounts.getEcCount(),
+                    ServiceConfig.PREKEY_MINIMUM_COUNT);
+            preKeyRecords = generatePreKeys(serviceIdType);
+        }
+
+        SignedPreKeyRecord signedPreKeyRecord = null;
+        if (force || signedPreKeyNeedsRefresh(serviceIdType)) {
+            logger.debug("Refreshing {} signed pre key.", serviceIdType);
+            signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
+        }
+
+        List<KyberPreKeyRecord> kyberPreKeyRecords = null;
+        if (force || preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
+            logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain",
+                    serviceIdType,
+                    preKeyCounts.getKyberCount(),
+                    ServiceConfig.PREKEY_MINIMUM_COUNT);
+            kyberPreKeyRecords = generateKyberPreKeys(serviceIdType, identityKeyPair);
+        }
+
+        KyberPreKeyRecord lastResortKyberPreKeyRecord = null;
+        if (force || lastResortKyberPreKeyNeedsRefresh(serviceIdType)) {
+            logger.debug("Refreshing {} last resort kyber pre key.", serviceIdType);
+            lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType,
+                    identityKeyPair,
+                    kyberPreKeyRecords == null ? 0 : kyberPreKeyRecords.size());
+        }
+
+        if (signedPreKeyRecord == null
+                && preKeyRecords == null
+                && lastResortKyberPreKeyRecord == null
+                && kyberPreKeyRecords == null) {
+            return false;
+        }
 
         final var preKeyUpload = new PreKeyUpload(serviceIdType,
-                identityKeyPair.getPublicKey(),
                 signedPreKeyRecord,
-                oneTimePreKeys,
-                null,
-                null);
-        dependencies.getAccountManager().setPreKeys(preKeyUpload);
+                preKeyRecords,
+                lastResortKyberPreKeyRecord,
+                kyberPreKeyRecords);
+        var needsReset = false;
+        try {
+            NetworkResultUtil.toPreKeysLegacy(dependencies.getKeysApi().setPreKeys(preKeyUpload));
+            try {
+                if (preKeyRecords != null) {
+                    account.addPreKeys(serviceIdType, preKeyRecords);
+                }
+                if (signedPreKeyRecord != null) {
+                    account.addSignedPreKey(serviceIdType, signedPreKeyRecord);
+                }
+            } catch (Exception e) {
+                logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
+                account.resetPreKeyOffsets(serviceIdType);
+                needsReset = true;
+            }
+            try {
+                if (kyberPreKeyRecords != null) {
+                    account.addKyberPreKeys(serviceIdType, kyberPreKeyRecords);
+                }
+                if (lastResortKyberPreKeyRecord != null) {
+                    account.addLastResortKyberPreKey(serviceIdType, lastResortKyberPreKeyRecord);
+                }
+            } catch (Exception e) {
+                logger.warn("Failed to store new kyber pre keys, resetting preKey id offset", e);
+                account.resetKyberPreKeyOffsets(serviceIdType);
+                needsReset = true;
+            }
+        } catch (AuthorizationFailedException e) {
+            // This can happen when the primary device has changed phone number
+            logger.warn("Failed to updated pre keys: {}", e.getMessage());
+        } catch (NonSuccessfulResponseCodeException e) {
+            if (serviceIdType != ServiceIdType.PNI || e.code != 422) {
+                throw e;
+            }
+            logger.warn("Failed to set PNI pre keys, ignoring for now. Account needs to be reregistered to fix this.");
+        }
+        return needsReset;
+    }
+
+    public void cleanOldPreKeys() {
+        cleanOldPreKeys(ServiceIdType.ACI);
+        cleanOldPreKeys(ServiceIdType.PNI);
+    }
+
+    private void cleanOldPreKeys(final ServiceIdType serviceIdType) {
+        cleanSignedPreKeys(serviceIdType);
+        cleanOneTimePreKeys(serviceIdType);
     }
 
     private List<PreKeyRecord> generatePreKeys(ServiceIdType serviceIdType) {
-        final var offset = account.getPreKeyIdOffset(serviceIdType);
+        final var accountData = account.getAccountData(serviceIdType);
+        final var offset = accountData.getPreKeyMetadata().getNextPreKeyId();
 
-        var records = KeyUtils.generatePreKeyRecords(offset, ServiceConfig.PREKEY_BATCH_SIZE);
-        account.addPreKeys(serviceIdType, records);
+        return KeyUtils.generatePreKeyRecords(offset);
+    }
+
+    private boolean signedPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
+        final var accountData = account.getAccountData(serviceIdType);
 
-        return records;
+        final var activeSignedPreKeyId = accountData.getPreKeyMetadata().getActiveSignedPreKeyId();
+        if (activeSignedPreKeyId == -1) {
+            return true;
+        }
+        try {
+            final var signedPreKeyRecord = accountData.getSignedPreKeyStore().loadSignedPreKey(activeSignedPreKeyId);
+            return signedPreKeyRecord.getTimestamp() < System.currentTimeMillis() - SIGNED_PREKEY_ROTATE_AGE;
+        } catch (InvalidKeyIdException e) {
+            return true;
+        }
     }
 
     private SignedPreKeyRecord generateSignedPreKey(ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair) {
-        final var signedPreKeyId = account.getNextSignedPreKeyId(serviceIdType);
+        final var accountData = account.getAccountData(serviceIdType);
+        final var signedPreKeyId = accountData.getPreKeyMetadata().getNextSignedPreKeyId();
+
+        return KeyUtils.generateSignedPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
+    }
+
+    private List<KyberPreKeyRecord> generateKyberPreKeys(
+            ServiceIdType serviceIdType,
+            final IdentityKeyPair identityKeyPair
+    ) {
+        final var accountData = account.getAccountData(serviceIdType);
+        final var offset = accountData.getPreKeyMetadata().getNextKyberPreKeyId();
+
+        return KeyUtils.generateKyberPreKeyRecords(offset, identityKeyPair.getPrivateKey());
+    }
+
+    private boolean lastResortKyberPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
+        final var accountData = account.getAccountData(serviceIdType);
+
+        final var activeLastResortKyberPreKeyId = accountData.getPreKeyMetadata().getActiveLastResortKyberPreKeyId();
+        if (activeLastResortKyberPreKeyId == -1) {
+            return true;
+        }
+        try {
+            final var kyberPreKeyRecord = accountData.getKyberPreKeyStore()
+                    .loadKyberPreKey(activeLastResortKyberPreKeyId);
+            return kyberPreKeyRecord.getTimestamp() < System.currentTimeMillis() - SIGNED_PREKEY_ROTATE_AGE;
+        } catch (InvalidKeyIdException e) {
+            return true;
+        }
+    }
+
+    private KyberPreKeyRecord generateLastResortKyberPreKey(
+            ServiceIdType serviceIdType,
+            IdentityKeyPair identityKeyPair,
+            final int offset
+    ) {
+        final var accountData = account.getAccountData(serviceIdType);
+        final var signedPreKeyId = accountData.getPreKeyMetadata().getNextKyberPreKeyId() + offset;
+
+        return KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
+    }
+
+    private void cleanSignedPreKeys(ServiceIdType serviceIdType) {
+        final var accountData = account.getAccountData(serviceIdType);
+
+        final var activeSignedPreKeyId = accountData.getPreKeyMetadata().getActiveSignedPreKeyId();
+        accountData.getSignedPreKeyStore().removeOldSignedPreKeys(activeSignedPreKeyId);
+
+        final var activeLastResortKyberPreKeyId = accountData.getPreKeyMetadata().getActiveLastResortKyberPreKeyId();
+        accountData.getKyberPreKeyStore().removeOldLastResortKyberPreKeys(activeLastResortKyberPreKeyId);
+    }
 
-        var record = KeyUtils.generateSignedPreKeyRecord(identityKeyPair, signedPreKeyId);
-        account.addSignedPreKey(serviceIdType, record);
+    private void cleanOneTimePreKeys(ServiceIdType serviceIdType) {
+        long threshold = System.currentTimeMillis() - PREKEY_STALE_AGE;
+        int minCount = 200;
 
-        return record;
+        final var accountData = account.getAccountData(serviceIdType);
+        accountData.getPreKeyStore().deleteAllStaleOneTimeEcPreKeys(threshold, minCount);
+        accountData.getKyberPreKeyStore().deleteAllStaleOneTimeKyberPreKeys(threshold, minCount);
     }
 }