]> nmode's Git Repositories - signal-cli/commitdiff
Implement support for receiving sender key messages
authorAsamK <asamk@gmx.de>
Fri, 3 Sep 2021 20:38:45 +0000 (22:38 +0200)
committerAsamK <asamk@gmx.de>
Sat, 4 Sep 2021 11:54:06 +0000 (13:54 +0200)
lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java
lib/src/main/java/org/asamk/signal/manager/helper/IncomingMessageHandler.java
lib/src/main/java/org/asamk/signal/manager/helper/RecipientAddressResolver.java [new file with mode: 0644]
lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java
lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java
lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeyRecordStore.java [new file with mode: 0644]
lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeySharedStore.java [new file with mode: 0644]
lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeyStore.java [new file with mode: 0644]
src/main/java/org/asamk/signal/ReceiveMessageHandler.java

index 3f97be6bb9d2306b8a1ebbe1c387675f9c331a9a..5324439b2ee163e0c296a426da21c608388f3e63 100644 (file)
@@ -34,12 +34,7 @@ public class ServiceConfig {
         } catch (Throwable ignored) {
             zkGroupAvailable = false;
         }
         } catch (Throwable ignored) {
             zkGroupAvailable = false;
         }
-        capabilities = new AccountAttributes.Capabilities(false,
-                zkGroupAvailable,
-                false,
-                zkGroupAvailable,
-                false,
-                true);
+        capabilities = new AccountAttributes.Capabilities(false, zkGroupAvailable, false, zkGroupAvailable, true, true);
 
         try {
             TrustStore contactTrustStore = new IasTrustStore();
 
         try {
             TrustStore contactTrustStore = new IasTrustStore();
index 57b71ee1de8394538016cdb53c3fd165c6168147..e6e434787e3af3ed97fef9ca6c2a487bd4ef158b 100644 (file)
@@ -30,6 +30,7 @@ import org.signal.zkgroup.InvalidInputException;
 import org.signal.zkgroup.profiles.ProfileKey;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.signal.zkgroup.profiles.ProfileKey;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.whispersystems.libsignal.SignalProtocolAddress;
 import org.whispersystems.libsignal.util.Pair;
 import org.whispersystems.signalservice.api.messages.SignalServiceContent;
 import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage;
 import org.whispersystems.libsignal.util.Pair;
 import org.whispersystems.signalservice.api.messages.SignalServiceContent;
 import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage;
@@ -173,10 +174,20 @@ public final class IncomingMessageHandler {
     ) {
         var actions = new ArrayList<HandleAction>();
         final RecipientId sender;
     ) {
         var actions = new ArrayList<HandleAction>();
         final RecipientId sender;
+        final int senderDeviceId;
         if (!envelope.isUnidentifiedSender() && envelope.hasSourceUuid()) {
             sender = recipientResolver.resolveRecipient(envelope.getSourceAddress());
         if (!envelope.isUnidentifiedSender() && envelope.hasSourceUuid()) {
             sender = recipientResolver.resolveRecipient(envelope.getSourceAddress());
+            senderDeviceId = envelope.getSourceDevice();
         } else {
             sender = recipientResolver.resolveRecipient(content.getSender());
         } else {
             sender = recipientResolver.resolveRecipient(content.getSender());
+            senderDeviceId = content.getSenderDevice();
+        }
+
+        if (content.getSenderKeyDistributionMessage().isPresent()) {
+            final var message = content.getSenderKeyDistributionMessage().get();
+            final var protocolAddress = new SignalProtocolAddress(addressResolver.resolveSignalServiceAddress(sender)
+                    .getIdentifier(), senderDeviceId);
+            dependencies.getMessageSender().processSenderKeyDistributionMessage(protocolAddress, message);
         }
 
         if (content.getDataMessage().isPresent()) {
         }
 
         if (content.getDataMessage().isPresent()) {
diff --git a/lib/src/main/java/org/asamk/signal/manager/helper/RecipientAddressResolver.java b/lib/src/main/java/org/asamk/signal/manager/helper/RecipientAddressResolver.java
new file mode 100644 (file)
index 0000000..e2c10f4
--- /dev/null
@@ -0,0 +1,9 @@
+package org.asamk.signal.manager.helper;
+
+import org.asamk.signal.manager.storage.recipients.RecipientAddress;
+import org.asamk.signal.manager.storage.recipients.RecipientId;
+
+public interface RecipientAddressResolver {
+
+    RecipientAddress resolveRecipientAddress(RecipientId recipientId);
+}
index 4e2408870cd86e0ee55f59a52df02166f313787f..e75996c51cd57f0087171cf6fc2067f561b67843 100644 (file)
@@ -24,6 +24,7 @@ import org.asamk.signal.manager.storage.recipients.Profile;
 import org.asamk.signal.manager.storage.recipients.RecipientAddress;
 import org.asamk.signal.manager.storage.recipients.RecipientId;
 import org.asamk.signal.manager.storage.recipients.RecipientStore;
 import org.asamk.signal.manager.storage.recipients.RecipientAddress;
 import org.asamk.signal.manager.storage.recipients.RecipientId;
 import org.asamk.signal.manager.storage.recipients.RecipientStore;
+import org.asamk.signal.manager.storage.senderKeys.SenderKeyStore;
 import org.asamk.signal.manager.storage.sessions.SessionStore;
 import org.asamk.signal.manager.storage.stickers.StickerStore;
 import org.asamk.signal.manager.storage.threads.LegacyJsonThreadStore;
 import org.asamk.signal.manager.storage.sessions.SessionStore;
 import org.asamk.signal.manager.storage.stickers.StickerStore;
 import org.asamk.signal.manager.storage.threads.LegacyJsonThreadStore;
@@ -95,6 +96,7 @@ public class SignalAccount implements Closeable {
     private SignedPreKeyStore signedPreKeyStore;
     private SessionStore sessionStore;
     private IdentityKeyStore identityKeyStore;
     private SignedPreKeyStore signedPreKeyStore;
     private SessionStore sessionStore;
     private IdentityKeyStore identityKeyStore;
+    private SenderKeyStore senderKeyStore;
     private GroupStore groupStore;
     private GroupStore.Storage groupStoreStorage;
     private RecipientStore recipientStore;
     private GroupStore groupStore;
     private GroupStore.Storage groupStoreStorage;
     private RecipientStore recipientStore;
@@ -181,10 +183,15 @@ public class SignalAccount implements Closeable {
                 identityKey,
                 registrationId,
                 trustNewIdentity);
                 identityKey,
                 registrationId,
                 trustNewIdentity);
+        senderKeyStore = new SenderKeyStore(getSharedSenderKeysFile(dataPath, username),
+                getSenderKeysPath(dataPath, username),
+                recipientStore::resolveRecipientAddress,
+                recipientStore);
         signalProtocolStore = new SignalProtocolStore(preKeyStore,
                 signedPreKeyStore,
                 sessionStore,
                 identityKeyStore,
         signalProtocolStore = new SignalProtocolStore(preKeyStore,
                 signedPreKeyStore,
                 sessionStore,
                 identityKeyStore,
+                senderKeyStore,
                 this::isMultiDevice);
 
         messageCache = new MessageCache(getMessageCachePath(dataPath, username));
                 this::isMultiDevice);
 
         messageCache = new MessageCache(getMessageCachePath(dataPath, username));
@@ -221,6 +228,7 @@ public class SignalAccount implements Closeable {
         account.setProvisioningData(username, uuid, password, encryptedDeviceName, deviceId, profileKey);
         account.recipientStore.resolveRecipientTrusted(account.getSelfAddress());
         account.sessionStore.archiveAllSessions();
         account.setProvisioningData(username, uuid, password, encryptedDeviceName, deviceId, profileKey);
         account.recipientStore.resolveRecipientTrusted(account.getSelfAddress());
         account.sessionStore.archiveAllSessions();
+        account.senderKeyStore.deleteAll();
         account.clearAllPreKeys();
         return account;
     }
         account.clearAllPreKeys();
         return account;
     }
@@ -303,6 +311,7 @@ public class SignalAccount implements Closeable {
         identityKeyStore.mergeRecipients(recipientId, toBeMergedRecipientId);
         messageCache.mergeRecipients(recipientId, toBeMergedRecipientId);
         groupStore.mergeRecipients(recipientId, toBeMergedRecipientId);
         identityKeyStore.mergeRecipients(recipientId, toBeMergedRecipientId);
         messageCache.mergeRecipients(recipientId, toBeMergedRecipientId);
         groupStore.mergeRecipients(recipientId, toBeMergedRecipientId);
+        senderKeyStore.mergeRecipients(recipientId, toBeMergedRecipientId);
     }
 
     public static File getFileName(File dataPath, String username) {
     }
 
     public static File getFileName(File dataPath, String username) {
@@ -343,6 +352,14 @@ public class SignalAccount implements Closeable {
         return new File(getUserPath(dataPath, username), "sessions");
     }
 
         return new File(getUserPath(dataPath, username), "sessions");
     }
 
+    private static File getSenderKeysPath(File dataPath, String username) {
+        return new File(getUserPath(dataPath, username), "sender-keys");
+    }
+
+    private static File getSharedSenderKeysFile(File dataPath, String username) {
+        return new File(getUserPath(dataPath, username), "shared-sender-keys-store");
+    }
+
     private static File getRecipientsStoreFile(File dataPath, String username) {
         return new File(getUserPath(dataPath, username), "recipients-store");
     }
     private static File getRecipientsStoreFile(File dataPath, String username) {
         return new File(getUserPath(dataPath, username), "recipients-store");
     }
@@ -768,6 +785,10 @@ public class SignalAccount implements Closeable {
         return stickerStore;
     }
 
         return stickerStore;
     }
 
+    public SenderKeyStore getSenderKeyStore() {
+        return senderKeyStore;
+    }
+
     public MessageCache getMessageCache() {
         return messageCache;
     }
     public MessageCache getMessageCache() {
         return messageCache;
     }
@@ -932,6 +953,7 @@ public class SignalAccount implements Closeable {
         save();
 
         getSessionStore().archiveAllSessions();
         save();
 
         getSessionStore().archiveAllSessions();
+        senderKeyStore.deleteAll();
         final var recipientId = getRecipientStore().resolveRecipientTrusted(getSelfAddress());
         final var publicKey = getIdentityKeyPair().getPublicKey();
         getIdentityKeyStore().saveIdentity(recipientId, publicKey, new Date());
         final var recipientId = getRecipientStore().resolveRecipientTrusted(getSelfAddress());
         final var publicKey = getIdentityKeyPair().getPublicKey();
         getIdentityKeyStore().saveIdentity(recipientId, publicKey, new Date());
index 77eb764a4ab2e4c405a1d3ce96c057a8150020ad..7f2004592107190dda5789765002ff832eca8206 100644 (file)
@@ -13,6 +13,7 @@ import org.whispersystems.libsignal.state.SessionRecord;
 import org.whispersystems.libsignal.state.SignedPreKeyRecord;
 import org.whispersystems.libsignal.state.SignedPreKeyStore;
 import org.whispersystems.signalservice.api.SignalServiceDataStore;
 import org.whispersystems.libsignal.state.SignedPreKeyRecord;
 import org.whispersystems.libsignal.state.SignedPreKeyStore;
 import org.whispersystems.signalservice.api.SignalServiceDataStore;
+import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
 import org.whispersystems.signalservice.api.push.DistributionId;
 
 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
 import org.whispersystems.signalservice.api.push.DistributionId;
 
@@ -28,6 +29,7 @@ public class SignalProtocolStore implements SignalServiceDataStore {
     private final SignedPreKeyStore signedPreKeyStore;
     private final SignalServiceSessionStore sessionStore;
     private final IdentityKeyStore identityKeyStore;
     private final SignedPreKeyStore signedPreKeyStore;
     private final SignalServiceSessionStore sessionStore;
     private final IdentityKeyStore identityKeyStore;
+    private final SignalServiceSenderKeyStore senderKeyStore;
     private final Supplier<Boolean> isMultiDevice;
 
     public SignalProtocolStore(
     private final Supplier<Boolean> isMultiDevice;
 
     public SignalProtocolStore(
@@ -35,12 +37,14 @@ public class SignalProtocolStore implements SignalServiceDataStore {
             final SignedPreKeyStore signedPreKeyStore,
             final SignalServiceSessionStore sessionStore,
             final IdentityKeyStore identityKeyStore,
             final SignedPreKeyStore signedPreKeyStore,
             final SignalServiceSessionStore sessionStore,
             final IdentityKeyStore identityKeyStore,
+            final SignalServiceSenderKeyStore senderKeyStore,
             final Supplier<Boolean> isMultiDevice
     ) {
         this.preKeyStore = preKeyStore;
         this.signedPreKeyStore = signedPreKeyStore;
         this.sessionStore = sessionStore;
         this.identityKeyStore = identityKeyStore;
             final Supplier<Boolean> isMultiDevice
     ) {
         this.preKeyStore = preKeyStore;
         this.signedPreKeyStore = signedPreKeyStore;
         this.sessionStore = sessionStore;
         this.identityKeyStore = identityKeyStore;
+        this.senderKeyStore = senderKeyStore;
         this.isMultiDevice = isMultiDevice;
     }
 
         this.isMultiDevice = isMultiDevice;
     }
 
@@ -163,31 +167,29 @@ public class SignalProtocolStore implements SignalServiceDataStore {
     public void storeSenderKey(
             final SignalProtocolAddress sender, final UUID distributionId, final SenderKeyRecord record
     ) {
     public void storeSenderKey(
             final SignalProtocolAddress sender, final UUID distributionId, final SenderKeyRecord record
     ) {
-        // TODO
+        senderKeyStore.storeSenderKey(sender, distributionId, record);
     }
 
     @Override
     public SenderKeyRecord loadSenderKey(final SignalProtocolAddress sender, final UUID distributionId) {
     }
 
     @Override
     public SenderKeyRecord loadSenderKey(final SignalProtocolAddress sender, final UUID distributionId) {
-        // TODO
-        return null;
+        return senderKeyStore.loadSenderKey(sender, distributionId);
     }
 
     @Override
     public Set<SignalProtocolAddress> getSenderKeySharedWith(final DistributionId distributionId) {
     }
 
     @Override
     public Set<SignalProtocolAddress> getSenderKeySharedWith(final DistributionId distributionId) {
-        // TODO
-        return null;
+        return senderKeyStore.getSenderKeySharedWith(distributionId);
     }
 
     @Override
     public void markSenderKeySharedWith(
             final DistributionId distributionId, final Collection<SignalProtocolAddress> addresses
     ) {
     }
 
     @Override
     public void markSenderKeySharedWith(
             final DistributionId distributionId, final Collection<SignalProtocolAddress> addresses
     ) {
-        // TODO
+        senderKeyStore.markSenderKeySharedWith(distributionId, addresses);
     }
 
     @Override
     public void clearSenderKeySharedWith(final Collection<SignalProtocolAddress> addresses) {
     }
 
     @Override
     public void clearSenderKeySharedWith(final Collection<SignalProtocolAddress> addresses) {
-        // TODO
+        senderKeyStore.clearSenderKeySharedWith(addresses);
     }
 
     @Override
     }
 
     @Override
diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeyRecordStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeyRecordStore.java
new file mode 100644 (file)
index 0000000..f84903e
--- /dev/null
@@ -0,0 +1,261 @@
+package org.asamk.signal.manager.storage.senderKeys;
+
+import org.asamk.signal.manager.storage.recipients.RecipientId;
+import org.asamk.signal.manager.storage.recipients.RecipientResolver;
+import org.asamk.signal.manager.util.IOUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.libsignal.SignalProtocolAddress;
+import org.whispersystems.libsignal.groups.state.SenderKeyRecord;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+
+public class SenderKeyRecordStore implements org.whispersystems.libsignal.groups.state.SenderKeyStore {
+
+    private final static Logger logger = LoggerFactory.getLogger(SenderKeyRecordStore.class);
+
+    private final Map<Key, SenderKeyRecord> cachedSenderKeys = new HashMap<>();
+
+    private final File senderKeysPath;
+
+    private final RecipientResolver resolver;
+
+    public SenderKeyRecordStore(
+            final File senderKeysPath, final RecipientResolver resolver
+    ) {
+        this.senderKeysPath = senderKeysPath;
+        this.resolver = resolver;
+    }
+
+    @Override
+    public SenderKeyRecord loadSenderKey(final SignalProtocolAddress address, final UUID distributionId) {
+        final var key = getKey(address, distributionId);
+
+        synchronized (cachedSenderKeys) {
+            return loadSenderKeyLocked(key);
+        }
+    }
+
+    @Override
+    public void storeSenderKey(
+            final SignalProtocolAddress address, final UUID distributionId, final SenderKeyRecord record
+    ) {
+        final var key = getKey(address, distributionId);
+
+        synchronized (cachedSenderKeys) {
+            storeSenderKeyLocked(key, record);
+        }
+    }
+
+    public void deleteAll() {
+        synchronized (cachedSenderKeys) {
+            cachedSenderKeys.clear();
+            final var files = senderKeysPath.listFiles((_file, s) -> senderKeyFileNamePattern.matcher(s).matches());
+            if (files == null) {
+                return;
+            }
+
+            for (final var file : files) {
+                try {
+                    Files.delete(file.toPath());
+                } catch (IOException e) {
+                    logger.error("Failed to delete sender key file {}: {}", file, e.getMessage());
+                }
+            }
+        }
+    }
+
+    public void deleteAllFor(final RecipientId recipientId) {
+        synchronized (cachedSenderKeys) {
+            cachedSenderKeys.clear();
+            final var keys = getKeysLocked(recipientId);
+            for (var key : keys) {
+                deleteSenderKeyLocked(key);
+            }
+        }
+    }
+
+    public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
+        synchronized (cachedSenderKeys) {
+            final var keys = getKeysLocked(toBeMergedRecipientId);
+            final var otherHasSenderKeys = keys.size() > 0;
+            if (!otherHasSenderKeys) {
+                return;
+            }
+
+            logger.debug("Only to be merged recipient had sender keys, re-assigning to the new recipient.");
+            for (var key : keys) {
+                final var toBeMergedSenderKey = loadSenderKeyLocked(key);
+                deleteSenderKeyLocked(key);
+                if (toBeMergedSenderKey == null) {
+                    continue;
+                }
+
+                final var newKey = new Key(recipientId, key.getDeviceId(), key.distributionId);
+                final var senderKeyRecord = loadSenderKeyLocked(newKey);
+                if (senderKeyRecord != null) {
+                    continue;
+                }
+                storeSenderKeyLocked(newKey, senderKeyRecord);
+            }
+        }
+    }
+
+    /**
+     * @param identifier can be either a serialized uuid or a e164 phone number
+     */
+    private RecipientId resolveRecipient(String identifier) {
+        return resolver.resolveRecipient(identifier);
+    }
+
+    private Key getKey(final SignalProtocolAddress address, final UUID distributionId) {
+        final var recipientId = resolveRecipient(address.getName());
+        return new Key(recipientId, address.getDeviceId(), distributionId);
+    }
+
+    private List<Key> getKeysLocked(RecipientId recipientId) {
+        final var files = senderKeysPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
+        if (files == null) {
+            return List.of();
+        }
+        return parseFileNames(files);
+    }
+
+    final Pattern senderKeyFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)_([0-9a-z\\-]+)");
+
+    private List<Key> parseFileNames(final File[] files) {
+        return Arrays.stream(files)
+                .map(f -> senderKeyFileNamePattern.matcher(f.getName()))
+                .filter(Matcher::matches)
+                .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
+                        Integer.parseInt(matcher.group(2)),
+                        UUID.fromString(matcher.group(3))))
+                .collect(Collectors.toList());
+    }
+
+    private File getSenderKeyFile(Key key) {
+        try {
+            IOUtils.createPrivateDirectories(senderKeysPath);
+        } catch (IOException e) {
+            throw new AssertionError("Failed to create sender keys path", e);
+        }
+        return new File(senderKeysPath,
+                key.getRecipientId().getId() + "_" + key.getDeviceId() + "_" + key.distributionId.toString());
+    }
+
+    private SenderKeyRecord loadSenderKeyLocked(final Key key) {
+        {
+            final var senderKeyRecord = cachedSenderKeys.get(key);
+            if (senderKeyRecord != null) {
+                return senderKeyRecord;
+            }
+        }
+
+        final var file = getSenderKeyFile(key);
+        if (!file.exists()) {
+            return null;
+        }
+        try (var inputStream = new FileInputStream(file)) {
+            final var senderKeyRecord = new SenderKeyRecord(inputStream.readAllBytes());
+            cachedSenderKeys.put(key, senderKeyRecord);
+            return senderKeyRecord;
+        } catch (IOException e) {
+            logger.warn("Failed to load sender key, resetting sender key: {}", e.getMessage());
+            return null;
+        }
+    }
+
+    private void storeSenderKeyLocked(final Key key, final SenderKeyRecord senderKeyRecord) {
+        cachedSenderKeys.put(key, senderKeyRecord);
+
+        final var file = getSenderKeyFile(key);
+        try {
+            try (var outputStream = new FileOutputStream(file)) {
+                outputStream.write(senderKeyRecord.serialize());
+            }
+        } catch (IOException e) {
+            logger.warn("Failed to store sender key, trying to delete file and retry: {}", e.getMessage());
+            try {
+                Files.delete(file.toPath());
+                try (var outputStream = new FileOutputStream(file)) {
+                    outputStream.write(senderKeyRecord.serialize());
+                }
+            } catch (IOException e2) {
+                logger.error("Failed to store sender key file {}: {}", file, e2.getMessage());
+            }
+        }
+    }
+
+    private void deleteSenderKeyLocked(final Key key) {
+        cachedSenderKeys.remove(key);
+
+        final var file = getSenderKeyFile(key);
+        if (!file.exists()) {
+            return;
+        }
+        try {
+            Files.delete(file.toPath());
+        } catch (IOException e) {
+            logger.error("Failed to delete sender key file {}: {}", file, e.getMessage());
+        }
+    }
+
+    private static final class Key {
+
+        private final RecipientId recipientId;
+        private final int deviceId;
+        private final UUID distributionId;
+
+        public Key(
+                final RecipientId recipientId, final int deviceId, final UUID distributionId
+        ) {
+            this.recipientId = recipientId;
+            this.deviceId = deviceId;
+            this.distributionId = distributionId;
+        }
+
+        public RecipientId getRecipientId() {
+            return recipientId;
+        }
+
+        public int getDeviceId() {
+            return deviceId;
+        }
+
+        public UUID getDistributionId() {
+            return distributionId;
+        }
+
+        @Override
+        public boolean equals(final Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            final Key key = (Key) o;
+
+            if (deviceId != key.deviceId) return false;
+            if (!recipientId.equals(key.recipientId)) return false;
+            return distributionId.equals(key.distributionId);
+        }
+
+        @Override
+        public int hashCode() {
+            int result = recipientId.hashCode();
+            result = 31 * result + deviceId;
+            result = 31 * result + distributionId.hashCode();
+            return result;
+        }
+    }
+}
diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeySharedStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeySharedStore.java
new file mode 100644 (file)
index 0000000..3faf2e7
--- /dev/null
@@ -0,0 +1,270 @@
+package org.asamk.signal.manager.storage.senderKeys;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import org.asamk.signal.manager.helper.RecipientAddressResolver;
+import org.asamk.signal.manager.storage.Utils;
+import org.asamk.signal.manager.storage.recipients.RecipientId;
+import org.asamk.signal.manager.storage.recipients.RecipientResolver;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.libsignal.SignalProtocolAddress;
+import org.whispersystems.signalservice.api.push.DistributionId;
+import org.whispersystems.signalservice.api.util.UuidUtil;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class SenderKeySharedStore {
+
+    private final static Logger logger = LoggerFactory.getLogger(SenderKeySharedStore.class);
+
+    private final Map<DistributionId, Set<SenderKeySharedEntry>> sharedSenderKeys;
+
+    private final ObjectMapper objectMapper;
+    private final File file;
+
+    private final RecipientResolver resolver;
+    private final RecipientAddressResolver addressResolver;
+
+    public static SenderKeySharedStore load(
+            final File file, final RecipientAddressResolver addressResolver, final RecipientResolver resolver
+    ) throws IOException {
+        final var objectMapper = Utils.createStorageObjectMapper();
+        try (var inputStream = new FileInputStream(file)) {
+            final var storage = objectMapper.readValue(inputStream, Storage.class);
+            final var sharedSenderKeys = new HashMap<DistributionId, Set<SenderKeySharedEntry>>();
+            for (final var senderKey : storage.sharedSenderKeys) {
+                final var entry = new SenderKeySharedEntry(RecipientId.of(senderKey.recipientId), senderKey.deviceId);
+                final var uuid = UuidUtil.parseOrNull(senderKey.distributionId);
+                if (uuid == null) {
+                    logger.warn("Read invalid distribution id from storage {}, ignoring", senderKey.distributionId);
+                    continue;
+                }
+                final var distributionId = DistributionId.from(uuid);
+                var entries = sharedSenderKeys.get(distributionId);
+                if (entries == null) {
+                    entries = new HashSet<>();
+                }
+                entries.add(entry);
+                sharedSenderKeys.put(distributionId, entries);
+            }
+
+            return new SenderKeySharedStore(sharedSenderKeys, objectMapper, file, addressResolver, resolver);
+        } catch (FileNotFoundException e) {
+            logger.debug("Creating new shared sender key store.");
+            return new SenderKeySharedStore(new HashMap<>(), objectMapper, file, addressResolver, resolver);
+        }
+    }
+
+    private SenderKeySharedStore(
+            final Map<DistributionId, Set<SenderKeySharedEntry>> sharedSenderKeys,
+            final ObjectMapper objectMapper,
+            final File file,
+            final RecipientAddressResolver addressResolver,
+            final RecipientResolver resolver
+    ) {
+        this.sharedSenderKeys = sharedSenderKeys;
+        this.objectMapper = objectMapper;
+        this.file = file;
+        this.addressResolver = addressResolver;
+        this.resolver = resolver;
+    }
+
+    public Set<SignalProtocolAddress> getSenderKeySharedWith(final DistributionId distributionId) {
+        synchronized (sharedSenderKeys) {
+            return sharedSenderKeys.get(distributionId)
+                    .stream()
+                    .map(k -> new SignalProtocolAddress(addressResolver.resolveRecipientAddress(k.getRecipientId())
+                            .getIdentifier(), k.getDeviceId()))
+                    .collect(Collectors.toSet());
+        }
+    }
+
+    public void markSenderKeySharedWith(
+            final DistributionId distributionId, final Collection<SignalProtocolAddress> addresses
+    ) {
+        final var newEntries = addresses.stream()
+                .map(a -> new SenderKeySharedEntry(resolveRecipient(a.getName()), a.getDeviceId()))
+                .collect(Collectors.toSet());
+
+        synchronized (sharedSenderKeys) {
+            final var previousEntries = sharedSenderKeys.getOrDefault(distributionId, Set.of());
+
+            sharedSenderKeys.put(distributionId, new HashSet<>() {
+                {
+                    addAll(previousEntries);
+                    addAll(newEntries);
+                }
+            });
+            saveLocked();
+        }
+    }
+
+    public void clearSenderKeySharedWith(final Collection<SignalProtocolAddress> addresses) {
+        final var entriesToDelete = addresses.stream()
+                .map(a -> new SenderKeySharedEntry(resolveRecipient(a.getName()), a.getDeviceId()))
+                .collect(Collectors.toSet());
+
+        synchronized (sharedSenderKeys) {
+            for (final var distributionId : sharedSenderKeys.keySet()) {
+                final var entries = sharedSenderKeys.getOrDefault(distributionId, Set.of());
+
+                sharedSenderKeys.put(distributionId, new HashSet<>(entries) {
+                    {
+                        removeAll(entriesToDelete);
+                    }
+                });
+            }
+            saveLocked();
+        }
+    }
+
+    public void deleteAll() {
+        synchronized (sharedSenderKeys) {
+            sharedSenderKeys.clear();
+            saveLocked();
+        }
+    }
+
+    public void deleteAllFor(final RecipientId recipientId) {
+        synchronized (sharedSenderKeys) {
+            for (final var distributionId : sharedSenderKeys.keySet()) {
+                final var entries = sharedSenderKeys.getOrDefault(distributionId, Set.of());
+
+                sharedSenderKeys.put(distributionId, new HashSet<>(entries) {
+                    {
+                        entries.removeIf(e -> e.getRecipientId().equals(recipientId));
+                    }
+                });
+            }
+            saveLocked();
+        }
+    }
+
+    public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
+        synchronized (sharedSenderKeys) {
+            for (final var distributionId : sharedSenderKeys.keySet()) {
+                final var entries = sharedSenderKeys.getOrDefault(distributionId, Set.of());
+
+                sharedSenderKeys.put(distributionId,
+                        entries.stream()
+                                .map(e -> e.recipientId.equals(toBeMergedRecipientId) ? new SenderKeySharedEntry(
+                                        recipientId,
+                                        e.getDeviceId()) : e)
+                                .collect(Collectors.toSet()));
+            }
+            saveLocked();
+        }
+    }
+
+    /**
+     * @param identifier can be either a serialized uuid or a e164 phone number
+     */
+    private RecipientId resolveRecipient(String identifier) {
+        return resolver.resolveRecipient(identifier);
+    }
+
+    private void saveLocked() {
+        var storage = new Storage(sharedSenderKeys.entrySet().stream().flatMap(pair -> {
+            final var sharedWith = pair.getValue();
+            return sharedWith.stream()
+                    .map(entry -> new Storage.SharedSenderKey(entry.getRecipientId().getId(),
+                            entry.getDeviceId(),
+                            pair.getKey().asUuid().toString()));
+        }).collect(Collectors.toList()));
+
+        // Write to memory first to prevent corrupting the file in case of serialization errors
+        try (var inMemoryOutput = new ByteArrayOutputStream()) {
+            objectMapper.writeValue(inMemoryOutput, storage);
+
+            var input = new ByteArrayInputStream(inMemoryOutput.toByteArray());
+            try (var outputStream = new FileOutputStream(file)) {
+                input.transferTo(outputStream);
+            }
+        } catch (Exception e) {
+            logger.error("Error saving shared sender key store file: {}", e.getMessage());
+        }
+    }
+
+    private static class Storage {
+
+        public List<SharedSenderKey> sharedSenderKeys;
+
+        // For deserialization
+        private Storage() {
+        }
+
+        public Storage(final List<SharedSenderKey> sharedSenderKeys) {
+            this.sharedSenderKeys = sharedSenderKeys;
+        }
+
+        private static class SharedSenderKey {
+
+            public long recipientId;
+            public int deviceId;
+            public String distributionId;
+
+            // For deserialization
+            private SharedSenderKey() {
+            }
+
+            public SharedSenderKey(final long recipientId, final int deviceId, final String distributionId) {
+                this.recipientId = recipientId;
+                this.deviceId = deviceId;
+                this.distributionId = distributionId;
+            }
+        }
+    }
+
+    private static final class SenderKeySharedEntry {
+
+        private final RecipientId recipientId;
+        private final int deviceId;
+
+        public SenderKeySharedEntry(
+                final RecipientId recipientId, final int deviceId
+        ) {
+            this.recipientId = recipientId;
+            this.deviceId = deviceId;
+        }
+
+        public RecipientId getRecipientId() {
+            return recipientId;
+        }
+
+        public int getDeviceId() {
+            return deviceId;
+        }
+
+        @Override
+        public boolean equals(final Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            final SenderKeySharedEntry that = (SenderKeySharedEntry) o;
+
+            if (deviceId != that.deviceId) return false;
+            return recipientId.equals(that.recipientId);
+        }
+
+        @Override
+        public int hashCode() {
+            int result = recipientId.hashCode();
+            result = 31 * result + deviceId;
+            return result;
+        }
+    }
+}
diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeyStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeyStore.java
new file mode 100644 (file)
index 0000000..ab02d75
--- /dev/null
@@ -0,0 +1,75 @@
+package org.asamk.signal.manager.storage.senderKeys;
+
+import org.asamk.signal.manager.helper.RecipientAddressResolver;
+import org.asamk.signal.manager.storage.recipients.RecipientId;
+import org.asamk.signal.manager.storage.recipients.RecipientResolver;
+import org.whispersystems.libsignal.SignalProtocolAddress;
+import org.whispersystems.libsignal.groups.state.SenderKeyRecord;
+import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
+import org.whispersystems.signalservice.api.push.DistributionId;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Set;
+import java.util.UUID;
+
+public class SenderKeyStore implements SignalServiceSenderKeyStore {
+
+    private final SenderKeyRecordStore senderKeyRecordStore;
+    private final SenderKeySharedStore senderKeySharedStore;
+
+    public SenderKeyStore(
+            final File file,
+            final File senderKeysPath,
+            final RecipientAddressResolver addressResolver,
+            final RecipientResolver resolver
+    ) throws IOException {
+        this.senderKeyRecordStore = new SenderKeyRecordStore(senderKeysPath, resolver);
+        this.senderKeySharedStore = SenderKeySharedStore.load(file, addressResolver, resolver);
+    }
+
+    @Override
+    public void storeSenderKey(
+            final SignalProtocolAddress sender, final UUID distributionId, final SenderKeyRecord record
+    ) {
+        senderKeyRecordStore.storeSenderKey(sender, distributionId, record);
+    }
+
+    @Override
+    public SenderKeyRecord loadSenderKey(final SignalProtocolAddress sender, final UUID distributionId) {
+        return senderKeyRecordStore.loadSenderKey(sender, distributionId);
+    }
+
+    @Override
+    public Set<SignalProtocolAddress> getSenderKeySharedWith(final DistributionId distributionId) {
+        return senderKeySharedStore.getSenderKeySharedWith(distributionId);
+    }
+
+    @Override
+    public void markSenderKeySharedWith(
+            final DistributionId distributionId, final Collection<SignalProtocolAddress> addresses
+    ) {
+        senderKeySharedStore.markSenderKeySharedWith(distributionId, addresses);
+    }
+
+    @Override
+    public void clearSenderKeySharedWith(final Collection<SignalProtocolAddress> addresses) {
+        senderKeySharedStore.clearSenderKeySharedWith(addresses);
+    }
+
+    public void deleteAll() {
+        senderKeySharedStore.deleteAll();
+        senderKeyRecordStore.deleteAll();
+    }
+
+    public void rotateSenderKeys(RecipientId recipientId) {
+        senderKeySharedStore.deleteAllFor(recipientId);
+        senderKeyRecordStore.deleteAllFor(recipientId);
+    }
+
+    public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
+        senderKeySharedStore.mergeRecipients(recipientId, toBeMergedRecipientId);
+        senderKeyRecordStore.mergeRecipients(recipientId, toBeMergedRecipientId);
+    }
+}
index 4a516197e69d33e32307020da6e2d310ed03259f..15dbd1af0b0047d8703cb92a63b94b4dd4edbdc0 100644 (file)
@@ -82,6 +82,12 @@ public class ReceiveMessageHandler implements Manager.ReceiveMessageHandler {
                         DateUtils.formatTimestamp(content.getServerReceivedTimestamp()),
                         DateUtils.formatTimestamp(content.getServerDeliveredTimestamp()));
 
                         DateUtils.formatTimestamp(content.getServerReceivedTimestamp()),
                         DateUtils.formatTimestamp(content.getServerDeliveredTimestamp()));
 
+                if (content.getSenderKeyDistributionMessage().isPresent()) {
+                    final var message = content.getSenderKeyDistributionMessage().get();
+                    writer.println("Received a sender key distribution message for distributionId {}",
+                            message.getDistributionId());
+                }
+
                 if (content.getDataMessage().isPresent()) {
                     var message = content.getDataMessage().get();
                     printDataMessage(writer, message);
                 if (content.getDataMessage().isPresent()) {
                     var message = content.getDataMessage().get();
                     printDataMessage(writer, message);