]> nmode's Git Repositories - signal-cli/commitdiff
Refresh pre keys for PNI identity
authorAsamK <asamk@gmx.de>
Mon, 11 Apr 2022 18:05:02 +0000 (20:05 +0200)
committerAsamK <asamk@gmx.de>
Mon, 11 Apr 2022 18:05:02 +0000 (20:05 +0200)
Fixes #930

lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java
lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java

index 52e443dad39e9491bb1fd11172b6e9183ea4406f..ea5ccdabd1545f04a80132fdd734174639c24aca 100644 (file)
@@ -45,32 +45,31 @@ public class PreKeyHelper {
     }
 
     public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
     }
 
     public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
-        if (serviceIdType != ServiceIdType.ACI) {
-            // TODO implement
+        final var oneTimePreKeys = generatePreKeys(serviceIdType);
+        final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
+        if (identityKeyPair == null) {
             return;
         }
             return;
         }
-        var oneTimePreKeys = generatePreKeys();
-        final var identityKeyPair = account.getAciIdentityKeyPair();
-        var signedPreKeyRecord = generateSignedPreKey(identityKeyPair);
+        final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
 
         dependencies.getAccountManager()
                 .setPreKeys(serviceIdType, identityKeyPair.getPublicKey(), signedPreKeyRecord, oneTimePreKeys);
     }
 
 
         dependencies.getAccountManager()
                 .setPreKeys(serviceIdType, identityKeyPair.getPublicKey(), signedPreKeyRecord, oneTimePreKeys);
     }
 
-    private List<PreKeyRecord> generatePreKeys() {
-        final var offset = account.getPreKeyIdOffset();
+    private List<PreKeyRecord> generatePreKeys(ServiceIdType serviceIdType) {
+        final var offset = account.getPreKeyIdOffset(serviceIdType);
 
         var records = KeyUtils.generatePreKeyRecords(offset, ServiceConfig.PREKEY_BATCH_SIZE);
 
         var records = KeyUtils.generatePreKeyRecords(offset, ServiceConfig.PREKEY_BATCH_SIZE);
-        account.addPreKeys(records);
+        account.addPreKeys(serviceIdType, records);
 
         return records;
     }
 
 
         return records;
     }
 
-    private SignedPreKeyRecord generateSignedPreKey(IdentityKeyPair identityKeyPair) {
-        final var signedPreKeyId = account.getNextSignedPreKeyId();
+    private SignedPreKeyRecord generateSignedPreKey(ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair) {
+        final var signedPreKeyId = account.getNextSignedPreKeyId(serviceIdType);
 
         var record = KeyUtils.generateSignedPreKeyRecord(identityKeyPair, signedPreKeyId);
 
         var record = KeyUtils.generateSignedPreKeyRecord(identityKeyPair, signedPreKeyId);
-        account.addSignedPreKey(record);
+        account.addSignedPreKey(serviceIdType, record);
 
         return record;
     }
 
         return record;
     }
index 1ab383dfec5aba5679362d33dfd47405678d9e79..66f515aef0bfc33f60db0501775efb746cfe4619 100644 (file)
@@ -53,6 +53,7 @@ import org.whispersystems.signalservice.api.push.ACI;
 import org.whispersystems.signalservice.api.push.DistributionId;
 import org.whispersystems.signalservice.api.push.PNI;
 import org.whispersystems.signalservice.api.push.ServiceId;
 import org.whispersystems.signalservice.api.push.DistributionId;
 import org.whispersystems.signalservice.api.push.PNI;
 import org.whispersystems.signalservice.api.push.ServiceId;
+import org.whispersystems.signalservice.api.push.ServiceIdType;
 import org.whispersystems.signalservice.api.push.SignalServiceAddress;
 import org.whispersystems.signalservice.api.storage.StorageKey;
 import org.whispersystems.signalservice.api.util.CredentialsProvider;
 import org.whispersystems.signalservice.api.push.SignalServiceAddress;
 import org.whispersystems.signalservice.api.storage.StorageKey;
 import org.whispersystems.signalservice.api.util.CredentialsProvider;
