From: AsamK Date: Mon, 11 Apr 2022 18:05:02 +0000 (+0200) Subject: Refresh pre keys for PNI identity X-Git-Tag: v0.10.5~3 X-Git-Url: https://git.nmode.ca/signal-cli/commitdiff_plain/945ff44de3df067545412c3e1c0fff8a5d95e810 Refresh pre keys for PNI identity Fixes #930 --- diff --git a/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java b/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java index 52e443da..ea5ccdab 100644 --- a/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java +++ b/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java @@ -45,32 +45,31 @@ public class PreKeyHelper { } public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException { - if (serviceIdType != ServiceIdType.ACI) { - // TODO implement + final var oneTimePreKeys = generatePreKeys(serviceIdType); + final var identityKeyPair = account.getIdentityKeyPair(serviceIdType); + if (identityKeyPair == null) { return; } - var oneTimePreKeys = generatePreKeys(); - final var identityKeyPair = account.getAciIdentityKeyPair(); - var signedPreKeyRecord = generateSignedPreKey(identityKeyPair); + final var signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair); dependencies.getAccountManager() .setPreKeys(serviceIdType, identityKeyPair.getPublicKey(), signedPreKeyRecord, oneTimePreKeys); } - private List generatePreKeys() { - final var offset = account.getPreKeyIdOffset(); + private List generatePreKeys(ServiceIdType serviceIdType) { + final var offset = account.getPreKeyIdOffset(serviceIdType); var records = KeyUtils.generatePreKeyRecords(offset, ServiceConfig.PREKEY_BATCH_SIZE); - account.addPreKeys(records); + account.addPreKeys(serviceIdType, records); return records; } - private SignedPreKeyRecord generateSignedPreKey(IdentityKeyPair identityKeyPair) { - final var signedPreKeyId = account.getNextSignedPreKeyId(); + private SignedPreKeyRecord generateSignedPreKey(ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair) { + final var signedPreKeyId = account.getNextSignedPreKeyId(serviceIdType); var record = KeyUtils.generateSignedPreKeyRecord(identityKeyPair, signedPreKeyId); - account.addSignedPreKey(record); + account.addSignedPreKey(serviceIdType, record); return record; } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java index 1ab383df..66f515ae 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java @@ -53,6 +53,7 @@ import org.whispersystems.signalservice.api.push.ACI; import org.whispersystems.signalservice.api.push.DistributionId; import org.whispersystems.signalservice.api.push.PNI; import org.whispersystems.signalservice.api.push.ServiceId; +import org.whispersystems.signalservice.api.push.ServiceIdType; import org.whispersystems.signalservice.api.push.SignalServiceAddress; import org.whispersystems.signalservice.api.storage.StorageKey; import org.whispersystems.signalservice.api.util.CredentialsProvider; @@ -106,8 +107,10 @@ public class SignalAccount implements Closeable { private StorageKey storageKey; private long storageManifestVersion = -1; private ProfileKey profileKey; - private int preKeyIdOffset = 1; - private int nextSignedPreKeyId = 1; + private int aciPreKeyIdOffset = 1; + private int aciNextSignedPreKeyId = 1; + private int pniPreKeyIdOffset = 1; + private int pniNextSignedPreKeyId = 1; private IdentityKeyPair aciIdentityKeyPair; private IdentityKeyPair pniIdentityKeyPair; private int localRegistrationId; @@ -117,8 +120,10 @@ public class SignalAccount implements Closeable { private boolean registered = false; private SignalProtocolStore signalProtocolStore; - private PreKeyStore preKeyStore; - private SignedPreKeyStore signedPreKeyStore; + private PreKeyStore aciPreKeyStore; + private SignedPreKeyStore aciSignedPreKeyStore; + private PreKeyStore pniPreKeyStore; + private SignedPreKeyStore pniSignedPreKeyStore; private SessionStore sessionStore; private IdentityKeyStore identityKeyStore; private SenderKeyStore senderKeyStore; @@ -259,10 +264,14 @@ public class SignalAccount implements Closeable { } private void clearAllPreKeys() { - this.preKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE); - this.nextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE); - this.getPreKeyStore().removeAllPreKeys(); - this.getSignedPreKeyStore().removeAllSignedPreKeys(); + this.aciPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE); + this.aciNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE); + this.pniPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE); + this.pniNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE); + this.getAciPreKeyStore().removeAllPreKeys(); + this.getAciSignedPreKeyStore().removeAllSignedPreKeys(); + this.getPniPreKeyStore().removeAllPreKeys(); + this.getPniSignedPreKeyStore().removeAllSignedPreKeys(); save(); } @@ -407,14 +416,22 @@ public class SignalAccount implements Closeable { return new File(getUserPath(dataPath, account), "group-cache"); } - private static File getPreKeysPath(File dataPath, String account) { + private static File getAciPreKeysPath(File dataPath, String account) { return new File(getUserPath(dataPath, account), "pre-keys"); } - private static File getSignedPreKeysPath(File dataPath, String account) { + private static File getAciSignedPreKeysPath(File dataPath, String account) { return new File(getUserPath(dataPath, account), "signed-pre-keys"); } + private static File getPniPreKeysPath(File dataPath, String account) { + return new File(getUserPath(dataPath, account), "pre-keys-pni"); + } + + private static File getPniSignedPreKeysPath(File dataPath, String account) { + return new File(getUserPath(dataPath, account), "signed-pre-keys-pni"); + } + private static File getIdentitiesPath(File dataPath, String account) { return new File(getUserPath(dataPath, account), "identities"); } @@ -528,14 +545,24 @@ public class SignalAccount implements Closeable { storageManifestVersion = rootNode.get("storageManifestVersion").asLong(); } if (rootNode.hasNonNull("preKeyIdOffset")) { - preKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1); + aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1); } else { - preKeyIdOffset = 1; + aciPreKeyIdOffset = 1; } if (rootNode.hasNonNull("nextSignedPreKeyId")) { - nextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1); + aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1); } else { - nextSignedPreKeyId = 1; + aciNextSignedPreKeyId = 1; + } + if (rootNode.hasNonNull("pniPreKeyIdOffset")) { + pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1); + } else { + pniPreKeyIdOffset = 1; + } + if (rootNode.hasNonNull("pniNextSignedPreKeyId")) { + pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1); + } else { + pniNextSignedPreKeyId = 1; } if (rootNode.hasNonNull("profileKey")) { try { @@ -618,7 +645,7 @@ public class SignalAccount implements Closeable { logger.debug("Migrating legacy pre key store."); for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) { try { - getPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue())); + getAciPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue())); } catch (InvalidMessageException e) { logger.warn("Failed to migrate pre key, ignoring", e); } @@ -630,7 +657,8 @@ public class SignalAccount implements Closeable { logger.debug("Migrating legacy signed pre key store."); for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) { try { - getSignedPreKeyStore().storeSignedPreKey(entry.getKey(), new SignedPreKeyRecord(entry.getValue())); + getAciSignedPreKeyStore().storeSignedPreKey(entry.getKey(), + new SignedPreKeyRecord(entry.getValue())); } catch (InvalidMessageException e) { logger.warn("Failed to migrate signed pre key, ignoring", e); } @@ -813,8 +841,10 @@ public class SignalAccount implements Closeable { .put("storageKey", storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize())) .put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion) - .put("preKeyIdOffset", preKeyIdOffset) - .put("nextSignedPreKeyId", nextSignedPreKeyId) + .put("preKeyIdOffset", aciPreKeyIdOffset) + .put("nextSignedPreKeyId", aciNextSignedPreKeyId) + .put("pniPreKeyIdOffset", pniPreKeyIdOffset) + .put("pniNextSignedPreKeyId", pniNextSignedPreKeyId) .put("profileKey", profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize())) .put("registered", registered) @@ -852,25 +882,63 @@ public class SignalAccount implements Closeable { return new Pair<>(fileChannel, lock); } - public void addPreKeys(List records) { + public void addPreKeys(ServiceIdType serviceIdType, List records) { + if (serviceIdType.equals(ServiceIdType.ACI)) { + addAciPreKeys(records); + } else { + addPniPreKeys(records); + } + } + + public void addAciPreKeys(List records) { + for (var record : records) { + if (aciPreKeyIdOffset != record.getId()) { + logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset); + throw new AssertionError("Invalid pre key id"); + } + getAciPreKeyStore().storePreKey(record.getId(), record); + aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % Medium.MAX_VALUE; + } + save(); + } + + public void addPniPreKeys(List records) { for (var record : records) { - if (preKeyIdOffset != record.getId()) { - logger.error("Invalid pre key id {}, expected {}", record.getId(), preKeyIdOffset); + if (pniPreKeyIdOffset != record.getId()) { + logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset); throw new AssertionError("Invalid pre key id"); } - getPreKeyStore().storePreKey(record.getId(), record); - preKeyIdOffset = (preKeyIdOffset + 1) % Medium.MAX_VALUE; + getPniPreKeyStore().storePreKey(record.getId(), record); + pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % Medium.MAX_VALUE; } save(); } - public void addSignedPreKey(SignedPreKeyRecord record) { - if (nextSignedPreKeyId != record.getId()) { - logger.error("Invalid signed pre key id {}, expected {}", record.getId(), nextSignedPreKeyId); + public void addSignedPreKey(ServiceIdType serviceIdType, SignedPreKeyRecord record) { + if (serviceIdType.equals(ServiceIdType.ACI)) { + addAciSignedPreKey(record); + } else { + addPniSignedPreKey(record); + } + } + + public void addAciSignedPreKey(SignedPreKeyRecord record) { + if (aciNextSignedPreKeyId != record.getId()) { + logger.error("Invalid signed pre key id {}, expected {}", record.getId(), aciNextSignedPreKeyId); throw new AssertionError("Invalid signed pre key id"); } - getSignedPreKeyStore().storeSignedPreKey(record.getId(), record); - nextSignedPreKeyId = (nextSignedPreKeyId + 1) % Medium.MAX_VALUE; + getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record); + aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % Medium.MAX_VALUE; + save(); + } + + public void addPniSignedPreKey(SignedPreKeyRecord record) { + if (pniNextSignedPreKeyId != record.getId()) { + logger.error("Invalid signed pre key id {}, expected {}", record.getId(), pniNextSignedPreKeyId); + throw new AssertionError("Invalid signed pre key id"); + } + getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record); + pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % Medium.MAX_VALUE; save(); } @@ -906,22 +974,32 @@ public class SignalAccount implements Closeable { public SignalServiceAccountDataStore getSignalServiceAccountDataStore() { return getOrCreate(() -> signalProtocolStore, - () -> signalProtocolStore = new SignalProtocolStore(getPreKeyStore(), - getSignedPreKeyStore(), + () -> signalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(), + getAciSignedPreKeyStore(), getSessionStore(), getIdentityKeyStore(), getSenderKeyStore(), this::isMultiDevice)); } - private PreKeyStore getPreKeyStore() { - return getOrCreate(() -> preKeyStore, - () -> preKeyStore = new PreKeyStore(getPreKeysPath(dataPath, accountPath))); + private PreKeyStore getAciPreKeyStore() { + return getOrCreate(() -> aciPreKeyStore, + () -> aciPreKeyStore = new PreKeyStore(getAciPreKeysPath(dataPath, accountPath))); + } + + private SignedPreKeyStore getAciSignedPreKeyStore() { + return getOrCreate(() -> aciSignedPreKeyStore, + () -> aciSignedPreKeyStore = new SignedPreKeyStore(getAciSignedPreKeysPath(dataPath, accountPath))); } - private SignedPreKeyStore getSignedPreKeyStore() { - return getOrCreate(() -> signedPreKeyStore, - () -> signedPreKeyStore = new SignedPreKeyStore(getSignedPreKeysPath(dataPath, accountPath))); + private PreKeyStore getPniPreKeyStore() { + return getOrCreate(() -> pniPreKeyStore, + () -> pniPreKeyStore = new PreKeyStore(getPniPreKeysPath(dataPath, accountPath))); + } + + private SignedPreKeyStore getPniSignedPreKeyStore() { + return getOrCreate(() -> pniSignedPreKeyStore, + () -> pniSignedPreKeyStore = new SignedPreKeyStore(getPniSignedPreKeysPath(dataPath, accountPath))); } public SessionStore getSessionStore() { @@ -1078,6 +1156,10 @@ public class SignalAccount implements Closeable { return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID; } + public IdentityKeyPair getIdentityKeyPair(ServiceIdType serviceIdType) { + return serviceIdType.equals(ServiceIdType.ACI) ? aciIdentityKeyPair : pniIdentityKeyPair; + } + public IdentityKeyPair getAciIdentityKeyPair() { return aciIdentityKeyPair; } @@ -1157,12 +1239,12 @@ public class SignalAccount implements Closeable { return UnidentifiedAccess.deriveAccessKeyFrom(getProfileKey()); } - public int getPreKeyIdOffset() { - return preKeyIdOffset; + public int getPreKeyIdOffset(ServiceIdType serviceIdType) { + return serviceIdType.equals(ServiceIdType.ACI) ? aciPreKeyIdOffset : pniPreKeyIdOffset; } - public int getNextSignedPreKeyId() { - return nextSignedPreKeyId; + public int getNextSignedPreKeyId(ServiceIdType serviceIdType) { + return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId; } public boolean isRegistered() {