From: AsamK Date: Sun, 18 Jun 2023 12:44:57 +0000 (+0200) Subject: Refactor ACI/PNI store handling X-Git-Tag: v0.12.0~24 X-Git-Url: https://git.nmode.ca/signal-cli/commitdiff_plain/0ebfd989d12f3479ec220eb98ccfc72bbb47a5c3?ds=sidebyside Refactor ACI/PNI store handling --- diff --git a/lib/src/main/java/org/asamk/signal/manager/actions/RenewSessionAction.java b/lib/src/main/java/org/asamk/signal/manager/actions/RenewSessionAction.java index 2718bc26..40117246 100644 --- a/lib/src/main/java/org/asamk/signal/manager/actions/RenewSessionAction.java +++ b/lib/src/main/java/org/asamk/signal/manager/actions/RenewSessionAction.java @@ -8,15 +8,17 @@ public class RenewSessionAction implements HandleAction { private final RecipientId recipientId; private final ServiceId serviceId; + private final ServiceId accountId; - public RenewSessionAction(final RecipientId recipientId, final ServiceId serviceId) { + public RenewSessionAction(final RecipientId recipientId, final ServiceId serviceId, final ServiceId accountId) { this.recipientId = recipientId; this.serviceId = serviceId; + this.accountId = accountId; } @Override public void execute(Context context) throws Throwable { - context.getAccount().getAciSessionStore().archiveSessions(serviceId); + context.getAccount().getAccountData(accountId).getSessionStore().archiveSessions(serviceId); if (!recipientId.equals(context.getAccount().getSelfRecipientId())) { context.getSendHelper().sendNullMessage(recipientId); } diff --git a/lib/src/main/java/org/asamk/signal/manager/actions/SendRetryMessageRequestAction.java b/lib/src/main/java/org/asamk/signal/manager/actions/SendRetryMessageRequestAction.java index 2300eae9..add09e72 100644 --- a/lib/src/main/java/org/asamk/signal/manager/actions/SendRetryMessageRequestAction.java +++ b/lib/src/main/java/org/asamk/signal/manager/actions/SendRetryMessageRequestAction.java @@ -18,22 +18,25 @@ public class SendRetryMessageRequestAction implements HandleAction { private final ServiceId serviceId; private final ProtocolException protocolException; private final SignalServiceEnvelope envelope; + private final ServiceId accountId; public SendRetryMessageRequestAction( final RecipientId recipientId, final ServiceId serviceId, final ProtocolException protocolException, - final SignalServiceEnvelope envelope + final SignalServiceEnvelope envelope, + final ServiceId accountId ) { this.recipientId = recipientId; this.serviceId = serviceId; this.protocolException = protocolException; this.envelope = envelope; + this.accountId = accountId; } @Override public void execute(Context context) throws Throwable { - context.getAccount().getAciSessionStore().archiveSessions(serviceId); + context.getAccount().getAccountData(accountId).getSessionStore().archiveSessions(serviceId); int senderDevice = protocolException.getSenderDevice(); Optional groupId = protocolException.getGroupId().isPresent() ? Optional.of(GroupId.unknownVersion( diff --git a/lib/src/main/java/org/asamk/signal/manager/helper/IncomingMessageHandler.java b/lib/src/main/java/org/asamk/signal/manager/helper/IncomingMessageHandler.java index 3b616644..7cc3d09d 100644 --- a/lib/src/main/java/org/asamk/signal/manager/helper/IncomingMessageHandler.java +++ b/lib/src/main/java/org/asamk/signal/manager/helper/IncomingMessageHandler.java @@ -181,12 +181,13 @@ public final class IncomingMessageHandler { .contains(Profile.Capability.senderKey); final var isSelfSenderKeyCapable = selfProfile != null && selfProfile.getCapabilities() .contains(Profile.Capability.senderKey); + final var destination = getDestination(envelope).serviceId(); if (!isSelf && isSenderSenderKeyCapable && isSelfSenderKeyCapable) { logger.debug("Received invalid message, requesting message resend."); - actions.add(new SendRetryMessageRequestAction(sender, serviceId, e, envelope)); + actions.add(new SendRetryMessageRequestAction(sender, serviceId, e, envelope, destination)); } else { logger.debug("Received invalid message, queuing renew session action."); - actions.add(new RenewSessionAction(sender, serviceId)); + actions.add(new RenewSessionAction(sender, serviceId, destination)); } } else { logger.debug("Received invalid message from invalid sender: {}", e.getSender()); @@ -346,7 +347,12 @@ public final class IncomingMessageHandler { senderDeviceId, message.getTimestamp()); if (message.getDeviceId() == account.getDeviceId()) { - handleDecryptionErrorMessage(actions, sender, senderServiceId, senderDeviceId, message); + handleDecryptionErrorMessage(actions, + sender, + senderServiceId, + senderDeviceId, + message, + destination.serviceId()); } else { logger.debug("Request is for another one of our devices"); } @@ -430,7 +436,8 @@ public final class IncomingMessageHandler { final RecipientId sender, final ServiceId senderServiceId, final int senderDeviceId, - final DecryptionErrorMessage message + final DecryptionErrorMessage message, + final ServiceId destination ) { final var logEntries = account.getMessageSendLogStore() .findMessages(senderServiceId, @@ -443,14 +450,14 @@ public final class IncomingMessageHandler { } if (message.getRatchetKey().isPresent()) { - if (account.getAciSessionStore() - .isCurrentRatchetKey(senderServiceId, senderDeviceId, message.getRatchetKey().get())) { + final var sessionStore = account.getAccountData(destination).getSessionStore(); + if (sessionStore.isCurrentRatchetKey(senderServiceId, senderDeviceId, message.getRatchetKey().get())) { if (logEntries.isEmpty()) { logger.debug("Renewing the session with sender"); - actions.add(new RenewSessionAction(sender, senderServiceId)); + actions.add(new RenewSessionAction(sender, senderServiceId, destination)); } else { logger.trace("Archiving the session with sender, a resend message has already been queued"); - context.getAccount().getAciSessionStore().archiveSessions(senderServiceId); + sessionStore.archiveSessions(senderServiceId); } } return; @@ -806,9 +813,12 @@ public final class IncomingMessageHandler { } } + final var selfAddress = isSync ? source : destination; final var conversationPartnerAddress = isSync ? destination : source; if (conversationPartnerAddress != null && message.isEndSession()) { - account.getAciSessionStore().deleteAllSessions(conversationPartnerAddress.serviceId()); + account.getAccountData(selfAddress.serviceId()) + .getSessionStore() + .deleteAllSessions(conversationPartnerAddress.serviceId()); } if (message.isExpirationUpdate() || message.getBody().isPresent()) { if (message.getGroupContext().isPresent()) { @@ -854,10 +864,12 @@ public final class IncomingMessageHandler { if (message.getQuote().isPresent()) { final var quote = message.getQuote().get(); - for (var quotedAttachment : quote.getAttachments()) { - final var thumbnail = quotedAttachment.getThumbnail(); - if (thumbnail != null) { - context.getAttachmentHelper().downloadAttachment(thumbnail); + if (quote.getAttachments() != null) { + for (var quotedAttachment : quote.getAttachments()) { + final var thumbnail = quotedAttachment.getThumbnail(); + if (thumbnail != null) { + context.getAttachmentHelper().downloadAttachment(thumbnail); + } } } } @@ -972,7 +984,9 @@ public final class IncomingMessageHandler { return new DeviceAddress(account.getSelfRecipientId(), account.getAci(), account.getDeviceId()); } final var address = addressOptional.get(); - return new DeviceAddress(context.getRecipientHelper().resolveRecipient(address), address.getServiceId(), 0); + return new DeviceAddress(context.getRecipientHelper().resolveRecipient(address), + address.getServiceId(), + account.getDeviceId()); } private record DeviceAddress(RecipientId recipientId, ServiceId serviceId, int deviceId) {} diff --git a/lib/src/main/java/org/asamk/signal/manager/internal/ManagerImpl.java b/lib/src/main/java/org/asamk/signal/manager/internal/ManagerImpl.java index 90b2f6a4..faea76d7 100644 --- a/lib/src/main/java/org/asamk/signal/manager/internal/ManagerImpl.java +++ b/lib/src/main/java/org/asamk/signal/manager/internal/ManagerImpl.java @@ -82,6 +82,7 @@ import org.whispersystems.signalservice.api.messages.SignalServiceReceiptMessage import org.whispersystems.signalservice.api.messages.SignalServiceTypingMessage; import org.whispersystems.signalservice.api.push.ACI; import org.whispersystems.signalservice.api.push.ServiceId; +import org.whispersystems.signalservice.api.push.ServiceIdType; import org.whispersystems.signalservice.api.util.DeviceNameUtil; import org.whispersystems.signalservice.api.util.InvalidNumberException; import org.whispersystems.signalservice.api.util.PhoneNumberFormatter; @@ -183,8 +184,8 @@ public class ManagerImpl implements Manager { }); disposable.add(account.getIdentityKeyStore().getIdentityChanges().subscribe(serviceId -> { logger.trace("Archiving old sessions for {}", serviceId); - account.getAciSessionStore().archiveSessions(serviceId); - account.getPniSessionStore().archiveSessions(serviceId); + account.getAccountData(ServiceIdType.ACI).getSessionStore().archiveSessions(serviceId); + account.getAccountData(ServiceIdType.PNI).getSessionStore().archiveSessions(serviceId); account.getSenderKeyStore().deleteSharedWith(serviceId); final var recipientId = account.getRecipientResolver().resolveRecipient(serviceId); final var profile = account.getProfileStore().getProfile(recipientId); @@ -775,7 +776,7 @@ public class ManagerImpl implements Manager { .resolveRecipientAddress(recipientId) .serviceId(); if (serviceId.isPresent()) { - account.getAciSessionStore().deleteAllSessions(serviceId.get()); + account.getAccountData(ServiceIdType.ACI).getSessionStore().deleteAllSessions(serviceId.get()); } } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java index 76867b72..931a45a2 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java @@ -89,7 +89,6 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.FileChannel; import java.nio.channels.FileLock; import java.nio.file.Files; -import java.security.SecureRandom; import java.sql.Connection; import java.sql.SQLException; import java.util.Base64; @@ -136,36 +135,14 @@ public class SignalAccount implements Closeable { private StorageKey storageKey; private long storageManifestVersion = -1; private ProfileKey profileKey; - private int aciPreKeyIdOffset = 1; - private int aciNextSignedPreKeyId = 1; - private int pniPreKeyIdOffset = 1; - private int pniNextSignedPreKeyId = 1; - private int aciKyberPreKeyIdOffset = 1; - private int aciActiveLastResortKyberPreKeyId = -1; - private int pniKyberPreKeyIdOffset = 1; - private int pniActiveLastResortKyberPreKeyId = -1; - private IdentityKeyPair aciIdentityKeyPair; - private IdentityKeyPair pniIdentityKeyPair; - private int localRegistrationId; - private int localPniRegistrationId; private Settings settings; private long lastReceiveTimestamp = 0; private boolean registered = false; - private SignalProtocolStore aciSignalProtocolStore; - private SignalProtocolStore pniSignalProtocolStore; - private PreKeyStore aciPreKeyStore; - private SignedPreKeyStore aciSignedPreKeyStore; - private KyberPreKeyStore aciKyberPreKeyStore; - private PreKeyStore pniPreKeyStore; - private SignedPreKeyStore pniSignedPreKeyStore; - private KyberPreKeyStore pniKyberPreKeyStore; - private SessionStore aciSessionStore; - private SessionStore pniSessionStore; + private final AccountData aciAccountData = new AccountData(ServiceIdType.ACI); + private final AccountData pniAccountData = new AccountData(ServiceIdType.PNI); private IdentityKeyStore identityKeyStore; - private SignalIdentityKeyStore aciIdentityKeyStore; - private SignalIdentityKeyStore pniIdentityKeyStore; private SenderKeyStore senderKeyStore; private GroupStore groupStore; private RecipientStore recipientStore; @@ -231,10 +208,10 @@ public class SignalAccount implements Closeable { signalAccount.profileKey = profileKey; signalAccount.dataPath = dataPath; - signalAccount.aciIdentityKeyPair = aciIdentityKey; - signalAccount.pniIdentityKeyPair = pniIdentityKey; - signalAccount.localRegistrationId = registrationId; - signalAccount.localPniRegistrationId = pniRegistrationId; + signalAccount.aciAccountData.setIdentityKeyPair(aciIdentityKey); + signalAccount.pniAccountData.setIdentityKeyPair(pniIdentityKey); + signalAccount.aciAccountData.setLocalRegistrationId(registrationId); + signalAccount.pniAccountData.setLocalRegistrationId(pniRegistrationId); signalAccount.settings = settings; signalAccount.configurationStore = new ConfigurationStore(signalAccount::saveConfigurationStore); @@ -296,8 +273,8 @@ public class SignalAccount implements Closeable { profileKey); signalAccount.getRecipientTrustedResolver() .resolveSelfRecipientTrusted(signalAccount.getSelfRecipientAddress()); - signalAccount.getAciSessionStore().archiveAllSessions(); - signalAccount.getPniSessionStore().archiveAllSessions(); + signalAccount.aciAccountData.getSessionStore().archiveAllSessions(); + signalAccount.pniAccountData.getSessionStore().archiveAllSessions(); signalAccount.getSenderKeyStore().deleteAll(); signalAccount.clearAllPreKeys(); return signalAccount; @@ -308,16 +285,17 @@ public class SignalAccount implements Closeable { } private void clearAllPreKeys() { - resetPreKeyOffsets(ServiceIdType.ACI); - resetPreKeyOffsets(ServiceIdType.PNI); - resetKyberPreKeyOffsets(ServiceIdType.ACI); - resetKyberPreKeyOffsets(ServiceIdType.PNI); - this.getAciPreKeyStore().removeAllPreKeys(); - this.getAciSignedPreKeyStore().removeAllSignedPreKeys(); - this.getAciKyberPreKeyStore().removeAllKyberPreKeys(); - this.getPniPreKeyStore().removeAllPreKeys(); - this.getPniSignedPreKeyStore().removeAllSignedPreKeys(); - this.getPniKyberPreKeyStore().removeAllKyberPreKeys(); + clearAllPreKeys(ServiceIdType.ACI); + clearAllPreKeys(ServiceIdType.PNI); + } + + private void clearAllPreKeys(ServiceIdType serviceIdType) { + final var accountData = getAccountData(serviceIdType); + resetPreKeyOffsets(serviceIdType); + resetKyberPreKeyOffsets(serviceIdType); + accountData.getPreKeyStore().removeAllPreKeys(); + accountData.getSignedPreKeyStore().removeAllSignedPreKeys(); + accountData.getKyberPreKeyStore().removeAllKyberPreKeys(); save(); } @@ -347,8 +325,8 @@ public class SignalAccount implements Closeable { signalAccount.dataPath = dataPath; signalAccount.accountPath = accountPath; signalAccount.serviceEnvironment = serviceEnvironment; - signalAccount.localRegistrationId = registrationId; - signalAccount.localPniRegistrationId = pniRegistrationId; + signalAccount.aciAccountData.setLocalRegistrationId(registrationId); + signalAccount.pniAccountData.setLocalRegistrationId(pniRegistrationId); signalAccount.settings = settings; signalAccount.setProvisioningData(number, aci, @@ -391,8 +369,8 @@ public class SignalAccount implements Closeable { getProfileStore().storeSelfProfileKey(getSelfRecipientId(), getProfileKey()); this.encryptedDeviceName = encryptedDeviceName; this.deviceId = deviceId; - this.aciIdentityKeyPair = aciIdentity; - this.pniIdentityKeyPair = pniIdentity; + this.aciAccountData.setIdentityKeyPair(aciIdentity); + this.pniAccountData.setIdentityKeyPair(pniIdentity); this.registered = true; this.isMultiDevice = true; this.lastReceiveTimestamp = 0; @@ -400,13 +378,9 @@ public class SignalAccount implements Closeable { this.storageManifestVersion = -1; this.setStorageManifest(null); this.storageKey = null; - final var aciPublicKey = getAciIdentityKeyPair().getPublicKey(); - getIdentityKeyStore().saveIdentity(getAci(), aciPublicKey); - getIdentityKeyStore().setIdentityTrustLevel(getAci(), aciPublicKey, TrustLevel.TRUSTED_VERIFIED); + trustSelfIdentity(ServiceIdType.ACI); if (getPniIdentityKeyPair() != null) { - final var pniPublicKey = getPniIdentityKeyPair().getPublicKey(); - getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey); - getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED); + trustSelfIdentity(ServiceIdType.PNI); } } @@ -438,8 +412,8 @@ public class SignalAccount implements Closeable { getMessageCache().deleteMessages(recipientId); if (recipientAddress.serviceId().isPresent()) { final var serviceId = recipientAddress.serviceId().get(); - getAciSessionStore().deleteAllSessions(serviceId); - getPniSessionStore().deleteAllSessions(serviceId); + aciAccountData.getSessionStore().deleteAllSessions(serviceId); + pniAccountData.getSessionStore().deleteAllSessions(serviceId); getIdentityKeyStore().deleteIdentity(serviceId); getSenderKeyStore().deleteAll(serviceId); } @@ -595,9 +569,9 @@ public class SignalAccount implements Closeable { registrationId = rootNode.get("registrationId").asInt(); } if (rootNode.hasNonNull("pniRegistrationId")) { - localPniRegistrationId = rootNode.get("pniRegistrationId").asInt(); + pniAccountData.setLocalRegistrationId(rootNode.get("pniRegistrationId").asInt()); } else { - localPniRegistrationId = KeyHelper.generateRegistrationId(false); + pniAccountData.setLocalRegistrationId(KeyHelper.generateRegistrationId(false)); } IdentityKeyPair aciIdentityKeyPair = null; if (rootNode.hasNonNull("identityPrivateKey") && rootNode.hasNonNull("identityKey")) { @@ -608,7 +582,7 @@ public class SignalAccount implements Closeable { if (rootNode.hasNonNull("pniIdentityPrivateKey") && rootNode.hasNonNull("pniIdentityKey")) { final var publicKeyBytes = Base64.getDecoder().decode(rootNode.get("pniIdentityKey").asText()); final var privateKeyBytes = Base64.getDecoder().decode(rootNode.get("pniIdentityPrivateKey").asText()); - pniIdentityKeyPair = KeyUtils.getIdentityKeyPair(publicKeyBytes, privateKeyBytes); + pniAccountData.setIdentityKeyPair(KeyUtils.getIdentityKeyPair(publicKeyBytes, privateKeyBytes)); } if (rootNode.hasNonNull("registrationLockPin")) { @@ -624,44 +598,46 @@ public class SignalAccount implements Closeable { storageManifestVersion = rootNode.get("storageManifestVersion").asLong(); } if (rootNode.hasNonNull("preKeyIdOffset")) { - aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1); + aciAccountData.preKeyMetadata.preKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1); } else { - aciPreKeyIdOffset = getRandomPreKeyIdOffset(); + aciAccountData.preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset(); } if (rootNode.hasNonNull("nextSignedPreKeyId")) { - aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1); + aciAccountData.preKeyMetadata.nextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1); } else { - aciNextSignedPreKeyId = getRandomPreKeyIdOffset(); + aciAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset(); } if (rootNode.hasNonNull("pniPreKeyIdOffset")) { - pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1); + pniAccountData.preKeyMetadata.preKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1); } else { - pniPreKeyIdOffset = getRandomPreKeyIdOffset(); + pniAccountData.preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset(); } if (rootNode.hasNonNull("pniNextSignedPreKeyId")) { - pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1); + pniAccountData.preKeyMetadata.nextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1); } else { - pniNextSignedPreKeyId = getRandomPreKeyIdOffset(); + pniAccountData.preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset(); } if (rootNode.hasNonNull("kyberPreKeyIdOffset")) { - aciKyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1); + aciAccountData.preKeyMetadata.kyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1); } else { - aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); + aciAccountData.preKeyMetadata.kyberPreKeyIdOffset = getRandomPreKeyIdOffset(); } if (rootNode.hasNonNull("activeLastResortKyberPreKeyId")) { - aciActiveLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId").asInt(-1); + aciAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId") + .asInt(-1); } else { - aciActiveLastResortKyberPreKeyId = -1; + aciAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = -1; } if (rootNode.hasNonNull("pniKyberPreKeyIdOffset")) { - pniKyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1); + pniAccountData.preKeyMetadata.kyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1); } else { - pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); + pniAccountData.preKeyMetadata.kyberPreKeyIdOffset = getRandomPreKeyIdOffset(); } if (rootNode.hasNonNull("pniActiveLastResortKyberPreKeyId")) { - pniActiveLastResortKyberPreKeyId = rootNode.get("pniActiveLastResortKyberPreKeyId").asInt(-1); + pniAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = rootNode.get( + "pniActiveLastResortKyberPreKeyId").asInt(-1); } else { - pniActiveLastResortKyberPreKeyId = -1; + pniAccountData.preKeyMetadata.activeLastResortKyberPreKeyId = -1; } if (rootNode.hasNonNull("profileKey")) { try { @@ -687,22 +663,22 @@ public class SignalAccount implements Closeable { } final var legacyAciPreKeysPath = getAciPreKeysPath(dataPath, accountPath); if (legacyAciPreKeysPath.exists()) { - LegacyPreKeyStore.migrate(legacyAciPreKeysPath, getAciPreKeyStore()); + LegacyPreKeyStore.migrate(legacyAciPreKeysPath, aciAccountData.getPreKeyStore()); migratedLegacyConfig = true; } final var legacyPniPreKeysPath = getPniPreKeysPath(dataPath, accountPath); if (legacyPniPreKeysPath.exists()) { - LegacyPreKeyStore.migrate(legacyPniPreKeysPath, getPniPreKeyStore()); + LegacyPreKeyStore.migrate(legacyPniPreKeysPath, pniAccountData.getPreKeyStore()); migratedLegacyConfig = true; } final var legacyAciSignedPreKeysPath = getAciSignedPreKeysPath(dataPath, accountPath); if (legacyAciSignedPreKeysPath.exists()) { - LegacySignedPreKeyStore.migrate(legacyAciSignedPreKeysPath, getAciSignedPreKeyStore()); + LegacySignedPreKeyStore.migrate(legacyAciSignedPreKeysPath, aciAccountData.getSignedPreKeyStore()); migratedLegacyConfig = true; } final var legacyPniSignedPreKeysPath = getPniSignedPreKeysPath(dataPath, accountPath); if (legacyPniSignedPreKeysPath.exists()) { - LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, getPniSignedPreKeyStore()); + LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, pniAccountData.getSignedPreKeyStore()); migratedLegacyConfig = true; } final var legacySessionsPath = getSessionsPath(dataPath, accountPath); @@ -710,7 +686,7 @@ public class SignalAccount implements Closeable { LegacySessionStore.migrate(legacySessionsPath, getRecipientResolver(), getRecipientAddressResolver(), - getAciSessionStore()); + aciAccountData.getSessionStore()); migratedLegacyConfig = true; } final var legacyIdentitiesPath = getIdentitiesPath(dataPath, accountPath); @@ -731,8 +707,8 @@ public class SignalAccount implements Closeable { migratedLegacyConfig = true; } - this.aciIdentityKeyPair = aciIdentityKeyPair; - this.localRegistrationId = registrationId; + this.aciAccountData.setIdentityKeyPair(aciIdentityKeyPair); + this.aciAccountData.setLocalRegistrationId(registrationId); migratedLegacyConfig = loadLegacyStores(rootNode, legacySignalProtocolStore) || migratedLegacyConfig; @@ -805,7 +781,7 @@ public class SignalAccount implements Closeable { logger.debug("Migrating legacy pre key store."); for (var entry : legacySignalProtocolStore.getLegacyPreKeyStore().getPreKeys().entrySet()) { try { - getAciPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue())); + aciAccountData.getPreKeyStore().storePreKey(entry.getKey(), new PreKeyRecord(entry.getValue())); } catch (InvalidMessageException e) { logger.warn("Failed to migrate pre key, ignoring", e); } @@ -817,8 +793,8 @@ public class SignalAccount implements Closeable { logger.debug("Migrating legacy signed pre key store."); for (var entry : legacySignalProtocolStore.getLegacySignedPreKeyStore().getSignedPreKeys().entrySet()) { try { - getAciSignedPreKeyStore().storeSignedPreKey(entry.getKey(), - new SignedPreKeyRecord(entry.getValue())); + aciAccountData.getSignedPreKeyStore() + .storeSignedPreKey(entry.getKey(), new SignedPreKeyRecord(entry.getValue())); } catch (InvalidMessageException e) { logger.warn("Failed to migrate signed pre key, ignoring", e); } @@ -830,8 +806,9 @@ public class SignalAccount implements Closeable { logger.debug("Migrating legacy session store."); for (var session : legacySignalProtocolStore.getLegacySessionStore().getSessions()) { try { - getAciSessionStore().storeSession(new SignalProtocolAddress(session.address.getIdentifier(), - session.deviceId), new SessionRecord(session.sessionRecord)); + aciAccountData.getSessionStore() + .storeSession(new SignalProtocolAddress(session.address.getIdentifier(), session.deviceId), + new SessionRecord(session.sessionRecord)); } catch (Exception e) { logger.warn("Failed to migrate session, ignoring", e); } @@ -981,35 +958,44 @@ public class SignalAccount implements Closeable { .put("isMultiDevice", isMultiDevice) .put("lastReceiveTimestamp", lastReceiveTimestamp) .put("password", password) - .put("registrationId", localRegistrationId) - .put("pniRegistrationId", localPniRegistrationId) + .put("registrationId", aciAccountData.getLocalRegistrationId()) + .put("pniRegistrationId", pniAccountData.getLocalRegistrationId()) .put("identityPrivateKey", - Base64.getEncoder().encodeToString(aciIdentityKeyPair.getPrivateKey().serialize())) + Base64.getEncoder() + .encodeToString(aciAccountData.getIdentityKeyPair().getPrivateKey().serialize())) .put("identityKey", - Base64.getEncoder().encodeToString(aciIdentityKeyPair.getPublicKey().serialize())) + Base64.getEncoder() + .encodeToString(aciAccountData.getIdentityKeyPair().getPublicKey().serialize())) .put("pniIdentityPrivateKey", - pniIdentityKeyPair == null + pniAccountData.getIdentityKeyPair() == null ? null : Base64.getEncoder() - .encodeToString(pniIdentityKeyPair.getPrivateKey().serialize())) + .encodeToString(pniAccountData.getIdentityKeyPair() + .getPrivateKey() + .serialize())) .put("pniIdentityKey", - pniIdentityKeyPair == null + pniAccountData.getIdentityKeyPair() == null ? null - : Base64.getEncoder().encodeToString(pniIdentityKeyPair.getPublicKey().serialize())) + : Base64.getEncoder() + .encodeToString(pniAccountData.getIdentityKeyPair() + .getPublicKey() + .serialize())) .put("registrationLockPin", registrationLockPin) .put("pinMasterKey", pinMasterKey == null ? null : Base64.getEncoder().encodeToString(pinMasterKey.serialize())) .put("storageKey", storageKey == null ? null : Base64.getEncoder().encodeToString(storageKey.serialize())) .put("storageManifestVersion", storageManifestVersion == -1 ? null : storageManifestVersion) - .put("preKeyIdOffset", aciPreKeyIdOffset) - .put("nextSignedPreKeyId", aciNextSignedPreKeyId) - .put("pniPreKeyIdOffset", pniPreKeyIdOffset) - .put("pniNextSignedPreKeyId", pniNextSignedPreKeyId) - .put("kyberPreKeyIdOffset", aciKyberPreKeyIdOffset) - .put("activeLastResortKyberPreKeyId", aciActiveLastResortKyberPreKeyId) - .put("pniKyberPreKeyIdOffset", pniKyberPreKeyIdOffset) - .put("pniActiveLastResortKyberPreKeyId", pniActiveLastResortKyberPreKeyId) + .put("preKeyIdOffset", aciAccountData.getPreKeyMetadata().preKeyIdOffset) + .put("nextSignedPreKeyId", aciAccountData.getPreKeyMetadata().nextSignedPreKeyId) + .put("pniPreKeyIdOffset", pniAccountData.getPreKeyMetadata().preKeyIdOffset) + .put("pniNextSignedPreKeyId", pniAccountData.getPreKeyMetadata().nextSignedPreKeyId) + .put("kyberPreKeyIdOffset", aciAccountData.getPreKeyMetadata().kyberPreKeyIdOffset) + .put("activeLastResortKyberPreKeyId", + aciAccountData.getPreKeyMetadata().activeLastResortKyberPreKeyId) + .put("pniKyberPreKeyIdOffset", pniAccountData.getPreKeyMetadata().kyberPreKeyIdOffset) + .put("pniActiveLastResortKyberPreKeyId", + pniAccountData.getPreKeyMetadata().activeLastResortKyberPreKeyId) .put("profileKey", profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize())) .put("registered", registered) @@ -1054,154 +1040,79 @@ public class SignalAccount implements Closeable { } public void resetPreKeyOffsets(final ServiceIdType serviceIdType) { - if (serviceIdType.equals(ServiceIdType.ACI)) { - this.aciPreKeyIdOffset = getRandomPreKeyIdOffset(); - this.aciNextSignedPreKeyId = getRandomPreKeyIdOffset(); - } else { - this.pniPreKeyIdOffset = getRandomPreKeyIdOffset(); - this.pniNextSignedPreKeyId = getRandomPreKeyIdOffset(); - } + final var preKeyMetadata = getAccountData(serviceIdType).getPreKeyMetadata(); + preKeyMetadata.preKeyIdOffset = getRandomPreKeyIdOffset(); + preKeyMetadata.nextSignedPreKeyId = getRandomPreKeyIdOffset(); save(); } private static int getRandomPreKeyIdOffset() { - return new SecureRandom().nextInt(PREKEY_MAXIMUM_ID); + return KeyUtils.getRandomInt(PREKEY_MAXIMUM_ID); } public void addPreKeys(ServiceIdType serviceIdType, List records) { - if (serviceIdType.equals(ServiceIdType.ACI)) { - addAciPreKeys(records); - } else { - addPniPreKeys(records); - } - } - - private void addAciPreKeys(List records) { + final var accountData = getAccountData(serviceIdType); + final var preKeyMetadata = accountData.getPreKeyMetadata(); for (var record : records) { - if (aciPreKeyIdOffset != record.getId()) { - logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset); + if (preKeyMetadata.preKeyIdOffset != record.getId()) { + logger.error("Invalid pre key id {}, expected {}", record.getId(), preKeyMetadata.preKeyIdOffset); throw new AssertionError("Invalid pre key id"); } - getAciPreKeyStore().storePreKey(record.getId(), record); - aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; - } - save(); - } - - private void addPniPreKeys(List records) { - for (var record : records) { - if (pniPreKeyIdOffset != record.getId()) { - logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset); - throw new AssertionError("Invalid pre key id"); - } - getPniPreKeyStore().storePreKey(record.getId(), record); - pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; + accountData.getPreKeyStore().storePreKey(record.getId(), record); + preKeyMetadata.preKeyIdOffset = (preKeyMetadata.preKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; } save(); } public void addSignedPreKey(ServiceIdType serviceIdType, SignedPreKeyRecord record) { - if (serviceIdType.equals(ServiceIdType.ACI)) { - addAciSignedPreKey(record); - } else { - addPniSignedPreKey(record); - } - } - - public void addAciSignedPreKey(SignedPreKeyRecord record) { - if (aciNextSignedPreKeyId != record.getId()) { - logger.error("Invalid signed pre key id {}, expected {}", record.getId(), aciNextSignedPreKeyId); - throw new AssertionError("Invalid signed pre key id"); - } - getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record); - aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID; - save(); - } - - public void addPniSignedPreKey(SignedPreKeyRecord record) { - if (pniNextSignedPreKeyId != record.getId()) { - logger.error("Invalid signed pre key id {}, expected {}", record.getId(), pniNextSignedPreKeyId); + final var accountData = getAccountData(serviceIdType); + final var preKeyMetadata = accountData.getPreKeyMetadata(); + if (preKeyMetadata.nextSignedPreKeyId != record.getId()) { + logger.error("Invalid signed pre key id {}, expected {}", + record.getId(), + preKeyMetadata.nextSignedPreKeyId); throw new AssertionError("Invalid signed pre key id"); } - getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record); - pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID; + accountData.getSignedPreKeyStore().storeSignedPreKey(record.getId(), record); + preKeyMetadata.nextSignedPreKeyId = (preKeyMetadata.nextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID; save(); } public void resetKyberPreKeyOffsets(final ServiceIdType serviceIdType) { - if (serviceIdType.equals(ServiceIdType.ACI)) { - this.aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); - this.aciActiveLastResortKyberPreKeyId = -1; - } else { - this.pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); - this.pniActiveLastResortKyberPreKeyId = -1; - } + final var preKeyMetadata = getAccountData(serviceIdType).getPreKeyMetadata(); + preKeyMetadata.kyberPreKeyIdOffset = getRandomPreKeyIdOffset(); + preKeyMetadata.activeLastResortKyberPreKeyId = -1; save(); } public void addKyberPreKeys(ServiceIdType serviceIdType, List records) { - if (serviceIdType.equals(ServiceIdType.ACI)) { - addAciKyberPreKeys(records); - } else { - addPniKyberPreKeys(records); - } - } - - private void addAciKyberPreKeys(List records) { + final var accountData = getAccountData(serviceIdType); + final var preKeyMetadata = accountData.getPreKeyMetadata(); for (var record : records) { - if (aciKyberPreKeyIdOffset != record.getId()) { - logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), aciKyberPreKeyIdOffset); + if (preKeyMetadata.kyberPreKeyIdOffset != record.getId()) { + logger.error("Invalid kyber pre key id {}, expected {}", + record.getId(), + preKeyMetadata.kyberPreKeyIdOffset); throw new AssertionError("Invalid kyber pre key id"); } - getAciKyberPreKeyStore().storeKyberPreKey(record.getId(), record); - aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; - } - save(); - } - - private void addPniKyberPreKeys(List records) { - for (var record : records) { - if (pniKyberPreKeyIdOffset != record.getId()) { - logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), pniKyberPreKeyIdOffset); - throw new AssertionError("Invalid kyber pre key id"); - } - getPniKyberPreKeyStore().storeKyberPreKey(record.getId(), record); - pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; + accountData.getKyberPreKeyStore().storeKyberPreKey(record.getId(), record); + preKeyMetadata.kyberPreKeyIdOffset = (preKeyMetadata.kyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; } save(); } public void addLastResortKyberPreKey(ServiceIdType serviceIdType, KyberPreKeyRecord record) { - if (serviceIdType.equals(ServiceIdType.ACI)) { - addAciLastResortKyberPreKey(record); - } else { - addPniLastResortKyberPreKey(record); - } - } - - public void addAciLastResortKyberPreKey(KyberPreKeyRecord record) { - if (aciKyberPreKeyIdOffset != record.getId()) { + final var accountData = getAccountData(serviceIdType); + final var preKeyMetadata = accountData.getPreKeyMetadata(); + if (preKeyMetadata.kyberPreKeyIdOffset != record.getId()) { logger.error("Invalid last resort kyber pre key id {}, expected {}", record.getId(), - aciKyberPreKeyIdOffset); + preKeyMetadata.kyberPreKeyIdOffset); throw new AssertionError("Invalid last resort kyber pre key id"); } - getAciKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record); - aciActiveLastResortKyberPreKeyId = record.getId(); - aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; - save(); - } - - public void addPniLastResortKyberPreKey(KyberPreKeyRecord record) { - if (pniKyberPreKeyIdOffset != record.getId()) { - logger.error("Invalid last resort kyber pre key id {}, expected {}", - record.getId(), - pniKyberPreKeyIdOffset); - throw new AssertionError("Invalid last resort kyber pre key id"); - } - getPniKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record); - pniActiveLastResortKyberPreKeyId = record.getId(); - pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; + accountData.getKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record); + preKeyMetadata.activeLastResortKyberPreKeyId = record.getId(); + preKeyMetadata.kyberPreKeyIdOffset = (preKeyMetadata.kyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; save(); } @@ -1209,6 +1120,23 @@ public class SignalAccount implements Closeable { return previousStorageVersion; } + public AccountData getAccountData(ServiceIdType serviceIdType) { + return switch (serviceIdType) { + case ACI -> aciAccountData; + case PNI -> pniAccountData; + }; + } + + public AccountData getAccountData(ServiceId accountIdentifier) { + if (accountIdentifier.equals(aci)) { + return aciAccountData; + } else if (accountIdentifier.equals(pni)) { + return pniAccountData; + } else { + throw new IllegalArgumentException("No matching account data found for " + accountIdentifier); + } + } + public SignalServiceDataStore getSignalServiceDataStore() { return new SignalServiceDataStore() { @Override @@ -1224,12 +1152,12 @@ public class SignalAccount implements Closeable { @Override public SignalServiceAccountDataStore aci() { - return getAciSignalServiceAccountDataStore(); + return aciAccountData.getSignalServiceAccountDataStore(); } @Override public SignalServiceAccountDataStore pni() { - return getPniSignalServiceAccountDataStore(); + return pniAccountData.getSignalServiceAccountDataStore(); } @Override @@ -1239,89 +1167,11 @@ public class SignalAccount implements Closeable { }; } - private SignalServiceAccountDataStore getAciSignalServiceAccountDataStore() { - return getOrCreate(() -> aciSignalProtocolStore, - () -> aciSignalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(), - getAciSignedPreKeyStore(), - getAciKyberPreKeyStore(), - getAciSessionStore(), - getAciIdentityKeyStore(), - getSenderKeyStore(), - this::isMultiDevice)); - } - - private SignalServiceAccountDataStore getPniSignalServiceAccountDataStore() { - return getOrCreate(() -> pniSignalProtocolStore, - () -> pniSignalProtocolStore = new SignalProtocolStore(getPniPreKeyStore(), - getPniSignedPreKeyStore(), - getPniKyberPreKeyStore(), - getPniSessionStore(), - getPniIdentityKeyStore(), - getSenderKeyStore(), - this::isMultiDevice)); - } - - private PreKeyStore getAciPreKeyStore() { - return getOrCreate(() -> aciPreKeyStore, - () -> aciPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.ACI)); - } - - private SignedPreKeyStore getAciSignedPreKeyStore() { - return getOrCreate(() -> aciSignedPreKeyStore, - () -> aciSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.ACI)); - } - - private KyberPreKeyStore getAciKyberPreKeyStore() { - return getOrCreate(() -> aciKyberPreKeyStore, - () -> aciKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.ACI)); - } - - private PreKeyStore getPniPreKeyStore() { - return getOrCreate(() -> pniPreKeyStore, - () -> pniPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.PNI)); - } - - private SignedPreKeyStore getPniSignedPreKeyStore() { - return getOrCreate(() -> pniSignedPreKeyStore, - () -> pniSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.PNI)); - } - - private KyberPreKeyStore getPniKyberPreKeyStore() { - return getOrCreate(() -> pniKyberPreKeyStore, - () -> pniKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.PNI)); - } - - public SessionStore getAciSessionStore() { - return getOrCreate(() -> aciSessionStore, - () -> aciSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.ACI)); - } - - public SessionStore getPniSessionStore() { - return getOrCreate(() -> pniSessionStore, - () -> pniSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.PNI)); - } - public IdentityKeyStore getIdentityKeyStore() { return getOrCreate(() -> identityKeyStore, () -> identityKeyStore = new IdentityKeyStore(getAccountDatabase(), settings.trustNewIdentity())); } - public SignalIdentityKeyStore getAciIdentityKeyStore() { - return getOrCreate(() -> aciIdentityKeyStore, - () -> aciIdentityKeyStore = new SignalIdentityKeyStore(getRecipientResolver(), - () -> aciIdentityKeyPair, - localRegistrationId, - getIdentityKeyStore())); - } - - public SignalIdentityKeyStore getPniIdentityKeyStore() { - return getOrCreate(() -> pniIdentityKeyStore, - () -> pniIdentityKeyStore = new SignalIdentityKeyStore(getRecipientResolver(), - () -> pniIdentityKeyPair, - localRegistrationId, - getIdentityKeyStore())); - } - public GroupStore getGroupStore() { return getOrCreate(() -> groupStore, () -> groupStore = new GroupStore(getAccountDatabase(), @@ -1489,11 +1339,7 @@ public class SignalAccount implements Closeable { if (this.pni != null && !this.pni.equals(updatedPni)) { // Clear data for old PNI identityKeyStore.deleteIdentity(this.pni); - getPniPreKeyStore().removeAllPreKeys(); - getPniSignedPreKeyStore().removeAllSignedPreKeys(); - getPniKyberPreKeyStore().removeAllKyberPreKeys(); - aciActiveLastResortKyberPreKeyId = -1; - pniActiveLastResortKyberPreKeyId = -1; + clearAllPreKeys(ServiceIdType.PNI); } this.pni = updatedPni; @@ -1509,7 +1355,7 @@ public class SignalAccount implements Closeable { setPni(updatedPni); setPniIdentityKeyPair(pniIdentityKeyPair); - addPniSignedPreKey(pniSignedPreKey); + addSignedPreKey(ServiceIdType.PNI, pniSignedPreKey); setLocalPniRegistrationId(localPniRegistrationId); } @@ -1552,35 +1398,33 @@ public class SignalAccount implements Closeable { } public IdentityKeyPair getIdentityKeyPair(ServiceIdType serviceIdType) { - return serviceIdType.equals(ServiceIdType.ACI) ? aciIdentityKeyPair : pniIdentityKeyPair; + return getAccountData(serviceIdType).getIdentityKeyPair(); } public IdentityKeyPair getAciIdentityKeyPair() { - return aciIdentityKeyPair; + return aciAccountData.getIdentityKeyPair(); } public IdentityKeyPair getPniIdentityKeyPair() { - return pniIdentityKeyPair; + return pniAccountData.getIdentityKeyPair(); } public void setPniIdentityKeyPair(final IdentityKeyPair identityKeyPair) { - pniIdentityKeyPair = identityKeyPair; - final var pniPublicKey = identityKeyPair.getPublicKey(); - getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey); - getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED); + pniAccountData.setIdentityKeyPair(identityKeyPair); + trustSelfIdentity(ServiceIdType.PNI); save(); } public int getLocalRegistrationId() { - return localRegistrationId; + return aciAccountData.getLocalRegistrationId(); } public int getLocalPniRegistrationId() { - return localPniRegistrationId; + return pniAccountData.getLocalRegistrationId(); } public void setLocalPniRegistrationId(final int localPniRegistrationId) { - this.localPniRegistrationId = localPniRegistrationId; + pniAccountData.setLocalRegistrationId(localPniRegistrationId); save(); } @@ -1711,15 +1555,15 @@ public class SignalAccount implements Closeable { } public int getPreKeyIdOffset(ServiceIdType serviceIdType) { - return serviceIdType.equals(ServiceIdType.ACI) ? aciPreKeyIdOffset : pniPreKeyIdOffset; + return getAccountData(serviceIdType).getPreKeyMetadata().preKeyIdOffset; } public int getNextSignedPreKeyId(ServiceIdType serviceIdType) { - return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId; + return getAccountData(serviceIdType).getPreKeyMetadata().nextSignedPreKeyId; } public int getKyberPreKeyIdOffset(ServiceIdType serviceIdType) { - return serviceIdType.equals(ServiceIdType.ACI) ? aciKyberPreKeyIdOffset : pniKyberPreKeyIdOffset; + return getAccountData(serviceIdType).getPreKeyMetadata().kyberPreKeyIdOffset; } public boolean isRegistered() { @@ -1778,22 +1622,26 @@ public class SignalAccount implements Closeable { save(); clearAllPreKeys(); - getAciSessionStore().archiveAllSessions(); - getPniSessionStore().archiveAllSessions(); + aciAccountData.getSessionStore().archiveAllSessions(); + pniAccountData.getSessionStore().archiveAllSessions(); getSenderKeyStore().deleteAll(); getRecipientTrustedResolver().resolveSelfRecipientTrusted(getSelfRecipientAddress()); - final var aciPublicKey = getAciIdentityKeyPair().getPublicKey(); - getIdentityKeyStore().saveIdentity(getAci(), aciPublicKey); - getIdentityKeyStore().setIdentityTrustLevel(getAci(), aciPublicKey, TrustLevel.TRUSTED_VERIFIED); + trustSelfIdentity(ServiceIdType.ACI); if (getPniIdentityKeyPair() == null) { setPniIdentityKeyPair(KeyUtils.generateIdentityKeyPair()); } else { - final var pniPublicKey = getPniIdentityKeyPair().getPublicKey(); - getIdentityKeyStore().saveIdentity(getPni(), pniPublicKey); - getIdentityKeyStore().setIdentityTrustLevel(getPni(), pniPublicKey, TrustLevel.TRUSTED_VERIFIED); + trustSelfIdentity(ServiceIdType.PNI); } } + private void trustSelfIdentity(ServiceIdType serviceIdType) { + final var accountData = getAccountData(serviceIdType); + final var serviceId = accountData.getServiceId(); + final var publicKey = accountData.getIdentityKeyPair().getPublicKey(); + getIdentityKeyStore().saveIdentity(serviceId, publicKey); + getIdentityKeyStore().setIdentityTrustLevel(serviceId, publicKey, TrustLevel.TRUSTED_VERIFIED); + } + public void deleteAccountData() throws IOException { close(); try (final var files = Files.walk(getUserPath(dataPath, accountPath).toPath()) @@ -1850,4 +1698,95 @@ public class SignalAccount implements Closeable { void call(); } + + private static class PreKeyMetadata { + + private int preKeyIdOffset = 1; + private int nextSignedPreKeyId = 1; + private int kyberPreKeyIdOffset = 1; + private int activeLastResortKyberPreKeyId = -1; + } + + public class AccountData { + + private final ServiceIdType serviceIdType; + private IdentityKeyPair identityKeyPair; + private int localRegistrationId; + private final PreKeyMetadata preKeyMetadata = new PreKeyMetadata(); + + private SignalProtocolStore signalProtocolStore; + private PreKeyStore preKeyStore; + private SignedPreKeyStore signedPreKeyStore; + private KyberPreKeyStore kyberPreKeyStore; + private SessionStore sessionStore; + private SignalIdentityKeyStore identityKeyStore; + + public AccountData(final ServiceIdType serviceIdType) { + this.serviceIdType = serviceIdType; + } + + public ServiceId getServiceId() { + return getAccountId(serviceIdType); + } + + public IdentityKeyPair getIdentityKeyPair() { + return identityKeyPair; + } + + private void setIdentityKeyPair(final IdentityKeyPair identityKeyPair) { + this.identityKeyPair = identityKeyPair; + } + + public int getLocalRegistrationId() { + return localRegistrationId; + } + + private void setLocalRegistrationId(final int localRegistrationId) { + this.localRegistrationId = localRegistrationId; + this.identityKeyStore = null; + } + + public PreKeyMetadata getPreKeyMetadata() { + return preKeyMetadata; + } + + private SignalServiceAccountDataStore getSignalServiceAccountDataStore() { + return getOrCreate(() -> signalProtocolStore, + () -> signalProtocolStore = new SignalProtocolStore(getPreKeyStore(), + getSignedPreKeyStore(), + getKyberPreKeyStore(), + getSessionStore(), + getIdentityKeyStore(), + getSenderKeyStore(), + SignalAccount.this::isMultiDevice)); + } + + private PreKeyStore getPreKeyStore() { + return getOrCreate(() -> preKeyStore, + () -> preKeyStore = new PreKeyStore(getAccountDatabase(), serviceIdType)); + } + + private SignedPreKeyStore getSignedPreKeyStore() { + return getOrCreate(() -> signedPreKeyStore, + () -> signedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), serviceIdType)); + } + + private KyberPreKeyStore getKyberPreKeyStore() { + return getOrCreate(() -> kyberPreKeyStore, + () -> kyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), serviceIdType)); + } + + public SessionStore getSessionStore() { + return getOrCreate(() -> sessionStore, + () -> sessionStore = new SessionStore(getAccountDatabase(), serviceIdType)); + } + + public SignalIdentityKeyStore getIdentityKeyStore() { + return getOrCreate(() -> identityKeyStore, + () -> identityKeyStore = new SignalIdentityKeyStore(getRecipientResolver(), + () -> identityKeyPair, + localRegistrationId, + SignalAccount.this.getIdentityKeyStore())); + } + } } diff --git a/lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java b/lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java index 3ee5657e..8db8d0a4 100644 --- a/lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java +++ b/lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java @@ -120,4 +120,8 @@ public class KeyUtils { secureRandom.nextBytes(secret); return secret; } + + public static int getRandomInt(int bound) { + return secureRandom.nextInt(bound); + } }