@@ -106,8 +107,10 @@ public class SignalAccount implements Closeable {
     private StorageKey storageKey;
     private long storageManifestVersion = -1;
     private ProfileKey profileKey;
     private StorageKey storageKey;
     private long storageManifestVersion = -1;
     private ProfileKey profileKey;
-    private int preKeyIdOffset = 1;
-    private int nextSignedPreKeyId = 1;
+    private int aciPreKeyIdOffset = 1;
+    private int aciNextSignedPreKeyId = 1;
+    private int pniPreKeyIdOffset = 1;
+    private int pniNextSignedPreKeyId = 1;
     private IdentityKeyPair aciIdentityKeyPair;
     private IdentityKeyPair pniIdentityKeyPair;
     private int localRegistrationId;
     private IdentityKeyPair aciIdentityKeyPair;
     private IdentityKeyPair pniIdentityKeyPair;
     private int localRegistrationId;
@@ -117,8 +120,10 @@ public class SignalAccount implements Closeable {
     private boolean registered = false;
 
     private SignalProtocolStore signalProtocolStore;
     private boolean registered = false;
 
     private SignalProtocolStore signalProtocolStore;
-    private PreKeyStore preKeyStore;
-    private SignedPreKeyStore signedPreKeyStore;
+    private PreKeyStore aciPreKeyStore;
+    private SignedPreKeyStore aciSignedPreKeyStore;
+    private PreKeyStore pniPreKeyStore;
+    private SignedPreKeyStore pniSignedPreKeyStore;
     private SessionStore sessionStore;
     private IdentityKeyStore identityKeyStore;
     private SenderKeyStore senderKeyStore;
     private SessionStore sessionStore;
     private IdentityKeyStore identityKeyStore;
     private SenderKeyStore senderKeyStore;
@@ -259,10 +264,14 @@ public class SignalAccount implements Closeable {
     }
 
     private void clearAllPreKeys() {
     }
 
     private void clearAllPreKeys() {
-        this.preKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
-        this.nextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
-        this.getPreKeyStore().removeAllPreKeys();
-        this.getSignedPreKeyStore().removeAllSignedPreKeys();
+        this.aciPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
+        this.aciNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
+        this.pniPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
+        this.pniNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
+        this.getAciPreKeyStore().removeAllPreKeys();
+        this.getAciSignedPreKeyStore().removeAllSignedPreKeys();
+        this.getPniPreKeyStore().removeAllPreKeys();
+        this.getPniSignedPreKeyStore().removeAllSignedPreKeys();
         save();
     }
 
         save();
     }
 
@@ -407,14 +416,22 @@ public class SignalAccount implements Closeable {
         return new File(getUserPath(dataPath, account), "group-cache");
     }
 
         return new File(getUserPath(dataPath, account), "group-cache");
     }
 
-    private static File getPreKeysPath(File dataPath, String account) {
+    private static File getAciPreKeysPath(File dataPath, String account) {
         return new File(getUserPath(dataPath, account), "pre-keys");
     }
 
         return new File(getUserPath(dataPath, account), "pre-keys");
     }
 
-    private static File getSignedPreKeysPath(File dataPath, String account) {
+    private static File getAciSignedPreKeysPath(File dataPath, String account) {
         return new File(getUserPath(dataPath, account), "signed-pre-keys");
     }
 
         return new File(getUserPath(dataPath, account), "signed-pre-keys");
     }
 
+    private static File getPniPreKeysPath(File dataPath, String account) {
+        return new File(getUserPath(dataPath, account), "pre-keys-pni");
+    }
+
+    private static File getPniSignedPreKeysPath(File dataPath, String account) {
+        return new File(getUserPath(dataPath, account), "signed-pre-keys-pni");
+    }
+
     private static File getIdentitiesPath(File dataPath, String account) {
         return new File(getUserPath(dataPath, account), "identities");
     }
     private static File getIdentitiesPath(File dataPath, String account) {
         return new File(getUserPath(dataPath, account), "identities");
     }
@@ -528,14 +545,24 @@ public class SignalAccount implements Closeable {
             storageManifestVersion = rootNode.get("storageManifestVersion").asLong();
         }
         if (rootNode.hasNonNull("preKeyIdOffset")) {
             storageManifestVersion = rootNode.get("storageManifestVersion").asLong();
         }
         if (rootNode.hasNonNull("preKeyIdOffset")) {
-            preKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
+            aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
         } else {
         } else {
-            preKeyIdOffset = 1;
+            aciPreKeyIdOffset = 1;
         }
         if (rootNode.hasNonNull("nextSignedPreKeyId")) {
         }
         if (rootNode.hasNonNull("nextSignedPreKeyId")) {
-            nextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
+            aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
         } else {
         } else {
-            nextSignedPreKeyId = 1;
+            aciNextSignedPreKeyId = 1;
+        }
+        if (rootNode.hasNonNull("pniPreKeyIdOffset")) {
+            pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
+        } else {
+            pniPreKeyIdOffset = 1;
+        }
+        if (rootNode.hasNonNull("pniNextSignedPreKeyId")) {
+            pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1);
+        } else {
+            pniNextSignedPreKeyId = 1;
         }
         if (rootNode.hasNonNull("profileKey")) {
             try {
         }
         if (rootNode.hasNonNull("profileKey")) {
             try {
@@ -618,7 +645,7 @@ public class SignalAccount implements Closeable {
             logger.debug("Migrating legacy pre key store.");
             for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) {
                 try {
             logger.debug("Migrating legacy pre key store.");
             for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) {
                 try {
-                    getPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue()));
+                    getAciPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue()));
                 } catch (InvalidMessageException e) {
                     logger.warn("Failed to migrate pre key, ignoring", e);
                 }
                 } catch (InvalidMessageException e) {
                     logger.warn("Failed to migrate pre key, ignoring", e);
                 }
