]> nmode's Git Repositories - signal-cli/blobdiff - lib/src/main/java/org/asamk/signal/manager/helper/SendHelper.java
Implementing sending group messages with sender keys
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / helper / SendHelper.java
index c8d8bbb767aa9e6d62c48acd9c9cb2ed4ca09d61..41dccaa474a9a3a4b55cbf5cae95065b6c319a87 100644 (file)
@@ -8,14 +8,19 @@ import org.asamk.signal.manager.groups.GroupUtils;
 import org.asamk.signal.manager.groups.NotAGroupMemberException;
 import org.asamk.signal.manager.storage.SignalAccount;
 import org.asamk.signal.manager.storage.groups.GroupInfo;
+import org.asamk.signal.manager.storage.recipients.Profile;
 import org.asamk.signal.manager.storage.recipients.RecipientId;
 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.whispersystems.libsignal.InvalidKeyException;
+import org.whispersystems.libsignal.InvalidRegistrationIdException;
+import org.whispersystems.libsignal.NoSessionException;
 import org.whispersystems.libsignal.protocol.DecryptionErrorMessage;
 import org.whispersystems.libsignal.util.guava.Optional;
 import org.whispersystems.signalservice.api.SignalServiceMessageSender;
 import org.whispersystems.signalservice.api.crypto.ContentHint;
+import org.whispersystems.signalservice.api.crypto.UnidentifiedAccess;
 import org.whispersystems.signalservice.api.crypto.UnidentifiedAccessPair;
 import org.whispersystems.signalservice.api.messages.SendMessageResult;
 import org.whispersystems.signalservice.api.messages.SignalServiceDataMessage;
@@ -23,16 +28,22 @@ import org.whispersystems.signalservice.api.messages.SignalServiceReceiptMessage
 import org.whispersystems.signalservice.api.messages.SignalServiceTypingMessage;
 import org.whispersystems.signalservice.api.messages.multidevice.SentTranscriptMessage;
 import org.whispersystems.signalservice.api.messages.multidevice.SignalServiceSyncMessage;
+import org.whispersystems.signalservice.api.push.DistributionId;
 import org.whispersystems.signalservice.api.push.SignalServiceAddress;
+import org.whispersystems.signalservice.api.push.exceptions.NotFoundException;
 import org.whispersystems.signalservice.api.push.exceptions.ProofRequiredException;
 import org.whispersystems.signalservice.api.push.exceptions.RateLimitException;
 import org.whispersystems.signalservice.api.push.exceptions.UnregisteredUserException;
