]> nmode's Git Repositories - signal-cli/commitdiff
Refactor ACI/PNI store handling
authorAsamK <asamk@gmx.de>
Sun, 18 Jun 2023 12:44:57 +0000 (14:44 +0200)
committerAsamK <asamk@gmx.de>
Sun, 18 Jun 2023 15:54:52 +0000 (17:54 +0200)
lib/src/main/java/org/asamk/signal/manager/actions/RenewSessionAction.java
lib/src/main/java/org/asamk/signal/manager/actions/SendRetryMessageRequestAction.java
lib/src/main/java/org/asamk/signal/manager/helper/IncomingMessageHandler.java
lib/src/main/java/org/asamk/signal/manager/internal/ManagerImpl.java
lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java
lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java

index 2718bc26349648dce363590464b92d06fafe0180..401172463af9d5d4e4f6f5c9242ce943146e70bc 100644 (file)
@@ -8,15 +8,17 @@ public class RenewSessionAction implements HandleAction {
 
     private final RecipientId recipientId;
     private final ServiceId serviceId;
+    private final ServiceId accountId;
 
-    public RenewSessionAction(final RecipientId recipientId, final ServiceId serviceId) {
+    public RenewSessionAction(final RecipientId recipientId, final ServiceId serviceId, final ServiceId accountId) {
         this.recipientId = recipientId;
         this.serviceId = serviceId;
+        this.accountId = accountId;
     }
 
     @Override
     public void execute(Context context) throws Throwable {
-        context.getAccount().getAciSessionStore().archiveSessions(serviceId);
+        context.getAccount().getAccountData(accountId).getSessionStore().archiveSessions(serviceId);
         if (!recipientId.equals(context.getAccount().getSelfRecipientId())) {
             context.getSendHelper().sendNullMessage(recipientId);
         }
index 2300eae9d39c864cad97fb0e2b053aa7e638de3a..add09e72efd44515cd78f2080c59de868c754a11 100644 (file)
@@ -18,22 +18,25 @@ public class SendRetryMessageRequestAction implements HandleAction {
     private final ServiceId serviceId;
     private final ProtocolException protocolException;
     private final SignalServiceEnvelope envelope;
+    private final ServiceId accountId;
 
     public SendRetryMessageRequestAction(
             final RecipientId recipientId,
             final ServiceId serviceId,
             final ProtocolException protocolException,
-            final SignalServiceEnvelope envelope
+            final SignalServiceEnvelope envelope,
+            final ServiceId accountId
     ) {
         this.recipientId = recipientId;
         this.serviceId = serviceId;
         this.protocolException = protocolException;
         this.envelope = envelope;
+        this.accountId = accountId;
     }
 
     @Override
     public void execute(Context context) throws Throwable {
-        context.getAccount().getAciSessionStore().archiveSessions(serviceId);
+        context.getAccount().getAccountData(accountId).getSessionStore().archiveSessions(serviceId);
 
         int senderDevice = protocolException.getSenderDevice();
         Optional<GroupId> groupId = protocolException.getGroupId().isPresent() ? Optional.of(GroupId.unknownVersion(
index 3b616644d80be8e7d20fc9948c993892e37fc45e..7cc3d09df06a21e740710f0b80de1286dc1152b8 100644 (file)
@@ -181,12 +181,13 @@ public final class IncomingMessageHandler {
                                 .contains(Profile.Capability.senderKey);
                         final var isSelfSenderKeyCapable = selfProfile != null && selfProfile.getCapabilities()
                                 .contains(Profile.Capability.senderKey);
+                        final var destination = getDestination(envelope).serviceId();
                         if (!isSelf && isSenderSenderKeyCapable && isSelfSenderKeyCapable) {
                             logger.debug("Received invalid message, requesting message resend.");
-                            actions.add(new SendRetryMessageRequestAction(sender, serviceId, e, envelope));
+                            actions.add(new SendRetryMessageRequestAction(sender, serviceId, e, envelope, destination));
                         } else {
                             logger.debug("Received invalid message, queuing renew session action.");
-                            actions.add(new RenewSessionAction(sender, serviceId));
+                            actions.add(new RenewSessionAction(sender, serviceId, destination));
                         }
                     } else {
                         logger.debug("Received invalid message from invalid sender: {}", e.getSender());
@@ -346,7 +347,12 @@ public final class IncomingMessageHandler {
                     senderDeviceId,
                     message.getTimestamp());
             if (message.getDeviceId() == account.getDeviceId()) {
-                handleDecryptionErrorMessage(actions, sender, senderServiceId, senderDeviceId, message);
+                handleDecryptionErrorMessage(actions,
+                        sender,
+                        senderServiceId,
+                        senderDeviceId,
+                        message,
+                        destination.serviceId());
             } else {
                 logger.debug("Request is for another one of our devices");
             }
@@ -430,7 +436,8 @@ public final class IncomingMessageHandler {
             final RecipientId sender,
             final ServiceId senderServiceId,
             final int senderDeviceId,
-            final DecryptionErrorMessage message
+            final DecryptionErrorMessage message,
+            final ServiceId destination
     ) {
         final var logEntries = account.getMessageSendLogStore()
                 .findMessages(senderServiceId,
@@ -443,14 +450,14 @@ public final class IncomingMessageHandler {
         }
 
         if (message.getRatchetKey().isPresent()) {
-            if (account.getAciSessionStore()
-                    .isCurrentRatchetKey(senderServiceId, senderDeviceId, message.getRatchetKey().get())) {
+            final var sessionStore = account.getAccountData(destination).getSessionStore();
+            if (sessionStore.isCurrentRatchetKey(senderServiceId, senderDeviceId, message.getRatchetKey().get())) {
                 if (logEntries.isEmpty()) {
                     logger.debug("Renewing the session with sender");
-                    actions.add(new RenewSessionAction(sender, senderServiceId));
+                    actions.add(new RenewSessionAction(sender, senderServiceId, destination));
                 } else {
                     logger.trace("Archiving the session with sender, a resend message has already been queued");
-                    context.getAccount().getAciSessionStore().archiveSessions(senderServiceId);
+                    sessionStore.archiveSessions(senderServiceId);
                 }
             }
             return;
@@ -806,9 +813,12 @@ public final class IncomingMessageHandler {
             }
         }
 
+        final var selfAddress = isSync ? source : destination;
         final var conversationPartnerAddress = isSync ? destination : source;
         if (conversationPartnerAddress != null && message.isEndSession()) {
-            account.getAciSessionStore().deleteAllSessions(conversationPartnerAddress.serviceId());
+            account.getAccountData(selfAddress.serviceId())
+                    .getSessionStore()
+                    .deleteAllSessions(conversationPartnerAddress.serviceId());
         }
         if (message.isExpirationUpdate() || message.getBody().isPresent()) {
             if (message.getGroupContext().isPresent()) {
@@ -854,10 +864,12 @@ public final class IncomingMessageHandler {
             if (message.getQuote().isPresent()) {
                 final var quote = message.getQuote().get();
 
-                for (var quotedAttachment : quote.getAttachments()) {
-                    final var thumbnail = quotedAttachment.getThumbnail();
-                    if (thumbnail != null) {
-                        context.getAttachmentHelper().downloadAttachment(thumbnail);
+                if (quote.getAttachments() != null) {
+                    for (var quotedAttachment : quote.getAttachments()) {
+                        final var thumbnail = quotedAttachment.getThumbnail();
+                        if (thumbnail != null) {
+                            context.getAttachmentHelper().downloadAttachment(thumbnail);
+                        }
                     }
                 }
             }
@@ -972,7 +984,9 @@ public final class IncomingMessageHandler {
             return new DeviceAddress(account.getSelfRecipientId(), account.getAci(), account.getDeviceId());
         }
         final var address = addressOptional.get();
-        return new DeviceAddress(context.getRecipientHelper().resolveRecipient(address), address.getServiceId(), 0);
+        return new DeviceAddress(context.getRecipientHelper().resolveRecipient(address),
+                address.getServiceId(),
+                account.getDeviceId());
     }
 
     private record DeviceAddress(RecipientId recipientId, ServiceId serviceId, int deviceId) {}
index 90b2f6a4cb16e45e3e2ccc9d75cfbc666cb286d5..faea76d7e63fd06325b49e33210512747e8648a9 100644 (file)
@@ -82,6 +82,7 @@ import org.whispersystems.signalservice.api.messages.SignalServiceReceiptMessage
 import org.whispersystems.signalservice.api.messages.SignalServiceTypingMessage;
 import org.whispersystems.signalservice.api.push.ACI;
 import org.whispersystems.signalservice.api.push.ServiceId;
+import org.whispersystems.signalservice.api.push.ServiceIdType;
 import org.whispersystems.signalservice.api.util.DeviceNameUtil;
 import org.whispersystems.signalservice.api.util.InvalidNumberException;
 import org.whispersystems.signalservice.api.util.PhoneNumberFormatter;
@@ -183,8 +184,8 @@ public class ManagerImpl implements Manager {
         });
         disposable.add(account.getIdentityKeyStore().getIdentityChanges().subscribe(serviceId -> {
             logger.trace("Archiving old sessions for {}", serviceId);
-            account.getAciSessionStore().archiveSessions(serviceId);
-            account.getPniSessionStore().archiveSessions(serviceId);
+            account.getAccountData(ServiceIdType.ACI).getSessionStore().archiveSessions(serviceId);
+            account.getAccountData(ServiceIdType.PNI).getSessionStore().archiveSessions(serviceId);
             account.getSenderKeyStore().deleteSharedWith(serviceId);
             final var recipientId = account.getRecipientResolver().resolveRecipient(serviceId);
             final var profile = account.getProfileStore().getProfile(recipientId);
@@ -775,7 +776,7 @@ public class ManagerImpl implements Manager {
                         .resolveRecipientAddress(recipientId)
                         .serviceId();
                 if (serviceId.isPresent()) {
-                    account.getAciSessionStore().deleteAllSessions(serviceId.get());
+                    account.getAccountData(ServiceIdType.ACI).getSessionStore().deleteAllSessions(serviceId.get());
                 }
             }
         }
index 76867b728939d750992caae356dc849a5bc347f3..931a45a20da7c3e3666f602251a3f1bc351ef800 100644 (file)
@@ -89,7 +89,6 @@ import java.nio.channels.ClosedChannelException;
 import java.nio.channels.FileChannel;
 import java.nio.channels.FileLock;
 import java.nio.file.Files;
-import java.security.SecureRandom;
 import java.sql.Connection;
 import java.sql.SQLException;
 import java.util.Base64;
@@ -136,36 +135,14 @@ public class SignalAccount implements Closeable {
     private StorageKey storageKey;
     private long storageManifestVersion = -1;
     private ProfileKey profileKey;
-    private int aciPreKeyIdOffset = 1;
-    private int aciNextSignedPreKeyId = 1;
-    private int pniPreKeyIdOffset = 1;
-    private int pniNextSignedPreKeyId = 1;
-    private int aciKyberPreKeyIdOffset = 1;
-    private int aciActiveLastResortKyberPreKeyId = -1;
-    private int pniKyberPreKeyIdOffset = 1;
-    private int pniActiveLastResortKyberPreKeyId = -1;
-    private IdentityKeyPair aciIdentityKeyPair;
-    private IdentityKeyPair pniIdentityKeyPair;
-    private int localRegistrationId;
-    private int localPniRegistrationId;
     private Settings settings;
     private long lastReceiveTimestamp = 0;
 
     private boolean registered = false;
 
-    private SignalProtocolStore aciSignalProtocolStore;
-    private SignalProtocolStore pniSignalProtocolStore;
-    private PreKeyStore aciPreKeyStore;
-    private SignedPreKeyStore aciSignedPreKeyStore;
-    private KyberPreKeyStore aciKyberPreKeyStore;
-    private PreKeyStore pniPreKeyStore;
-    private SignedPreKeyStore pniSignedPreKeyStore;
-    private KyberPreKeyStore pniKyberPreKeyStore;
-    private SessionStore aciSessionStore;
-    private SessionStore pniSessionStore;
+    private final AccountData aciAccountData = new AccountData(ServiceIdType.ACI);
+    private final AccountData pniAccountData = new AccountData(ServiceIdType.PNI);
     private IdentityKeyStore identityKeyStore;
-    private SignalIdentityKeyStore aciIdentityKeyStore;
-    private SignalIdentityKeyStore pniIdentityKeyStore;
     private SenderKeyStore senderKeyStore;
     private GroupStore groupStore;
     private RecipientStore recipientStore;
@@ -231,10 +208,10 @@ public class SignalAccount implements Closeable {
         signalAccount.profileKey = profileKey;
 
         signalAccount.dataPath = dataPath;
-        signalAccount.aciIdentityKeyPair = aciIdentityKey;
-        signalAccount.pniIdentityKeyPair = pniIdentityKey;
-        signalAccount.localRegistrationId = registrationId;
-        signalAccount.localPniRegistrationId = pniRegistrationId;
+        signalAccount.aciAccountData.setIdentityKeyPair(aciIdentityKey);
+        signalAccount.pniAccountData.setIdentityKeyPair(pniIdentityKey);
+        signalAccount.aciAccountData.setLocalRegistrationId(registrationId);
+        signalAccount.pniAccountData.setLocalRegistrationId(pniRegistrationId);
         signalAccount.settings = settings;
         signalAccount.configurationStore = new ConfigurationStore(signalAccount::saveConfigurationStore);
 
@@ -296,8 +273,8 @@ public class SignalAccount implements Closeable {
                 profileKey);
         signalAccount.getRecipientTrustedResolver()
                 .resolveSelfRecipientTrusted(signalAccount.getSelfRecipientAddress());
-        signalAccount.getAciSessionStore().archiveAllSessions();
-        signalAccount.getPniSessionStore().archiveAllSessions();
+        signalAccount.aciAccountData.getSessionStore().archiveAllSessions();
+        signalAccount.pniAccountData.getSessionStore().archiveAllSessions();
         signalAccount.getSenderKeyStore().deleteAll();
         signalAccount.clearAllPreKeys();
         return signalAccount;
@@ -308,16 +285,17 @@ public class SignalAccount implements Closeable {
     }
 
     private void clearAllPreKeys() {
-        resetPreKeyOffsets(ServiceIdType.ACI);
-        resetPreKeyOffsets(ServiceIdType.PNI);
-        resetKyberPreKeyOffsets(ServiceIdType.ACI);
-        resetKyberPreKeyOffsets(ServiceIdType.PNI);
-        this.getAciPreKeyStore().removeAllPreKeys();
-        this.getAciSignedPreKeyStore().removeAllSignedPreKeys();
-        this.getAciKyberPreKeyStore().removeAllKyberPreKeys();
-        this.getPniPreKeyStore().removeAllPreKeys();
-        this.getPniSignedPreKeyStore().removeAllSignedPreKeys();
-        this.getPniKyberPreKeyStore().removeAllKyberPreKeys();
+        clearAllPreKeys(ServiceIdType.ACI);
+        clearAllPreKeys(ServiceIdType.PNI);
+    }
+
+    private void clearAllPreKeys(ServiceIdType serviceIdType) {
+        final var accountData = getAccountData(serviceIdType);
+        resetPreKeyOffsets(serviceIdType);
+        resetKyberPreKeyOffsets(serviceIdType);
+        accountData.getPreKeyStore().removeAllPreKeys();
+        accountData.getSignedPreKeyStore().removeAllSignedPreKeys();
+        accountData.getKyberPreKeyStore().removeAllKyberPreKeys();
         save();
     }
 
@@ -347,8 +325,8 @@ public class SignalAccount implements Closeable {
         signalAccount.dataPath = dataPath;
         signalAccount.accountPath = accountPath;
         signalAccount.serviceEnvironment = serviceEnvironment;
-        signalAccount.localRegistrationId = registrationId;
-        signalAccount.localPniRegistrationId = pniRegistrationId;
+        signalAccount.aciAccountData.setLocalRegistrationId(registrationId);
+        signalAccount.pniAccountData.setLocalRegistrationId(pniRegistrationId);
         signalAccount.settings = settings;
         signalAccount.setProvisioningData(number,
                 aci,
@@ -391,8 +369,8 @@ public class SignalAccount implements Closeable {
         getProfileStore().storeSelfProfileKey(getSelfRecipientId(), getProfileKey());
         this.encryptedDeviceName = encryptedDeviceName;
         this.deviceId = deviceId;
-        this.aciIdentityKeyPair = aciIdentity;
-        this.pniIdentityKeyPair = pniIdentity;
+        this.aciAccountData.setIdentityKeyPair(aciIdentity);
+        this.pniAccountData.setIdentityKeyPair(pniIdentity);
         this.registered = true;
         this.isMultiDevice = true;
         this.lastReceiveTimestamp = 0;
@@ -400,13 +378,9 @@ public class SignalAccount implements Closeable {
         this.storageManifestVersion = -1;
         this.setStorageManifest(null);
         this.storageKey = null;
-        final var aciPublicKey = getAciIdentityKeyPair().getPublicKey();
-        getIdentityKeyStore().saveIdentity(getAci(), aciPublicKey);
-        getIdentityKeyStore().setIdentityTrustLevel(getAci(), aciPublicKey, TrustLevel.TRUSTED_VERIFIED);
+        trustSelfIdentity(ServiceIdType.ACI);
         if (getPniIdentityKeyPair() != null) {
-            final var pniPublicKey = getPniIdentityKeyPair().getPublicKey();
-            getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey);
-            getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED);
+            trustSelfIdentity(ServiceIdType.PNI);
         }
     }
 
@@ -438,8 +412,8 @@ public class SignalAccount implements Closeable {
         getMessageCache().deleteMessages(recipientId);
         if (recipientAddress.serviceId().isPresent()) {
             final var serviceId = recipientAddress.serviceId().get();
-            getAciSessionStore().deleteAllSessions(serviceId);
-            getPniSessionStore().deleteAllSessions(serviceId);
+            aciAccountData.getSessionStore().deleteAllSessions(serviceId);
+            pniAccountData.getSessionStore().deleteAllSessions(serviceId);
             getIdentityKeyStore().deleteIdentity(serviceId);
             getSenderKeyStore().deleteAll(serviceId);
         }
@@ -595,9 +569,9 @@ public class SignalAccount implements Closeable {
             registrationId = rootNode.get("registrationId").asInt();
         }
         if (rootNode.hasNonNull("pniRegistrationId")) {
-            localPniRegistrationId = rootNode.get("pniRegistrationId").asInt();
+            pniAccountData.setLocalRegistrationId(rootNode.get("pniRegistrationId").asInt());
         } else {
-            localPniRegistrationId = KeyHelper.generateRegistrationId(false);
+            pniAccountData.setLocalRegistrationId(KeyHelper.generateRegistrationId(false));
         }
         IdentityKeyPair aciIdentityKeyPair = null;
         if (rootNode.hasNonNull("identityPrivateKey") && rootNode.hasNonNull("identityKey")) {
@@ -608,7 +582,7 @@ public class SignalAccount implements Closeable {
         if (rootNode.hasNonNull("pniIdentityPrivateKey") && rootNode.hasNonNull("pniIdentityKey")) {
             final var publicKeyBytes = Base64.getDecoder().decode(rootNode.get("pniIdentityKey").asText());
             final var privateKeyBytes = Base64.getDecoder().decode(rootNode.get("pniIdentityPrivateKey").asText());
-            pniIdentityKeyPair = KeyUtils.getIdentityKeyPair(publicKeyBytes, privateKeyBytes);
+            pniAccountData.setIdentityKeyPair(KeyUtils.getIdentityKeyPair(publicKeyBytes, privateKeyBytes));
         }
 
         if (rootNode.hasNonNull("registrationLockPin")) {
@@ -624,44 +598,46 @@ public class SignalAccount implements Closeable {
             storageManifestVersion = rootNode.get("storageManifestVersion").asLong();
         }
         if (rootNode.hasNonNull("preKeyIdOffset")) {
-            aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
+            aciAccountData.preKeyMetadata.preKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
         } else {
-            aciPreKeyIdOffset = getRandomPreKeyIdOffset();
+            aciAccountData.preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset();
         }
         if (rootNode.hasNonNull("nextSignedPreKeyId")) {
-            aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
+            aciAccountData.preKeyMetadata.nextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
         } else {
-            aciNextSignedPreKeyId = getRandomPreKeyIdOffset();
+            aciAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset();
         }
         if (rootNode.hasNonNull("pniPreKeyIdOffset")) {
-            pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
+            pniAccountData.preKeyMetadata.preKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
         } else {
-            pniPreKeyIdOffset = getRandomPreKeyIdOffset();
+            pniAccountData.preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset();
         }
         if (rootNode.hasNonNull("pniNextSignedPreKeyId")) {
-            pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1);
+            pniAccountData.preKeyMetadata.nextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1);
         } else {
-            pniNextSignedPreKeyId = getRandomPreKeyIdOffset();
+            pniAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset();
         }
         if (rootNode.hasNonNull("kyberPreKeyIdOffset")) {
-            aciKyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1);
+            aciAccountData.preKeyMetadata.kyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1);
         } else {
-            aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
+            aciAccountData.preKeyMetadata.kyberPreKeyIdOffset = getRandomPreKeyIdOffset();
         }
         if (rootNode.hasNonNull("activeLastResortKyberPreKeyId")) {
-            aciActiveLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId").asInt(-1);
+            aciAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId")
+                    .asInt(-1);
         } else {
-            aciActiveLastResortKyberPreKeyId = -1;
+            aciAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = -1;
         }
         if (rootNode.hasNonNull("pniKyberPreKeyIdOffset")) {
-            pniKyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1);
+            pniAccountData.preKeyMetadata.kyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1);
         } else {
-            pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
+            pniAccountData.preKeyMetadata.kyberPreKeyIdOffset = getRandomPreKeyIdOffset();
         }
         if (rootNode.hasNonNull("pniActiveLastResortKyberPreKeyId")) {
-            pniActiveLastResortKyberPreKeyId = rootNode.get("pniActiveLastResortKyberPreKeyId").asInt(-1);
+            pniAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = rootNode.get(
+                    "pniActiveLastResortKyberPreKeyId").asInt(-1);
         } else {
-            pniActiveLastResortKyberPreKeyId = -1;
+            pniAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = -1;
         }
         if (rootNode.hasNonNull("profileKey")) {
             try {
@@ -687,22 +663,22 @@ public class SignalAccount implements Closeable {
         }
         final var legacyAciPreKeysPath = getAciPreKeysPath(dataPath, accountPath);
         if (legacyAciPreKeysPath.exists()) {
-            LegacyPreKeyStore.migrate(legacyAciPreKeysPath, getAciPreKeyStore());
+            LegacyPreKeyStore.migrate(legacyAciPreKeysPath, aciAccountData.getPreKeyStore());
             migratedLegacyConfig = true;
         }
         final var legacyPniPreKeysPath = getPniPreKeysPath(dataPath, accountPath);
         if (legacyPniPreKeysPath.exists()) {
-            LegacyPreKeyStore.migrate(legacyPniPreKeysPath, getPniPreKeyStore());
+            LegacyPreKeyStore.migrate(legacyPniPreKeysPath, pniAccountData.getPreKeyStore());
             migratedLegacyConfig = true;
         }
         final var legacyAciSignedPreKeysPath = getAciSignedPreKeysPath(dataPath, accountPath);
         if (legacyAciSignedPreKeysPath.exists()) {
-            LegacySignedPreKeyStore.migrate(legacyAciSignedPreKeysPath, getAciSignedPreKeyStore());
+            LegacySignedPreKeyStore.migrate(legacyAciSignedPreKeysPath, aciAccountData.getSignedPreKeyStore());
             migratedLegacyConfig = true;
         }
         final var legacyPniSignedPreKeysPath = getPniSignedPreKeysPath(dataPath, accountPath);
         if (legacyPniSignedPreKeysPath.exists()) {
-            LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, getPniSignedPreKeyStore());
+            LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, pniAccountData.getSignedPreKeyStore());
             migratedLegacyConfig = true;
         }
         final var legacySessionsPath = getSessionsPath(dataPath, accountPath);
@@ -710,7 +686,7 @@ public class SignalAccount implements Closeable {
             LegacySessionStore.migrate(legacySessionsPath,
                     getRecipientResolver(),
                     getRecipientAddressResolver(),
-                    getAciSessionStore());
+                    aciAccountData.getSessionStore());
             migratedLegacyConfig = true;
         }
         final var legacyIdentitiesPath = getIdentitiesPath(dataPath, accountPath);
@@ -731,8 +707,8 @@ public class SignalAccount implements Closeable {
             migratedLegacyConfig = true;
         }
 
-        this.aciIdentityKeyPair = aciIdentityKeyPair;
-        this.localRegistrationId = registrationId;
+        this.aciAccountData.setIdentityKeyPair(aciIdentityKeyPair);
+        this.aciAccountData.setLocalRegistrationId(registrationId);
 
         migratedLegacyConfig = loadLegacyStores(rootNode, legacySignalProtocolStore) || migratedLegacyConfig;
 
@@ -805,7 +781,7 @@ public class SignalAccount implements Closeable {
             logger.debug("Migrating legacy pre key store.");
             for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) {
                 try {
-                    getAciPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue()));
+                    aciAccountData.getPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue()));
                 } catch (InvalidMessageException e) {
                     logger.warn("Failed to migrate pre key, ignoring", e);
                 }
@@ -817,8 +793,8 @@ public class SignalAccount implements Closeable {
             logger.debug("Migrating legacy signed pre key store.");
             for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) {
                 try {
-                    getAciSignedPreKeyStore().storeSignedPreKey(entry.getKey(),
-                            new SignedPreKeyRecord(entry.getValue()));
+                    aciAccountData.getSignedPreKeyStore()
+                            .storeSignedPreKey(entry.getKey(), new SignedPreKeyRecord(entry.getValue()));
                 } catch (InvalidMessageException e) {
                     logger.warn("Failed to migrate signed pre key, ignoring", e);
                 }
@@ -830,8 +806,9 @@ public class SignalAccount implements Closeable {
             logger.debug("Migrating legacy session store.");
             for (var session : legacySignalProtocolStore.getLegacySessionStore().getSessions()) {
                 try {
-                    getAciSessionStore().storeSession(new SignalProtocolAddress(session.address.getIdentifier(),
-                            session.deviceId), new SessionRecord(session.sessionRecord));
+                    aciAccountData.getSessionStore()
+                            .storeSession(new SignalProtocolAddress(session.address.getIdentifier(), session.deviceId),
+                                    new SessionRecord(session.sessionRecord));
                 } catch (Exception e) {
                     logger.warn("Failed to migrate session, ignoring", e);
                 }
@@ -981,35 +958,44 @@ public class SignalAccount implements Closeable {
                     .put("isMultiDevice", isMultiDevice)
                     .put("lastReceiveTimestamp", lastReceiveTimestamp)
                     .put("password", password)
-                    .put("registrationId", localRegistrationId)
-                    .put("pniRegistrationId", localPniRegistrationId)
+                    .put("registrationId", aciAccountData.getLocalRegistrationId())
+                    .put("pniRegistrationId", pniAccountData.getLocalRegistrationId())
                     .put("identityPrivateKey",
-                            Base64.getEncoder().encodeToString(aciIdentityKeyPair.getPrivateKey().serialize()))
+                            Base64.getEncoder()
+                                    .encodeToString(aciAccountData.getIdentityKeyPair().getPrivateKey().serialize()))
                     .put("identityKey",
-                            Base64.getEncoder().encodeToString(aciIdentityKeyPair.getPublicKey().serialize()))
+                            Base64.getEncoder()
+                                    .encodeToString(aciAccountData.getIdentityKeyPair().getPublicKey().serialize()))
                     .put("pniIdentityPrivateKey",
-                            pniIdentityKeyPair == null
+                            pniAccountData.getIdentityKeyPair() == null
                                     ? null
                                     : Base64.getEncoder()
-                                            .encodeToString(pniIdentityKeyPair.getPrivateKey().serialize()))
+                                            .encodeToString(pniAccountData.getIdentityKeyPair()
+                                                    .getPrivateKey()
+                                                    .serialize()))
                     .put("pniIdentityKey",
-                            pniIdentityKeyPair == null
+                            pniAccountData.getIdentityKeyPair() == null
                                     ? null
-                                    : Base64.getEncoder().encodeToString(pniIdentityKeyPair.getPublicKey().serialize()))
+                                    : Base64.getEncoder()
+                                            .encodeToString(pniAccountData.getIdentityKeyPair()
+                                                    .getPublicKey()
+                                                    .serialize()))
                     .put("registrationLockPin", registrationLockPin)
                     .put("pinMasterKey",
                             pinMasterKey == null ? null : Base64.getEncoder().encodeToString(pinMasterKey.serialize()))
                     .put("storageKey",
                             storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize()))
                     .put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion)
-                    .put("preKeyIdOffset", aciPreKeyIdOffset)
-                    .put("nextSignedPreKeyId", aciNextSignedPreKeyId)
-                    .put("pniPreKeyIdOffset", pniPreKeyIdOffset)
-                    .put("pniNextSignedPreKeyId", pniNextSignedPreKeyId)
-                    .put("kyberPreKeyIdOffset", aciKyberPreKeyIdOffset)
-                    .put("activeLastResortKyberPreKeyId", aciActiveLastResortKyberPreKeyId)
-                    .put("pniKyberPreKeyIdOffset", pniKyberPreKeyIdOffset)
-                    .put("pniActiveLastResortKyberPreKeyId", pniActiveLastResortKyberPreKeyId)
+                    .put("preKeyIdOffset", aciAccountData.getPreKeyMetadata().preKeyIdOffset)
+                    .put("nextSignedPreKeyId", aciAccountData.getPreKeyMetadata().nextSignedPreKeyId)
+                    .put("pniPreKeyIdOffset", pniAccountData.getPreKeyMetadata().preKeyIdOffset)
+                    .put("pniNextSignedPreKeyId", pniAccountData.getPreKeyMetadata().nextSignedPreKeyId)
+                    .put("kyberPreKeyIdOffset", aciAccountData.getPreKeyMetadata().kyberPreKeyIdOffset)
+                    .put("activeLastResortKyberPreKeyId",
+                            aciAccountData.getPreKeyMetadata().activeLastResortKyberPreKeyId)
+                    .put("pniKyberPreKeyIdOffset", pniAccountData.getPreKeyMetadata().kyberPreKeyIdOffset)
+                    .put("pniActiveLastResortKyberPreKeyId",
+                            pniAccountData.getPreKeyMetadata().activeLastResortKyberPreKeyId)
                     .put("profileKey",
                             profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize()))
                     .put("registered", registered)
@@ -1054,154 +1040,79 @@ public class SignalAccount implements Closeable {
     }
 
     public void resetPreKeyOffsets(final ServiceIdType serviceIdType) {
-        if (serviceIdType.equals(ServiceIdType.ACI)) {
-            this.aciPreKeyIdOffset = getRandomPreKeyIdOffset();
-            this.aciNextSignedPreKeyId = getRandomPreKeyIdOffset();
-        } else {
-            this.pniPreKeyIdOffset = getRandomPreKeyIdOffset();
-            this.pniNextSignedPreKeyId = getRandomPreKeyIdOffset();
-        }
+        final var preKeyMetadata = getAccountData(serviceIdType).getPreKeyMetadata();
+        preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset();
+        preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset();
         save();
     }
 
     private static int getRandomPreKeyIdOffset() {
-        return new SecureRandom().nextInt(PREKEY_MAXIMUM_ID);
+        return KeyUtils.getRandomInt(PREKEY_MAXIMUM_ID);
     }
 
     public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) {
-        if (serviceIdType.equals(ServiceIdType.ACI)) {
-            addAciPreKeys(records);
-        } else {
-            addPniPreKeys(records);
-        }
-    }
-
-    private void addAciPreKeys(List<PreKeyRecord> records) {
+        final var accountData = getAccountData(serviceIdType);
+        final var preKeyMetadata = accountData.getPreKeyMetadata();
         for (var record : records) {
-            if (aciPreKeyIdOffset != record.getId()) {
-                logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset);
+            if (preKeyMetadata.preKeyIdOffset != record.getId()) {
+                logger.error("Invalid pre key id {}, expected {}", record.getId(), preKeyMetadata.preKeyIdOffset);
                 throw new AssertionError("Invalid pre key id");
             }
-            getAciPreKeyStore().storePreKey(record.getId(), record);
-            aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
-        }
-        save();
-    }
-
-    private void addPniPreKeys(List<PreKeyRecord> records) {
-        for (var record : records) {
-            if (pniPreKeyIdOffset != record.getId()) {
-                logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset);
-                throw new AssertionError("Invalid pre key id");
-            }
-            getPniPreKeyStore().storePreKey(record.getId(), record);
-            pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
+            accountData.getPreKeyStore().storePreKey(record.getId(), record);
+            preKeyMetadata.preKeyIdOffset = (preKeyMetadata.preKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
         }
         save();
     }
 
     public void addSignedPreKey(ServiceIdType serviceIdType, SignedPreKeyRecord record) {
-        if (serviceIdType.equals(ServiceIdType.ACI)) {
-            addAciSignedPreKey(record);
-        } else {
-            addPniSignedPreKey(record);
-        }
-    }
-
-    public void addAciSignedPreKey(SignedPreKeyRecord record) {
-        if (aciNextSignedPreKeyId != record.getId()) {
-            logger.error("Invalid signed pre key id {}, expected {}", record.getId(), aciNextSignedPreKeyId);
-            throw new AssertionError("Invalid signed pre key id");
-        }
-        getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
-        aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
-        save();
-    }
-
-    public void addPniSignedPreKey(SignedPreKeyRecord record) {
-        if (pniNextSignedPreKeyId != record.getId()) {
-            logger.error("Invalid signed pre key id {}, expected {}", record.getId(), pniNextSignedPreKeyId);
+        final var accountData = getAccountData(serviceIdType);
+        final var preKeyMetadata = accountData.getPreKeyMetadata();
+        if (preKeyMetadata.nextSignedPreKeyId != record.getId()) {
+            logger.error("Invalid signed pre key id {}, expected {}",
+                    record.getId(),
+                    preKeyMetadata.nextSignedPreKeyId);
             throw new AssertionError("Invalid signed pre key id");
         }
-        getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
-        pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
+        accountData.getSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
+        preKeyMetadata.nextSignedPreKeyId = (preKeyMetadata.nextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
         save();
     }
 
     public void resetKyberPreKeyOffsets(final ServiceIdType serviceIdType) {
-        if (serviceIdType.equals(ServiceIdType.ACI)) {
-            this.aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
-            this.aciActiveLastResortKyberPreKeyId = -1;
-        } else {
-            this.pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
-            this.pniActiveLastResortKyberPreKeyId = -1;
-        }
+        final var preKeyMetadata = getAccountData(serviceIdType).getPreKeyMetadata();
+        preKeyMetadata.kyberPreKeyIdOffset = getRandomPreKeyIdOffset();
+        preKeyMetadata.activeLastResortKyberPreKeyId = -1;
         save();
     }
 
     public void addKyberPreKeys(ServiceIdType serviceIdType, List<KyberPreKeyRecord> records) {
-        if (serviceIdType.equals(ServiceIdType.ACI)) {
-            addAciKyberPreKeys(records);
-        } else {
-            addPniKyberPreKeys(records);
-        }
-    }
-
-    private void addAciKyberPreKeys(List<KyberPreKeyRecord> records) {
+        final var accountData = getAccountData(serviceIdType);
+        final var preKeyMetadata = accountData.getPreKeyMetadata();
         for (var record : records) {
-            if (aciKyberPreKeyIdOffset != record.getId()) {
-                logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), aciKyberPreKeyIdOffset);
+            if (preKeyMetadata.kyberPreKeyIdOffset != record.getId()) {
+                logger.error("Invalid kyber pre key id {}, expected {}",
+                        record.getId(),
+                        preKeyMetadata.kyberPreKeyIdOffset);
                 throw new AssertionError("Invalid kyber pre key id");
             }
-            getAciKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
-            aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
-        }
-        save();
-    }
-
-    private void addPniKyberPreKeys(List<KyberPreKeyRecord> records) {
-        for (var record : records) {
-            if (pniKyberPreKeyIdOffset != record.getId()) {
-                logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), pniKyberPreKeyIdOffset);
-                throw new AssertionError("Invalid kyber pre key id");
-            }
-            getPniKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
-            pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
+            accountData.getKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
+            preKeyMetadata.kyberPreKeyIdOffset = (preKeyMetadata.kyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
         }
         save();
     }
 
     public void addLastResortKyberPreKey(ServiceIdType serviceIdType, KyberPreKeyRecord record) {
-        if (serviceIdType.equals(ServiceIdType.ACI)) {
-            addAciLastResortKyberPreKey(record);
-        } else {
-            addPniLastResortKyberPreKey(record);
-        }
-    }
-
-    public void addAciLastResortKyberPreKey(KyberPreKeyRecord record) {
-        if (aciKyberPreKeyIdOffset != record.getId()) {
+        final var accountData = getAccountData(serviceIdType);
+        final var preKeyMetadata = accountData.getPreKeyMetadata();
+        if (preKeyMetadata.kyberPreKeyIdOffset != record.getId()) {
             logger.error("Invalid last resort kyber pre key id {}, expected {}",
                     record.getId(),
-                    aciKyberPreKeyIdOffset);
+                    preKeyMetadata.kyberPreKeyIdOffset);
             throw new AssertionError("Invalid last resort kyber pre key id");
         }
-        getAciKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
-        aciActiveLastResortKyberPreKeyId = record.getId();
-        aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
-        save();
-    }
-
-    public void addPniLastResortKyberPreKey(KyberPreKeyRecord record) {
-        if (pniKyberPreKeyIdOffset != record.getId()) {
-            logger.error("Invalid last resort kyber pre key id {}, expected {}",
-                    record.getId(),
-                    pniKyberPreKeyIdOffset);
-            throw new AssertionError("Invalid last resort kyber pre key id");
-        }
-        getPniKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
-        pniActiveLastResortKyberPreKeyId = record.getId();
-        pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
+        accountData.getKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
+        preKeyMetadata.activeLastResortKyberPreKeyId = record.getId();
+        preKeyMetadata.kyberPreKeyIdOffset = (preKeyMetadata.kyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
         save();
     }
 
@@ -1209,6 +1120,23 @@ public class SignalAccount implements Closeable {
         return previousStorageVersion;
     }
 
+    public AccountData getAccountData(ServiceIdType serviceIdType) {
+        return switch (serviceIdType) {
+            case ACI -> aciAccountData;
+            case PNI -> pniAccountData;
+        };
+    }
+
+    public AccountData getAccountData(ServiceId accountIdentifier) {
+        if (accountIdentifier.equals(aci)) {
+            return aciAccountData;
+        } else if (accountIdentifier.equals(pni)) {
+            return pniAccountData;
+        } else {
+            throw new IllegalArgumentException("No matching account data found for " + accountIdentifier);
+        }
+    }
+
     public SignalServiceDataStore getSignalServiceDataStore() {
         return new SignalServiceDataStore() {
             @Override
@@ -1224,12 +1152,12 @@ public class SignalAccount implements Closeable {
 
             @Override
             public SignalServiceAccountDataStore aci() {
-                return getAciSignalServiceAccountDataStore();
+                return aciAccountData.getSignalServiceAccountDataStore();
             }
 
             @Override
             public SignalServiceAccountDataStore pni() {
-                return getPniSignalServiceAccountDataStore();
+                return pniAccountData.getSignalServiceAccountDataStore();
             }
 
             @Override
@@ -1239,89 +1167,11 @@ public class SignalAccount implements Closeable {
         };
     }
 
-    private SignalServiceAccountDataStore getAciSignalServiceAccountDataStore() {
-        return getOrCreate(() -> aciSignalProtocolStore,
-                () -> aciSignalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(),
-                        getAciSignedPreKeyStore(),
-                        getAciKyberPreKeyStore(),
-                        getAciSessionStore(),
-                        getAciIdentityKeyStore(),
-                        getSenderKeyStore(),
-                        this::isMultiDevice));
-    }
-
-    private SignalServiceAccountDataStore getPniSignalServiceAccountDataStore() {
-        return getOrCreate(() -> pniSignalProtocolStore,
-                () -> pniSignalProtocolStore = new SignalProtocolStore(getPniPreKeyStore(),
-                        getPniSignedPreKeyStore(),
-                        getPniKyberPreKeyStore(),
-                        getPniSessionStore(),
-                        getPniIdentityKeyStore(),
-                        getSenderKeyStore(),
-                        this::isMultiDevice));
-    }
-
-    private PreKeyStore getAciPreKeyStore() {
-        return getOrCreate(() -> aciPreKeyStore,
-                () -> aciPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
-    }
-
-    private SignedPreKeyStore getAciSignedPreKeyStore() {
-        return getOrCreate(() -> aciSignedPreKeyStore,
-                () -> aciSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
-    }
-
-    private KyberPreKeyStore getAciKyberPreKeyStore() {
-        return getOrCreate(() -> aciKyberPreKeyStore,
-                () -> aciKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
-    }
-
-    private PreKeyStore getPniPreKeyStore() {
-        return getOrCreate(() -> pniPreKeyStore,
-                () -> pniPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
-    }
-
-    private SignedPreKeyStore getPniSignedPreKeyStore() {
-        return getOrCreate(() -> pniSignedPreKeyStore,
-                () -> pniSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
-    }
-
-    private KyberPreKeyStore getPniKyberPreKeyStore() {
-        return getOrCreate(() -> pniKyberPreKeyStore,
-                () -> pniKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
-    }
-
-    public SessionStore getAciSessionStore() {
-        return getOrCreate(() -> aciSessionStore,
-                () -> aciSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.ACI));
-    }
-
-    public SessionStore getPniSessionStore() {
-        return getOrCreate(() -> pniSessionStore,
-                () -> pniSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.PNI));
-    }
-
     public IdentityKeyStore getIdentityKeyStore() {
         return getOrCreate(() -> identityKeyStore,
                 () -> identityKeyStore = new IdentityKeyStore(getAccountDatabase(), settings.trustNewIdentity()));
     }
 
-    public SignalIdentityKeyStore getAciIdentityKeyStore() {
-        return getOrCreate(() -> aciIdentityKeyStore,
-                () -> aciIdentityKeyStore = new SignalIdentityKeyStore(getRecipientResolver(),
-                        () -> aciIdentityKeyPair,
-                        localRegistrationId,
-                        getIdentityKeyStore()));
-    }
-
-    public SignalIdentityKeyStore getPniIdentityKeyStore() {
-        return getOrCreate(() -> pniIdentityKeyStore,
-                () -> pniIdentityKeyStore = new SignalIdentityKeyStore(getRecipientResolver(),
-                        () -> pniIdentityKeyPair,
-                        localRegistrationId,
-                        getIdentityKeyStore()));
-    }
-
     public GroupStore getGroupStore() {
         return getOrCreate(() -> groupStore,
                 () -> groupStore = new GroupStore(getAccountDatabase(),
@@ -1489,11 +1339,7 @@ public class SignalAccount implements Closeable {
         if (this.pni != null && !this.pni.equals(updatedPni)) {
             // Clear data for old PNI
             identityKeyStore.deleteIdentity(this.pni);
-            getPniPreKeyStore().removeAllPreKeys();
-            getPniSignedPreKeyStore().removeAllSignedPreKeys();
-            getPniKyberPreKeyStore().removeAllKyberPreKeys();
-            aciActiveLastResortKyberPreKeyId = -1;
-            pniActiveLastResortKyberPreKeyId = -1;
+            clearAllPreKeys(ServiceIdType.PNI);
         }
 
         this.pni = updatedPni;
@@ -1509,7 +1355,7 @@ public class SignalAccount implements Closeable {
         setPni(updatedPni);
 
         setPniIdentityKeyPair(pniIdentityKeyPair);
-        addPniSignedPreKey(pniSignedPreKey);
+        addSignedPreKey(ServiceIdType.PNI, pniSignedPreKey);
         setLocalPniRegistrationId(localPniRegistrationId);
     }
 
@@ -1552,35 +1398,33 @@ public class SignalAccount implements Closeable {
     }
 
     public IdentityKeyPair getIdentityKeyPair(ServiceIdType serviceIdType) {
-        return serviceIdType.equals(ServiceIdType.ACI) ? aciIdentityKeyPair : pniIdentityKeyPair;
+        return getAccountData(serviceIdType).getIdentityKeyPair();
     }
 
     public IdentityKeyPair getAciIdentityKeyPair() {
-        return aciIdentityKeyPair;
+        return aciAccountData.getIdentityKeyPair();
     }
 
     public IdentityKeyPair getPniIdentityKeyPair() {
-        return pniIdentityKeyPair;
+        return pniAccountData.getIdentityKeyPair();
     }
 
     public void setPniIdentityKeyPair(final IdentityKeyPair identityKeyPair) {
-        pniIdentityKeyPair = identityKeyPair;
-        final var pniPublicKey = identityKeyPair.getPublicKey();
-        getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey);
-        getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED);
+        pniAccountData.setIdentityKeyPair(identityKeyPair);
+        trustSelfIdentity(ServiceIdType.PNI);
         save();
     }
 
     public int getLocalRegistrationId() {
-        return localRegistrationId;
+        return aciAccountData.getLocalRegistrationId();
     }
 
     public int getLocalPniRegistrationId() {
-        return localPniRegistrationId;
+        return pniAccountData.getLocalRegistrationId();
     }
 
     public void setLocalPniRegistrationId(final int localPniRegistrationId) {
-        this.localPniRegistrationId = localPniRegistrationId;
+        pniAccountData.setLocalRegistrationId(localPniRegistrationId);
         save();
     }
 
@@ -1711,15 +1555,15 @@ public class SignalAccount implements Closeable {
     }
 
     public int getPreKeyIdOffset(ServiceIdType serviceIdType) {
-        return serviceIdType.equals(ServiceIdType.ACI) ? aciPreKeyIdOffset : pniPreKeyIdOffset;
+        return getAccountData(serviceIdType).getPreKeyMetadata().preKeyIdOffset;
     }
 
     public int getNextSignedPreKeyId(ServiceIdType serviceIdType) {
-        return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId;
+        return getAccountData(serviceIdType).getPreKeyMetadata().nextSignedPreKeyId;
     }
 
     public int getKyberPreKeyIdOffset(ServiceIdType serviceIdType) {
-        return serviceIdType.equals(ServiceIdType.ACI) ? aciKyberPreKeyIdOffset : pniKyberPreKeyIdOffset;
+        return getAccountData(serviceIdType).getPreKeyMetadata().kyberPreKeyIdOffset;
     }
 
     public boolean isRegistered() {
@@ -1778,22 +1622,26 @@ public class SignalAccount implements Closeable {
         save();
 
         clearAllPreKeys();
-        getAciSessionStore().archiveAllSessions();
-        getPniSessionStore().archiveAllSessions();
+        aciAccountData.getSessionStore().archiveAllSessions();
+        pniAccountData.getSessionStore().archiveAllSessions();
         getSenderKeyStore().deleteAll();
         getRecipientTrustedResolver().resolveSelfRecipientTrusted(getSelfRecipientAddress());
-        final var aciPublicKey = getAciIdentityKeyPair().getPublicKey();
-        getIdentityKeyStore().saveIdentity(getAci(), aciPublicKey);
-        getIdentityKeyStore().setIdentityTrustLevel(getAci(), aciPublicKey, TrustLevel.TRUSTED_VERIFIED);
+        trustSelfIdentity(ServiceIdType.ACI);
         if (getPniIdentityKeyPair() == null) {
             setPniIdentityKeyPair(KeyUtils.generateIdentityKeyPair());
         } else {
-            final var pniPublicKey = getPniIdentityKeyPair().getPublicKey();
-            getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey);
-            getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED);
+            trustSelfIdentity(ServiceIdType.PNI);
         }
     }
 
+    private void trustSelfIdentity(ServiceIdType serviceIdType) {
+        final var accountData = getAccountData(serviceIdType);
+        final var serviceId = accountData.getServiceId();
+        final var publicKey = accountData.getIdentityKeyPair().getPublicKey();
+        getIdentityKeyStore().saveIdentity(serviceId, publicKey);
+        getIdentityKeyStore().setIdentityTrustLevel(serviceId, publicKey, TrustLevel.TRUSTED_VERIFIED);
+    }
+
     public void deleteAccountData() throws IOException {
         close();
         try (final var files = Files.walk(getUserPath(dataPath, accountPath).toPath())
@@ -1850,4 +1698,95 @@ public class SignalAccount implements Closeable {
 
         void call();
     }
+
+    private static class PreKeyMetadata {
+
+        private int preKeyIdOffset = 1;
+        private int nextSignedPreKeyId = 1;
+        private int kyberPreKeyIdOffset = 1;
+        private int activeLastResortKyberPreKeyId = -1;
+    }
+
+    public class AccountData {
+
+        private final ServiceIdType serviceIdType;
+        private IdentityKeyPair identityKeyPair;
+        private int localRegistrationId;
+        private final PreKeyMetadata preKeyMetadata = new PreKeyMetadata();
+
+        private SignalProtocolStore signalProtocolStore;
+        private PreKeyStore preKeyStore;
+        private SignedPreKeyStore signedPreKeyStore;
+        private KyberPreKeyStore kyberPreKeyStore;
+        private SessionStore sessionStore;
+        private SignalIdentityKeyStore identityKeyStore;
+
+        public AccountData(final ServiceIdType serviceIdType) {
+            this.serviceIdType = serviceIdType;
+        }
+
+        public ServiceId getServiceId() {
+            return getAccountId(serviceIdType);
+        }
+
+        public IdentityKeyPair getIdentityKeyPair() {
+            return identityKeyPair;
+        }
+
+        private void setIdentityKeyPair(final IdentityKeyPair identityKeyPair) {
+            this.identityKeyPair = identityKeyPair;
+        }
+
+        public int getLocalRegistrationId() {
+            return localRegistrationId;
+        }
+
+        private void setLocalRegistrationId(final int localRegistrationId) {
+            this.localRegistrationId = localRegistrationId;
+            this.identityKeyStore = null;
+        }
+
+        public PreKeyMetadata getPreKeyMetadata() {
+            return preKeyMetadata;
+        }
+
+        private SignalServiceAccountDataStore getSignalServiceAccountDataStore() {
+            return getOrCreate(() -> signalProtocolStore,
+                    () -> signalProtocolStore = new SignalProtocolStore(getPreKeyStore(),
+                            getSignedPreKeyStore(),
+                            getKyberPreKeyStore(),
+                            getSessionStore(),
+                            getIdentityKeyStore(),
+                            getSenderKeyStore(),
+                            SignalAccount.this::isMultiDevice));
+        }
+
+        private PreKeyStore getPreKeyStore() {
+            return getOrCreate(() -> preKeyStore,
+                    () -> preKeyStore = new PreKeyStore(getAccountDatabase(), serviceIdType));
+        }
+
+        private SignedPreKeyStore getSignedPreKeyStore() {
+            return getOrCreate(() -> signedPreKeyStore,
+                    () -> signedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), serviceIdType));
+        }
+
+        private KyberPreKeyStore getKyberPreKeyStore() {
+            return getOrCreate(() -> kyberPreKeyStore,
+                    () -> kyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), serviceIdType));
+        }
+
+        public SessionStore getSessionStore() {
+            return getOrCreate(() -> sessionStore,
+                    () -> sessionStore = new SessionStore(getAccountDatabase(), serviceIdType));
+        }
+
+        public SignalIdentityKeyStore getIdentityKeyStore() {
+            return getOrCreate(() -> identityKeyStore,
+                    () -> identityKeyStore = new SignalIdentityKeyStore(getRecipientResolver(),
+                            () -> identityKeyPair,
+                            localRegistrationId,
+                            SignalAccount.this.getIdentityKeyStore()));
+        }
+    }
 }
index 3ee5657ebda436abd6d4f3bbd81e7f5ff969419a..8db8d0a4018946436142e25e3a9e34dbc3da65a7 100644 (file)
@@ -120,4 +120,8 @@ public class KeyUtils {
         secureRandom.nextBytes(secret);
         return secret;
     }
+
+    public static int getRandomInt(int bound) {
+        return secureRandom.nextInt(bound);
+    }
 }