@@ -630,7 +657,8 @@ public class SignalAccount implements Closeable {
             logger.debug("Migrating legacy signed pre key store.");
             for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) {
                 try {
             logger.debug("Migrating legacy signed pre key store.");
             for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) {
                 try {
-                    getSignedPreKeyStore().storeSignedPreKey(entry.getKey(), new SignedPreKeyRecord(entry.getValue()));
+                    getAciSignedPreKeyStore().storeSignedPreKey(entry.getKey(),
+                            new SignedPreKeyRecord(entry.getValue()));
                 } catch (InvalidMessageException e) {
                     logger.warn("Failed to migrate signed pre key, ignoring", e);
                 }
                 } catch (InvalidMessageException e) {
                     logger.warn("Failed to migrate signed pre key, ignoring", e);
                 }
@@ -813,8 +841,10 @@ public class SignalAccount implements Closeable {
                     .put("storageKey",
                             storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize()))
                     .put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion)
                     .put("storageKey",
                             storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize()))
                     .put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion)
-                    .put("preKeyIdOffset", preKeyIdOffset)
-                    .put("nextSignedPreKeyId", nextSignedPreKeyId)
+                    .put("preKeyIdOffset", aciPreKeyIdOffset)
+                    .put("nextSignedPreKeyId", aciNextSignedPreKeyId)
+                    .put("pniPreKeyIdOffset", pniPreKeyIdOffset)
+                    .put("pniNextSignedPreKeyId", pniNextSignedPreKeyId)
                     .put("profileKey",
                             profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize()))
                     .put("registered", registered)
                     .put("profileKey",
                             profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize()))
                     .put("registered", registered)
@@ -852,25 +882,63 @@ public class SignalAccount implements Closeable {
         return new Pair<>(fileChannel, lock);
     }
 
         return new Pair<>(fileChannel, lock);
     }
 
-    public void addPreKeys(List<PreKeyRecord> records) {
+    public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) {
+        if (serviceIdType.equals(ServiceIdType.ACI)) {
+            addAciPreKeys(records);
+        } else {
+            addPniPreKeys(records);
+        }
+    }
+
+    public void addAciPreKeys(List<PreKeyRecord> records) {
+        for (var record : records) {
+            if (aciPreKeyIdOffset != record.getId()) {
+                logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset);
+                throw new AssertionError("Invalid pre key id");
+            }
+            getAciPreKeyStore().storePreKey(record.getId(), record);
+            aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % Medium.MAX_VALUE;
+        }
+        save();
+    }
+
+    public void addPniPreKeys(List<PreKeyRecord> records) {
         for (var record : records) {
         for (var record : records) {
-            if (preKeyIdOffset != record.getId()) {
-                logger.error("Invalid pre key id {}, expected {}", record.getId(), preKeyIdOffset);
+            if (pniPreKeyIdOffset != record.getId()) {
+                logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset);
                 throw new AssertionError("Invalid pre key id");
             }
                 throw new AssertionError("Invalid pre key id");
             }
-            getPreKeyStore().storePreKey(record.getId(), record);
-            preKeyIdOffset = (preKeyIdOffset + 1) % Medium.MAX_VALUE;
+            getPniPreKeyStore().storePreKey(record.getId(), record);
+            pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % Medium.MAX_VALUE;
         }
         save();
     }
 
         }
         save();
     }
 
-    public void addSignedPreKey(SignedPreKeyRecord record) {
-        if (nextSignedPreKeyId != record.getId()) {
-            logger.error("Invalid signed pre key id {}, expected {}", record.getId(), nextSignedPreKeyId);
+    public void addSignedPreKey(ServiceIdType serviceIdType, SignedPreKeyRecord record) {
+        if (serviceIdType.equals(ServiceIdType.ACI)) {
+            addAciSignedPreKey(record);
+        } else {
+            addPniSignedPreKey(record);
+        }
+    }
+
+    public void addAciSignedPreKey(SignedPreKeyRecord record) {
+        if (aciNextSignedPreKeyId != record.getId()) {
+            logger.error("Invalid signed pre key id {}, expected {}", record.getId(), aciNextSignedPreKeyId);
             throw new AssertionError("Invalid signed pre key id");
         }
             throw new AssertionError("Invalid signed pre key id");
         }
-        getSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
-        nextSignedPreKeyId = (nextSignedPreKeyId + 1) % Medium.MAX_VALUE;
+        getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
+        aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
+        save();
+    }
+
+    public void addPniSignedPreKey(SignedPreKeyRecord record) {
+        if (pniNextSignedPreKeyId != record.getId()) {
+            logger.error("Invalid signed pre key id {}, expected {}", record.getId(), pniNextSignedPreKeyId);
+            throw new AssertionError("Invalid signed pre key id");
+        }
+        getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
+        pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
         save();
     }
 
         save();
     }
 