+import org.whispersystems.signalservice.internal.push.exceptions.InvalidUnidentifiedAccessHeaderException;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
 
 public class SendHelper {
 
@@ -45,6 +56,7 @@ public class SendHelper {
     private final RecipientResolver recipientResolver;
     private final IdentityFailureHandler identityFailureHandler;
     private final GroupProvider groupProvider;
+    private final ProfileProvider profileProvider;
     private final RecipientRegistrationRefresher recipientRegistrationRefresher;
 
     public SendHelper(
@@ -55,6 +67,7 @@ public class SendHelper {
             final RecipientResolver recipientResolver,
             final IdentityFailureHandler identityFailureHandler,
             final GroupProvider groupProvider,
+            final ProfileProvider profileProvider,
             final RecipientRegistrationRefresher recipientRegistrationRefresher
     ) {
         this.account = account;
@@ -64,6 +77,7 @@ public class SendHelper {
         this.recipientResolver = recipientResolver;
         this.identityFailureHandler = identityFailureHandler;
         this.groupProvider = groupProvider;
+        this.profileProvider = profileProvider;
         this.recipientRegistrationRefresher = recipientRegistrationRefresher;
     }
 
@@ -81,7 +95,7 @@ public class SendHelper {
 
         final var message = messageBuilder.build();
         final var result = sendMessage(message, recipientId);
-        handlePossibleIdentityFailure(result);
+        handleSendMessageResult(result);
         return result;
     }
 
@@ -116,7 +130,7 @@ public class SendHelper {
             }
         }
 
-        return sendGroupMessage(message, recipients);
+        return sendGroupMessage(message, recipients, g.getDistributionId());
     }
 
     /**
@@ -124,12 +138,14 @@ public class SendHelper {
      * This method should only be used for create/update/quit group messages.
      */
     public List<SendMessageResult> sendGroupMessage(
-            final SignalServiceDataMessage message, final Set<RecipientId> recipientIds
+            final SignalServiceDataMessage message,
+            final Set<RecipientId> recipientIds,
+            final DistributionId distributionId
     ) throws IOException {
-        List<SendMessageResult> result = sendGroupMessageInternal(message, recipientIds);
+        List<SendMessageResult> result = sendGroupMessageInternal(message, recipientIds, distributionId);
 
         for (var r : result) {
-            handlePossibleIdentityFailure(r);
+            handleSendMessageResult(r);
         }
 
         return result;
@@ -245,27 +261,189 @@ public class SendHelper {
     }
 
     private List<SendMessageResult> sendGroupMessageInternal(
-            final SignalServiceDataMessage message, final Set<RecipientId> recipientIds
+            final SignalServiceDataMessage message,
+            final Set<RecipientId> recipientIds,
+            final DistributionId distributionId
     ) throws IOException {
+        // isRecipientUpdate is true if we've already sent this message to some recipients in the past, otherwise false.
+        final var isRecipientUpdate = false;
+        Set<RecipientId> senderKeyTargets = distributionId == null
+                ? Set.of()
+                : getSenderKeyCapableRecipientIds(recipientIds);
+        final var allResults = new ArrayList<SendMessageResult>(recipientIds.size());
+
+        if (senderKeyTargets.size() > 0) {
+            final var results = sendGroupMessageInternalWithSenderKey(message,
+                    senderKeyTargets,
+                    distributionId,
+                    isRecipientUpdate);
+
+            if (results == null) {
+                senderKeyTargets = Set.of();
+            } else {
+                results.stream().filter(SendMessageResult::isSuccess).forEach(allResults::add);
+                final var failedTargets = results.stream()
+                        .filter(r -> !r.isSuccess())
+                        .map(r -> recipientResolver.resolveRecipient(r.getAddress()))
+                        .toList();
+                if (failedTargets.size() > 0) {
+                    senderKeyTargets = new HashSet<>(senderKeyTargets);
+                    failedTargets.forEach(senderKeyTargets::remove);
+                }
+            }
+        }
+
+        final var legacyTargets = new HashSet<>(recipientIds);
+        legacyTargets.removeAll(senderKeyTargets);
+        final boolean onlyTargetIsSelfWithLinkedDevice = recipientIds.isEmpty() && account.isMultiDevice();
+
+        if (legacyTargets.size() > 0 || onlyTargetIsSelfWithLinkedDevice) {
+            if (legacyTargets.size() > 0) {
+                logger.debug("Need to do {} legacy sends.", legacyTargets.size());
+            } else {
+                logger.debug("Need to do a legacy send to send a sync message for a group of only ourselves.");
+            }
+
+            final List<SendMessageResult> results = sendGroupMessageInternalWithLegacy(message,
+                    legacyTargets,
+                    isRecipientUpdate || allResults.size() > 0);
+            allResults.addAll(results);
+        }
+
+        return allResults;
+    }
+
+    private Set<RecipientId> getSenderKeyCapableRecipientIds(final Set<RecipientId> recipientIds) {
+        final var selfProfile = profileProvider.getProfile(account.getSelfRecipientId());
+        if (selfProfile == null || !selfProfile.getCapabilities().contains(Profile.Capability.senderKey)) {
+            logger.debug("Not all of our devices support sender key. Using legacy.");
+            return Set.of();
+        }
+
+        final var senderKeyTargets = new HashSet<RecipientId>();
+        for (final var recipientId : recipientIds) {
+            // TODO filter out unregistered
+            final var profile = profileProvider.getProfile(recipientId);
+            if (profile == null || !profile.getCapabilities().contains(Profile.Capability.senderKey)) {
+                continue;
+            }
+
+            final var access = unidentifiedAccessHelper.getAccessFor(recipientId);
+            if (!access.isPresent() || !access.get().getTargetUnidentifiedAccess().isPresent()) {
+                continue;
+            }
+
+            final var identity = account.getIdentityKeyStore().getIdentity(recipientId);
+            if (identity == null || !identity.getTrustLevel().isTrusted()) {
+                continue;
+            }
+
+            senderKeyTargets.add(recipientId);
+        }
+
+        if (senderKeyTargets.size() < 2) {
+            logger.debug("Too few sender-key-capable users ({}). Doing all legacy sends.", senderKeyTargets.size());
+            return Set.of();
+        }
+
+        logger.debug("Can use sender key for {}/{} recipients.", senderKeyTargets.size(), recipientIds.size());
+        return senderKeyTargets;
+    }
+
+    private List<SendMessageResult> sendGroupMessageInternalWithLegacy(
+            final SignalServiceDataMessage message, final Set<RecipientId> recipientIds, final boolean isRecipientUpdate
+    ) throws IOException {
+        final var recipientIdList = new ArrayList<>(recipientIds);
+        final var addresses = recipientIdList.stream().map(addressResolver::resolveSignalServiceAddress).toList();
+        final var unidentifiedAccesses = unidentifiedAccessHelper.getAccessFor(recipientIdList);
+        final var messageSender = dependencies.getMessageSender();
         try {
-            var messageSender = dependencies.getMessageSender();
-            // isRecipientUpdate is true if we've already sent this message to some recipients in the past, otherwise false.
-            final var isRecipientUpdate = false;
-            final var recipientIdList = new ArrayList<>(recipientIds);
-            final var addresses = recipientIdList.stream().map(addressResolver::resolveSignalServiceAddress).toList();
-            return messageSender.sendDataMessage(addresses,
-                    unidentifiedAccessHelper.getAccessFor(recipientIdList),
+            final var results = messageSender.sendDataMessage(addresses,
+                    unidentifiedAccesses,
                     isRecipientUpdate,
                     ContentHint.DEFAULT,
                     message,
                     SignalServiceMessageSender.LegacyGroupEvents.EMPTY,
                     sendResult -> logger.trace("Partial message send result: {}", sendResult.isSuccess()),
                     () -> false);
+
+            final var successCount = results.stream().filter(SendMessageResult::isSuccess).count();
+            logger.debug("Successfully sent using 1:1 to {}/{} legacy targets.", successCount, recipientIdList.size());
+            return results;
         } catch (org.whispersystems.signalservice.api.crypto.UntrustedIdentityException e) {
             return List.of();
         }
     }
 
+    private List<SendMessageResult> sendGroupMessageInternalWithSenderKey(
+            final SignalServiceDataMessage message,
+            final Set<RecipientId> recipientIds,
+            final DistributionId distributionId,
+            final boolean isRecipientUpdate
+    ) throws IOException {
+        final var recipientIdList = new ArrayList<>(recipientIds);
+        final var messageSender = dependencies.getMessageSender();
+
+        long keyCreateTime = account.getSenderKeyStore()
+                .getCreateTimeForOurKey(account.getSelfRecipientId(), account.getDeviceId(), distributionId);
+        long keyAge = System.currentTimeMillis() - keyCreateTime;
+
+        if (keyCreateTime != -1 && keyAge > TimeUnit.DAYS.toMillis(14)) {
+            logger.debug("DistributionId {} was created at {} and is {} ms old (~{} days). Rotating.",
+                    distributionId,
+                    keyCreateTime,
+                    keyAge,
+                    TimeUnit.MILLISECONDS.toDays(keyAge));
+            account.getSenderKeyStore().deleteOurKey(account.getSelfRecipientId(), distributionId);
+        }
+
+        List<SignalServiceAddress> addresses = recipientIdList.stream()
+                .map(addressResolver::resolveSignalServiceAddress)
+                .collect(Collectors.toList());
+        List<UnidentifiedAccess> unidentifiedAccesses = recipientIdList.stream()
+                .map(unidentifiedAccessHelper::getAccessFor)
+                .map(Optional::get)
+                .map(UnidentifiedAccessPair::getTargetUnidentifiedAccess)
+                .map(Optional::get)
+                .collect(Collectors.toList());
+
+        try {
+            List<SendMessageResult> results = messageSender.sendGroupDataMessage(distributionId,
+                    addresses,
+                    unidentifiedAccesses,
+                    isRecipientUpdate,
+                    ContentHint.DEFAULT,
+                    message,
+                    SignalServiceMessageSender.SenderKeyGroupEvents.EMPTY);
+
+            final var successCount = results.stream().filter(SendMessageResult::isSuccess).count();
+            logger.debug("Successfully sent using sender key to {}/{} sender key targets.",
+                    successCount,
+                    addresses.size());
+
+            return results;
+        } catch (org.whispersystems.signalservice.api.crypto.UntrustedIdentityException e) {
+            return null;
+        } catch (InvalidUnidentifiedAccessHeaderException e) {
+            logger.warn("Someone had a bad UD header. Falling back to legacy sends.", e);
+            return null;
+        } catch (NoSessionException e) {
+            logger.warn("No session. Falling back to legacy sends.", e);
+            account.getSenderKeyStore().deleteOurKey(account.getSelfRecipientId(), distributionId);
+            return null;
+        } catch (InvalidKeyException e) {
+            logger.warn("Invalid key. Falling back to legacy sends.", e);
+            account.getSenderKeyStore().deleteOurKey(account.getSelfRecipientId(), distributionId);
+            return null;
+        } catch (InvalidRegistrationIdException e) {
+            logger.warn("Invalid registrationId. Falling back to legacy sends.", e);
+            return null;
+        } catch (NotFoundException e) {
+            logger.warn("Someone was unregistered. Falling back to legacy sends.", e);
+            return null;
+        }
+    }
+
     private SendMessageResult sendMessage(
             SignalServiceDataMessage message, RecipientId recipientId
     ) {
@@ -317,7 +495,7 @@ public class SendHelper {
         return sendSyncMessage(syncMessage);
     }
 
-    private void handlePossibleIdentityFailure(final SendMessageResult r) {
+    private void handleSendMessageResult(final SendMessageResult r) {
         if (r.getIdentityFailure() != null) {
             final var recipientId = recipientResolver.resolveRecipient(r.getAddress());
             identityFailureHandler.handleIdentityFailure(recipientId, r.getIdentityFailure());