From: AsamK Date: Sun, 15 Oct 2023 20:36:45 +0000 (+0200) Subject: Implement full CDSI refresh X-Git-Tag: v0.12.3~12 X-Git-Url: https://git.nmode.ca/signal-cli/commitdiff_plain/5c39344cff49befe4f48e34fc84be8238a4df4fa Implement full CDSI refresh --- diff --git a/lib/src/main/java/org/asamk/signal/manager/helper/RecipientHelper.java b/lib/src/main/java/org/asamk/signal/manager/helper/RecipientHelper.java index 89354524..322bd032 100644 --- a/lib/src/main/java/org/asamk/signal/manager/helper/RecipientHelper.java +++ b/lib/src/main/java/org/asamk/signal/manager/helper/RecipientHelper.java @@ -14,6 +14,7 @@ import org.whispersystems.signalservice.api.push.ServiceId; import org.whispersystems.signalservice.api.push.ServiceId.ACI; import org.whispersystems.signalservice.api.push.ServiceId.PNI; import org.whispersystems.signalservice.api.push.SignalServiceAddress; +import org.whispersystems.signalservice.api.push.exceptions.CdsiInvalidTokenException; import org.whispersystems.signalservice.api.services.CdsiV2Service; import org.whispersystems.util.Base64UrlSafe; @@ -115,6 +116,10 @@ public class RecipientHelper { } } + public void refreshUsers() throws IOException { + getRegisteredUsers(account.getRecipientStore().getAllNumbers(), false); + } + public RecipientId refreshRegisteredUser(RecipientId recipientId) throws IOException, UnregisteredRecipientException { final var address = resolveSignalServiceAddress(recipientId); if (address.getNumber().isEmpty()) { @@ -126,8 +131,16 @@ public class RecipientHelper { .resolveRecipientTrusted(new SignalServiceAddress(serviceId, number)); } - public Map getRegisteredUsers(final Set numbers) throws IOException { - Map registeredUsers = getRegisteredUsersV2(numbers, true); + public Map getRegisteredUsers( + final Set numbers + ) throws IOException { + return getRegisteredUsers(numbers, true); + } + + private Map getRegisteredUsers( + final Set numbers, final boolean isPartialRefresh + ) throws IOException { + Map registeredUsers = getRegisteredUsersV2(numbers, isPartialRefresh, true); // Store numbers as recipients, so we have the number/uuid association registeredUsers.forEach((number, u) -> account.getRecipientTrustedResolver() @@ -139,7 +152,7 @@ public class RecipientHelper { private ServiceId getRegisteredUserByNumber(final String number) throws IOException, UnregisteredRecipientException { final Map aciMap; try { - aciMap = getRegisteredUsers(Set.of(number)); + aciMap = getRegisteredUsers(Set.of(number), true); } catch (NumberFormatException e) { throw new UnregisteredRecipientException(new org.asamk.signal.manager.api.RecipientAddress(null, number)); } @@ -151,22 +164,50 @@ public class RecipientHelper { } private Map getRegisteredUsersV2( - final Set numbers, boolean useCompat + final Set numbers, boolean isPartialRefresh, boolean useCompat ) throws IOException { - // Only partial refresh is implemented here + final var previousNumbers = isPartialRefresh ? Set.of() : account.getCdsiStore().getAllNumbers(); + final var newNumbers = new HashSet<>(numbers) {{ + removeAll(previousNumbers); + }}; + if (newNumbers.isEmpty() && previousNumbers.isEmpty()) { + logger.debug("No new numbers to query."); + return Map.of(); + } + logger.trace("Querying CDSI for {} new numbers ({} previous)", newNumbers.size(), previousNumbers.size()); + final var token = previousNumbers.isEmpty() + ? Optional.empty() + : Optional.ofNullable(account.getCdsiToken()); + final CdsiV2Service.Response response; try { response = dependencies.getAccountManager() - .getRegisteredUsersWithCdsi(Set.of(), - numbers, + .getRegisteredUsersWithCdsi(previousNumbers, + newNumbers, account.getRecipientStore().getServiceIdToProfileKeyMap(), useCompat, - Optional.empty(), + token, serviceEnvironmentConfig.cdsiMrenclave(), null, - token -> { - // Not storing for partial refresh + newToken -> { + if (isPartialRefresh) { + account.getCdsiStore().updateAfterPartialCdsQuery(newNumbers); + // Not storing newToken for partial refresh + } else { + final var fullNumbers = new HashSet<>(previousNumbers) {{ + addAll(newNumbers); + }}; + final var seenNumbers = new HashSet<>(numbers) {{ + addAll(newNumbers); + }}; + account.getCdsiStore().updateAfterFullCdsQuery(fullNumbers, seenNumbers); + account.setCdsiToken(newToken); + } }); + } catch (CdsiInvalidTokenException e) { + account.setCdsiToken(null); + account.getCdsiStore().clearAll(); + throw e; } catch (NumberFormatException e) { throw new IOException(e); } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java b/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java index 3545eaa7..e09cd4c2 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java @@ -9,6 +9,7 @@ import org.asamk.signal.manager.storage.keyValue.KeyValueStore; import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore; import org.asamk.signal.manager.storage.prekeys.PreKeyStore; import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore; +import org.asamk.signal.manager.storage.recipients.CdsiStore; import org.asamk.signal.manager.storage.recipients.RecipientStore; import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore; import org.asamk.signal.manager.storage.senderKeys.SenderKeyRecordStore; @@ -31,7 +32,7 @@ import java.util.UUID; public class AccountDatabase extends Database { private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class); - private static final long DATABASE_VERSION = 17; + private static final long DATABASE_VERSION = 18; private AccountDatabase(final HikariDataSource dataSource) { super(logger, DATABASE_VERSION, dataSource); @@ -55,6 +56,7 @@ public class AccountDatabase extends Database { SenderKeyRecordStore.createSql(connection); SenderKeySharedStore.createSql(connection); KeyValueStore.createSql(connection); + CdsiStore.createSql(connection); } @Override @@ -517,5 +519,17 @@ public class AccountDatabase extends Database { """); } } + if (oldVersion < 18) { + logger.debug("Updating database: Adding cdsi table"); + try (final var statement = connection.createStatement()) { + statement.executeUpdate(""" + CREATE TABLE cdsi ( + _id INTEGER PRIMARY KEY, + number TEXT NOT NULL UNIQUE, + last_seen_at INTEGER NOT NULL + ) STRICT; + """); + } + } } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java index 2cfd5cd7..fbfd88f6 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java @@ -33,6 +33,7 @@ import org.asamk.signal.manager.storage.profiles.LegacyProfileStore; import org.asamk.signal.manager.storage.profiles.ProfileStore; import org.asamk.signal.manager.storage.protocol.LegacyJsonSignalProtocolStore; import org.asamk.signal.manager.storage.protocol.SignalProtocolStore; +import org.asamk.signal.manager.storage.recipients.CdsiStore; import org.asamk.signal.manager.storage.recipients.LegacyRecipientStore; import org.asamk.signal.manager.storage.recipients.LegacyRecipientStore2; import org.asamk.signal.manager.storage.recipients.RecipientAddress; @@ -145,6 +146,7 @@ public class SignalAccount implements Closeable { private final KeyValueEntry lastReceiveTimestamp = new KeyValueEntry<>("last-receive-timestamp", long.class, 0L); + private final KeyValueEntry cdsiToken = new KeyValueEntry<>("cdsi-token", byte[].class); private final KeyValueEntry storageManifestVersion = new KeyValueEntry<>("storage-manifest-version", long.class, -1L); @@ -160,6 +162,7 @@ public class SignalAccount implements Closeable { private StickerStore stickerStore; private ConfigurationStore configurationStore; private KeyValueStore keyValueStore; + private CdsiStore cdsiStore; private MessageCache messageCache; private MessageSendLogStore messageSendLogStore; @@ -1220,6 +1223,10 @@ public class SignalAccount implements Closeable { return getRecipientStore(); } + public CdsiStore getCdsiStore() { + return getOrCreate(() -> cdsiStore, () -> cdsiStore = new CdsiStore(getAccountDatabase())); + } + private RecipientIdCreator getRecipientIdCreator() { return recipientId -> getRecipientStore().create(recipientId); } @@ -1571,6 +1578,14 @@ public class SignalAccount implements Closeable { } } + public byte[] getCdsiToken() { + return getKeyValueStore().getEntry(cdsiToken); + } + + public void setCdsiToken(final byte[] value) { + getKeyValueStore().storeEntry(cdsiToken, value); + } + public ProfileKey getProfileKey() { return profileKey; } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/keyValue/KeyValueStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/keyValue/KeyValueStore.java index 5785dfdc..90c932af 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/keyValue/KeyValueStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/keyValue/KeyValueStore.java @@ -90,6 +90,8 @@ public class KeyValueStore { value = resultSet.getLong("value"); } else if (clazz == boolean.class || clazz == Boolean.class) { value = resultSet.getBoolean("value"); + } else if (clazz == byte[].class || clazz == Byte[].class) { + value = resultSet.getBytes("value"); } else if (clazz == String.class) { value = resultSet.getString("value"); } else if (Enum.class.isAssignableFrom(clazz)) { @@ -134,6 +136,12 @@ public class KeyValueStore { } else { statement.setBoolean(parameterIndex, (boolean) value); } + } else if (clazz == byte[].class || clazz == Byte[].class) { + if (value == null) { + statement.setNull(parameterIndex, Types.BLOB); + } else { + statement.setBytes(parameterIndex, (byte[]) value); + } } else if (clazz == String.class) { if (value == null) { statement.setNull(parameterIndex, Types.VARCHAR); diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/recipients/CdsiStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/recipients/CdsiStore.java new file mode 100644 index 00000000..fd0846e9 --- /dev/null +++ b/lib/src/main/java/org/asamk/signal/manager/storage/recipients/CdsiStore.java @@ -0,0 +1,170 @@ +package org.asamk.signal.manager.storage.recipients; + +import org.asamk.signal.manager.storage.Database; +import org.asamk.signal.manager.storage.Utils; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; + +public class CdsiStore { + + private static final String TABLE_CDSI = "cdsi"; + + private final Database database; + + public static void createSql(Connection connection) throws SQLException { + // When modifying the CREATE statement here, also add a migration in AccountDatabase.java + try (final var statement = connection.createStatement()) { + statement.executeUpdate(""" + CREATE TABLE cdsi ( + _id INTEGER PRIMARY KEY, + number TEXT NOT NULL UNIQUE, + last_seen_at INTEGER NOT NULL + ) STRICT; + """); + } + } + + public CdsiStore(final Database database) { + this.database = database; + } + + public Set getAllNumbers() { + try (final var connection = database.getConnection()) { + return getAllNumbers(connection); + } catch (SQLException e) { + throw new RuntimeException("Failed read from cdsi store", e); + } + } + + /** + * Saves the set of e164 numbers used after a full refresh. + * + * @param fullNumbers All the e164 numbers used in the last CDS query (previous and new). + * @param seenNumbers The E164 numbers that were seen in either the system contacts or recipients table. This is different from fullNumbers in that fullNumbers + * includes every number we've ever seen, even if it's not in our contacts anymore. + */ + public void updateAfterFullCdsQuery(Set fullNumbers, Set seenNumbers) { + final var lastSeen = System.currentTimeMillis(); + try (final var connection = database.getConnection()) { + final var existingNumbers = getAllNumbers(connection); + + final var removedNumbers = new HashSet<>(existingNumbers) {{ + removeAll(fullNumbers); + }}; + removeNumbers(connection, removedNumbers); + + final var addedNumbers = new HashSet<>(fullNumbers) {{ + removeAll(existingNumbers); + }}; + addNumbers(connection, addedNumbers, lastSeen); + + updateLastSeen(connection, seenNumbers, lastSeen); + } catch (SQLException e) { + throw new RuntimeException("Failed update cdsi store", e); + } + } + + /** + * Updates after a partial CDS query. Will not insert new entries. + * Instead, this will simply update the lastSeen timestamp of any entry we already have. + * + * @param seenNumbers The newly-added E164 numbers that we hadn't previously queried for. + */ + public void updateAfterPartialCdsQuery(Set seenNumbers) { + final var lastSeen = System.currentTimeMillis(); + + try (final var connection = database.getConnection()) { + updateLastSeen(connection, seenNumbers, lastSeen); + } catch (SQLException e) { + throw new RuntimeException("Failed update cdsi store", e); + } + } + + private static Set getAllNumbers(final Connection connection) throws SQLException { + final var sql = ( + """ + SELECT c.number + FROM %s c + """ + ).formatted(TABLE_CDSI); + try (final var statement = connection.prepareStatement(sql)) { + try (var result = Utils.executeQueryForStream(statement, r -> r.getString("number"))) { + return result.collect(Collectors.toSet()); + } + } + } + + private static void removeNumbers( + final Connection connection, final Set numbers + ) throws SQLException { + final var sql = ( + """ + DELETE FROM %s + WHERE number = ? + """ + ).formatted(TABLE_CDSI); + try (final var statement = connection.prepareStatement(sql)) { + for (final var number : numbers) { + statement.setString(1, number); + statement.executeUpdate(); + } + } + } + + private static void addNumbers( + final Connection connection, final Set numbers, final long lastSeen + ) throws SQLException { + final var sql = ( + """ + INSERT INTO %s (number, last_seen_at) + VALUES (?, ?) + ON CONFLICT (number) DO UPDATE SET last_seen_at = excluded.last_seen_at + """ + ).formatted(TABLE_CDSI); + try (final var statement = connection.prepareStatement(sql)) { + for (final var number : numbers) { + statement.setString(1, number); + statement.setLong(2, lastSeen); + statement.executeUpdate(); + } + } + } + + private static void updateLastSeen( + final Connection connection, final Set numbers, final long lastSeen + ) throws SQLException { + final var sql = ( + """ + UPDATE %s + SET last_seen_at = ? + WHERE number = ? + """ + ).formatted(TABLE_CDSI); + try (final var statement = connection.prepareStatement(sql)) { + for (final var number : numbers) { + statement.setLong(1, lastSeen); + statement.setString(2, number); + statement.executeUpdate(); + } + } + } + + public void clearAll() { + final var sql = ( + """ + TRUNCATE %s + """ + ).formatted(TABLE_CDSI); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.executeUpdate(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update cdsi store", e); + } + } +} diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientStore.java index 9ef734c7..62ab525a 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientStore.java @@ -383,6 +383,25 @@ public class RecipientStore implements RecipientIdCreator, RecipientResolver, Re } } + public Set getAllNumbers() { + final var sql = ( + """ + SELECT r.number + FROM %s r + WHERE r.number IS NOT NULL + """ + ).formatted(TABLE_RECIPIENT); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + return Utils.executeQueryForStream(statement, resultSet -> resultSet.getString("number")) + .filter(Objects::nonNull) + .collect(Collectors.toSet()); + } + } catch (SQLException e) { + throw new RuntimeException("Failed read from recipient store", e); + } + } + public Map getServiceIdToProfileKeyMap() { final var sql = ( """