@@ -906,22 +974,32 @@ public class SignalAccount implements Closeable {
 
     public SignalServiceAccountDataStore getSignalServiceAccountDataStore() {
         return getOrCreate(() -> signalProtocolStore,
 
     public SignalServiceAccountDataStore getSignalServiceAccountDataStore() {
         return getOrCreate(() -> signalProtocolStore,
-                () -> signalProtocolStore = new SignalProtocolStore(getPreKeyStore(),
-                        getSignedPreKeyStore(),
+                () -> signalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(),
+                        getAciSignedPreKeyStore(),
                         getSessionStore(),
                         getIdentityKeyStore(),
                         getSenderKeyStore(),
                         this::isMultiDevice));
     }
 
                         getSessionStore(),
                         getIdentityKeyStore(),
                         getSenderKeyStore(),
                         this::isMultiDevice));
     }
 
-    private PreKeyStore getPreKeyStore() {
-        return getOrCreate(() -> preKeyStore,
-                () -> preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, accountPath)));
+    private PreKeyStore getAciPreKeyStore() {
+        return getOrCreate(() -> aciPreKeyStore,
+                () -> aciPreKeyStore = new PreKeyStore(getAciPreKeysPath(dataPath, accountPath)));
+    }
+
+    private SignedPreKeyStore getAciSignedPreKeyStore() {
+        return getOrCreate(() -> aciSignedPreKeyStore,
+                () -> aciSignedPreKeyStore = new SignedPreKeyStore(getAciSignedPreKeysPath(dataPath, accountPath)));
     }
 
     }
 
-    private SignedPreKeyStore getSignedPreKeyStore() {
-        return getOrCreate(() -> signedPreKeyStore,
-                () -> signedPreKeyStore = new SignedPreKeyStore(getSignedPreKeysPath(dataPath, accountPath)));
+    private PreKeyStore getPniPreKeyStore() {
+        return getOrCreate(() -> pniPreKeyStore,
+                () -> pniPreKeyStore = new PreKeyStore(getPniPreKeysPath(dataPath, accountPath)));
+    }
+
+    private SignedPreKeyStore getPniSignedPreKeyStore() {
+        return getOrCreate(() -> pniSignedPreKeyStore,
+                () -> pniSignedPreKeyStore = new SignedPreKeyStore(getPniSignedPreKeysPath(dataPath, accountPath)));
     }
 
     public SessionStore getSessionStore() {
     }
 
     public SessionStore getSessionStore() {
@@ -1078,6 +1156,10 @@ public class SignalAccount implements Closeable {
         return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID;
     }
 
         return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID;
     }
 
+    public IdentityKeyPair getIdentityKeyPair(ServiceIdType serviceIdType) {
+        return serviceIdType.equals(ServiceIdType.ACI) ? aciIdentityKeyPair : pniIdentityKeyPair;
+    }
+
     public IdentityKeyPair getAciIdentityKeyPair() {
         return aciIdentityKeyPair;
     }
     public IdentityKeyPair getAciIdentityKeyPair() {
         return aciIdentityKeyPair;
     }
@@ -1157,12 +1239,12 @@ public class SignalAccount implements Closeable {
         return UnidentifiedAccess.deriveAccessKeyFrom(getProfileKey());
     }
 
         return UnidentifiedAccess.deriveAccessKeyFrom(getProfileKey());
     }
 
-    public int getPreKeyIdOffset() {
-        return preKeyIdOffset;
+    public int getPreKeyIdOffset(ServiceIdType serviceIdType) {
+        return serviceIdType.equals(ServiceIdType.ACI) ? aciPreKeyIdOffset : pniPreKeyIdOffset;
     }
 
     }
 
-    public int getNextSignedPreKeyId() {
-        return nextSignedPreKeyId;
+    public int getNextSignedPreKeyId(ServiceIdType serviceIdType) {
+        return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId;
     }
 
     public boolean isRegistered() {
     }
 
     public boolean isRegistered() {