From: AsamK Date: Sat, 17 Jun 2023 19:18:24 +0000 (+0200) Subject: Implement support for kyber pre keys X-Git-Tag: v0.12.0~25 X-Git-Url: https://git.nmode.ca/signal-cli/commitdiff_plain/306e38c9ee3ba1d1c1038f16bdf07422ab2c34bb?ds=sidebyside Implement support for kyber pre keys --- diff --git a/graalvm-config-dir/jni-config.json b/graalvm-config-dir/jni-config.json index 387549dd..33dbd070 100644 --- a/graalvm-config-dir/jni-config.json +++ b/graalvm-config-dir/jni-config.json @@ -46,7 +46,7 @@ }, { "name":"org.asamk.signal.manager.storage.protocol.SignalProtocolStore", - "methods":[{"name":"getIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"getIdentityKeyPair","parameterTypes":[] }, {"name":"getLocalRegistrationId","parameterTypes":[] }, {"name":"isTrustedIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey","org.signal.libsignal.protocol.state.IdentityKeyStore$Direction"] }, {"name":"loadPreKey","parameterTypes":["int"] }, {"name":"loadSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID"] }, {"name":"loadSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"loadSignedPreKey","parameterTypes":["int"] }, {"name":"removePreKey","parameterTypes":["int"] }, {"name":"saveIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey"] }, {"name":"storeSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID","org.signal.libsignal.protocol.groups.state.SenderKeyRecord"] }, {"name":"storeSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.state.SessionRecord"] }] + "methods":[{"name":"getIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"getIdentityKeyPair","parameterTypes":[] }, {"name":"getLocalRegistrationId","parameterTypes":[] }, {"name":"isTrustedIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey","org.signal.libsignal.protocol.state.IdentityKeyStore$Direction"] }, {"name":"loadKyberPreKey","parameterTypes":["int"] }, {"name":"loadPreKey","parameterTypes":["int"] }, {"name":"loadSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID"] }, {"name":"loadSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"loadSignedPreKey","parameterTypes":["int"] }, {"name":"markKyberPreKeyUsed","parameterTypes":["int"] }, {"name":"removePreKey","parameterTypes":["int"] }, {"name":"saveIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey"] }, {"name":"storeSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID","org.signal.libsignal.protocol.groups.state.SenderKeyRecord"] }, {"name":"storeSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.state.SessionRecord"] }] }, { "name":"org.asamk.signal.manager.storage.senderKeys.SenderKeyStore", @@ -133,6 +133,10 @@ "name":"org.signal.libsignal.protocol.state.IdentityKeyStore$Direction", "fields":[{"name":"RECEIVING"}, {"name":"SENDING"}] }, +{ + "name":"org.signal.libsignal.protocol.state.KyberPreKeyRecord", + "fields":[{"name":"unsafeHandle"}] +}, { "name":"org.signal.libsignal.protocol.state.KyberPreKeyStore" }, diff --git a/graalvm-config-dir/reflect-config.json b/graalvm-config-dir/reflect-config.json index 6a754766..0f08810c 100644 --- a/graalvm-config-dir/reflect-config.json +++ b/graalvm-config-dir/reflect-config.json @@ -2179,16 +2179,25 @@ "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity", "allDeclaredFields":true, "queryAllDeclaredMethods":true, - "queryAllDeclaredConstructors":true + "queryAllDeclaredConstructors":true, + "methods":[{"name":"","parameterTypes":[] }, {"name":"getKeyId","parameterTypes":[] }, {"name":"getPublicKey","parameterTypes":[] }, {"name":"getSignature","parameterTypes":[] }] }, { "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$ByteArrayDeserializer", "methods":[{"name":"","parameterTypes":[] }] }, +{ + "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$ByteArraySerializer", + "methods":[{"name":"","parameterTypes":[] }] +}, { "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$KEMPublicKeyDeserializer", "methods":[{"name":"","parameterTypes":[] }] }, +{ + "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$KEMPublicKeySerializer", + "methods":[{"name":"","parameterTypes":[] }] +}, { "name":"org.whispersystems.signalservice.internal.push.MismatchedDevices", "allDeclaredFields":true, @@ -2250,7 +2259,8 @@ "name":"org.whispersystems.signalservice.internal.push.PreKeyState", "allDeclaredFields":true, "allDeclaredMethods":true, - "allDeclaredConstructors":true + "allDeclaredConstructors":true, + "methods":[{"name":"getIdentityKey","parameterTypes":[] }, {"name":"getPreKeys","parameterTypes":[] }, {"name":"getSignedPreKey","parameterTypes":[] }] }, { "name":"org.whispersystems.signalservice.internal.push.PreKeyStatus", diff --git a/lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java b/lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java index b2b45a9c..f97e8f68 100644 --- a/lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java +++ b/lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java @@ -1,6 +1,7 @@ package org.asamk.signal.manager.config; import org.asamk.signal.manager.api.ServiceEnvironment; +import org.signal.libsignal.protocol.util.Medium; import org.whispersystems.signalservice.api.account.AccountAttributes; import org.whispersystems.signalservice.api.push.TrustStore; @@ -15,8 +16,9 @@ import okhttp3.Interceptor; public class ServiceConfig { - public final static int PREKEY_MINIMUM_COUNT = 20; + public final static int PREKEY_MINIMUM_COUNT = 10; public final static int PREKEY_BATCH_SIZE = 100; + public final static int PREKEY_MAXIMUM_ID = Medium.MAX_VALUE; public final static int MAX_ATTACHMENT_SIZE = 150 * 1024 * 1024; public final static long MAX_ENVELOPE_SIZE = 0; public final static long AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE = 10 * 1024 * 1024; diff --git a/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java b/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java index 06dc31b9..282b017a 100644 --- a/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java +++ b/lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java @@ -5,6 +5,7 @@ import org.asamk.signal.manager.internal.SignalDependencies; import org.asamk.signal.manager.storage.SignalAccount; import org.asamk.signal.manager.util.KeyUtils; import org.signal.libsignal.protocol.IdentityKeyPair; +import org.signal.libsignal.protocol.state.KyberPreKeyRecord; import org.signal.libsignal.protocol.state.PreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.slf4j.Logger; @@ -39,7 +40,9 @@ public class PreKeyHelper { if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) { refreshPreKeys(serviceIdType); } - // TODO kyber pre keys + if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) { + refreshKyberPreKeys(serviceIdType); + } } public void refreshPreKeys() throws IOException { @@ -47,7 +50,7 @@ public class PreKeyHelper { refreshPreKeys(ServiceIdType.PNI); } - public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException { + private void refreshPreKeys(ServiceIdType serviceIdType) throws IOException { final var identityKeyPair = account.getIdentityKeyPair(serviceIdType); if (identityKeyPair == null) { return; @@ -97,4 +100,61 @@ public class PreKeyHelper { return record; } + + private void refreshKyberPreKeys(ServiceIdType serviceIdType) throws IOException { + final var identityKeyPair = account.getIdentityKeyPair(serviceIdType); + if (identityKeyPair == null) { + return; + } + final var accountId = account.getAccountId(serviceIdType); + if (accountId == null) { + return; + } + try { + refreshKyberPreKeys(serviceIdType, identityKeyPair); + } catch (Exception e) { + logger.warn("Failed to store new pre keys, resetting preKey id offset", e); + account.resetKyberPreKeyOffsets(serviceIdType); + refreshKyberPreKeys(serviceIdType, identityKeyPair); + } + } + + private void refreshKyberPreKeys( + final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair + ) throws IOException { + final var oneTimePreKeys = generateKyberPreKeys(serviceIdType, identityKeyPair); + final var lastResortPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair); + + final var preKeyUpload = new PreKeyUpload(serviceIdType, + identityKeyPair.getPublicKey(), + null, + null, + lastResortPreKeyRecord, + oneTimePreKeys); + dependencies.getAccountManager().setPreKeys(preKeyUpload); + } + + private List generateKyberPreKeys( + ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair + ) { + final var offset = account.getKyberPreKeyIdOffset(serviceIdType); + + var records = KeyUtils.generateKyberPreKeyRecords(offset, + ServiceConfig.PREKEY_BATCH_SIZE, + identityKeyPair.getPrivateKey()); + account.addKyberPreKeys(serviceIdType, records); + + return records; + } + + private KyberPreKeyRecord generateLastResortKyberPreKey( + ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair + ) { + final var signedPreKeyId = account.getKyberPreKeyIdOffset(serviceIdType); + + var record = KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey()); + account.addLastResortKyberPreKey(serviceIdType, record); + + return record; + } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java b/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java index a15b4aab..7f6babe6 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java @@ -4,6 +4,7 @@ import com.zaxxer.hikari.HikariDataSource; import org.asamk.signal.manager.storage.groups.GroupStore; import org.asamk.signal.manager.storage.identities.IdentityKeyStore; +import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore; import org.asamk.signal.manager.storage.prekeys.PreKeyStore; import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore; import org.asamk.signal.manager.storage.recipients.RecipientStore; @@ -23,7 +24,7 @@ import java.sql.SQLException; public class AccountDatabase extends Database { private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class); - private static final long DATABASE_VERSION = 13; + private static final long DATABASE_VERSION = 14; private AccountDatabase(final HikariDataSource dataSource) { super(logger, DATABASE_VERSION, dataSource); @@ -40,6 +41,7 @@ public class AccountDatabase extends Database { StickerStore.createSql(connection); PreKeyStore.createSql(connection); SignedPreKeyStore.createSql(connection); + KyberPreKeyStore.createSql(connection); GroupStore.createSql(connection); SessionStore.createSql(connection); IdentityKeyStore.createSql(connection); @@ -328,5 +330,23 @@ public class AccountDatabase extends Database { } } } + if (oldVersion < 14) { + logger.debug("Updating database: Creating kyber_pre_key table"); + { + try (final var statement = connection.createStatement()) { + statement.executeUpdate(""" + CREATE TABLE kyber_pre_key ( + _id INTEGER PRIMARY KEY, + account_id_type INTEGER NOT NULL, + key_id INTEGER NOT NULL, + serialized BLOB NOT NULL, + is_last_resort INTEGER NOT NULL, + UNIQUE(account_id_type, key_id) + ) STRICT; + """); + } + } + + } } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java index 850b264f..76867b72 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java @@ -21,6 +21,7 @@ import org.asamk.signal.manager.storage.identities.IdentityKeyStore; import org.asamk.signal.manager.storage.identities.LegacyIdentityKeyStore; import org.asamk.signal.manager.storage.identities.SignalIdentityKeyStore; import org.asamk.signal.manager.storage.messageCache.MessageCache; +import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore; import org.asamk.signal.manager.storage.prekeys.LegacyPreKeyStore; import org.asamk.signal.manager.storage.prekeys.LegacySignedPreKeyStore; import org.asamk.signal.manager.storage.prekeys.PreKeyStore; @@ -51,11 +52,11 @@ import org.asamk.signal.manager.util.KeyUtils; import org.signal.libsignal.protocol.IdentityKeyPair; import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.SignalProtocolAddress; +import org.signal.libsignal.protocol.state.KyberPreKeyRecord; import org.signal.libsignal.protocol.state.PreKeyRecord; import org.signal.libsignal.protocol.state.SessionRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.util.KeyHelper; -import org.signal.libsignal.protocol.util.Medium; import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.slf4j.Logger; @@ -98,6 +99,7 @@ import java.util.List; import java.util.Optional; import java.util.function.Supplier; +import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_MAXIMUM_ID; import static org.asamk.signal.manager.config.ServiceConfig.getCapabilities; public class SignalAccount implements Closeable { @@ -105,7 +107,7 @@ public class SignalAccount implements Closeable { private final static Logger logger = LoggerFactory.getLogger(SignalAccount.class); private static final int MINIMUM_STORAGE_VERSION = 1; - private static final int CURRENT_STORAGE_VERSION = 6; + private static final int CURRENT_STORAGE_VERSION = 7; private final Object LOCK = new Object(); @@ -138,6 +140,10 @@ public class SignalAccount implements Closeable { private int aciNextSignedPreKeyId = 1; private int pniPreKeyIdOffset = 1; private int pniNextSignedPreKeyId = 1; + private int aciKyberPreKeyIdOffset = 1; + private int aciActiveLastResortKyberPreKeyId = -1; + private int pniKyberPreKeyIdOffset = 1; + private int pniActiveLastResortKyberPreKeyId = -1; private IdentityKeyPair aciIdentityKeyPair; private IdentityKeyPair pniIdentityKeyPair; private int localRegistrationId; @@ -151,8 +157,10 @@ public class SignalAccount implements Closeable { private SignalProtocolStore pniSignalProtocolStore; private PreKeyStore aciPreKeyStore; private SignedPreKeyStore aciSignedPreKeyStore; + private KyberPreKeyStore aciKyberPreKeyStore; private PreKeyStore pniPreKeyStore; private SignedPreKeyStore pniSignedPreKeyStore; + private KyberPreKeyStore pniKyberPreKeyStore; private SessionStore aciSessionStore; private SessionStore pniSessionStore; private IdentityKeyStore identityKeyStore; @@ -302,10 +310,14 @@ public class SignalAccount implements Closeable { private void clearAllPreKeys() { resetPreKeyOffsets(ServiceIdType.ACI); resetPreKeyOffsets(ServiceIdType.PNI); + resetKyberPreKeyOffsets(ServiceIdType.ACI); + resetKyberPreKeyOffsets(ServiceIdType.PNI); this.getAciPreKeyStore().removeAllPreKeys(); this.getAciSignedPreKeyStore().removeAllSignedPreKeys(); + this.getAciKyberPreKeyStore().removeAllKyberPreKeys(); this.getPniPreKeyStore().removeAllPreKeys(); this.getPniSignedPreKeyStore().removeAllSignedPreKeys(); + this.getPniKyberPreKeyStore().removeAllKyberPreKeys(); save(); } @@ -614,22 +626,42 @@ public class SignalAccount implements Closeable { if (rootNode.hasNonNull("preKeyIdOffset")) { aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1); } else { - aciPreKeyIdOffset = 1; + aciPreKeyIdOffset = getRandomPreKeyIdOffset(); } if (rootNode.hasNonNull("nextSignedPreKeyId")) { aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1); } else { - aciNextSignedPreKeyId = 1; + aciNextSignedPreKeyId = getRandomPreKeyIdOffset(); } if (rootNode.hasNonNull("pniPreKeyIdOffset")) { pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1); } else { - pniPreKeyIdOffset = 1; + pniPreKeyIdOffset = getRandomPreKeyIdOffset(); } if (rootNode.hasNonNull("pniNextSignedPreKeyId")) { pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1); } else { - pniNextSignedPreKeyId = 1; + pniNextSignedPreKeyId = getRandomPreKeyIdOffset(); + } + if (rootNode.hasNonNull("kyberPreKeyIdOffset")) { + aciKyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1); + } else { + aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); + } + if (rootNode.hasNonNull("activeLastResortKyberPreKeyId")) { + aciActiveLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId").asInt(-1); + } else { + aciActiveLastResortKyberPreKeyId = -1; + } + if (rootNode.hasNonNull("pniKyberPreKeyIdOffset")) { + pniKyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1); + } else { + pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); + } + if (rootNode.hasNonNull("pniActiveLastResortKyberPreKeyId")) { + pniActiveLastResortKyberPreKeyId = rootNode.get("pniActiveLastResortKyberPreKeyId").asInt(-1); + } else { + pniActiveLastResortKyberPreKeyId = -1; } if (rootNode.hasNonNull("profileKey")) { try { @@ -974,6 +1006,10 @@ public class SignalAccount implements Closeable { .put("nextSignedPreKeyId", aciNextSignedPreKeyId) .put("pniPreKeyIdOffset", pniPreKeyIdOffset) .put("pniNextSignedPreKeyId", pniNextSignedPreKeyId) + .put("kyberPreKeyIdOffset", aciKyberPreKeyIdOffset) + .put("activeLastResortKyberPreKeyId", aciActiveLastResortKyberPreKeyId) + .put("pniKyberPreKeyIdOffset", pniKyberPreKeyIdOffset) + .put("pniActiveLastResortKyberPreKeyId", pniActiveLastResortKyberPreKeyId) .put("profileKey", profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize())) .put("registered", registered) @@ -1019,15 +1055,19 @@ public class SignalAccount implements Closeable { public void resetPreKeyOffsets(final ServiceIdType serviceIdType) { if (serviceIdType.equals(ServiceIdType.ACI)) { - this.aciPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE); - this.aciNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE); + this.aciPreKeyIdOffset = getRandomPreKeyIdOffset(); + this.aciNextSignedPreKeyId = getRandomPreKeyIdOffset(); } else { - this.pniPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE); - this.pniNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE); + this.pniPreKeyIdOffset = getRandomPreKeyIdOffset(); + this.pniNextSignedPreKeyId = getRandomPreKeyIdOffset(); } save(); } + private static int getRandomPreKeyIdOffset() { + return new SecureRandom().nextInt(PREKEY_MAXIMUM_ID); + } + public void addPreKeys(ServiceIdType serviceIdType, List records) { if (serviceIdType.equals(ServiceIdType.ACI)) { addAciPreKeys(records); @@ -1036,26 +1076,26 @@ public class SignalAccount implements Closeable { } } - public void addAciPreKeys(List records) { + private void addAciPreKeys(List records) { for (var record : records) { if (aciPreKeyIdOffset != record.getId()) { logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset); throw new AssertionError("Invalid pre key id"); } getAciPreKeyStore().storePreKey(record.getId(), record); - aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % Medium.MAX_VALUE; + aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; } save(); } - public void addPniPreKeys(List records) { + private void addPniPreKeys(List records) { for (var record : records) { if (pniPreKeyIdOffset != record.getId()) { logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset); throw new AssertionError("Invalid pre key id"); } getPniPreKeyStore().storePreKey(record.getId(), record); - pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % Medium.MAX_VALUE; + pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; } save(); } @@ -1074,7 +1114,7 @@ public class SignalAccount implements Closeable { throw new AssertionError("Invalid signed pre key id"); } getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record); - aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % Medium.MAX_VALUE; + aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID; save(); } @@ -1084,7 +1124,84 @@ public class SignalAccount implements Closeable { throw new AssertionError("Invalid signed pre key id"); } getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record); - pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % Medium.MAX_VALUE; + pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID; + save(); + } + + public void resetKyberPreKeyOffsets(final ServiceIdType serviceIdType) { + if (serviceIdType.equals(ServiceIdType.ACI)) { + this.aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); + this.aciActiveLastResortKyberPreKeyId = -1; + } else { + this.pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset(); + this.pniActiveLastResortKyberPreKeyId = -1; + } + save(); + } + + public void addKyberPreKeys(ServiceIdType serviceIdType, List records) { + if (serviceIdType.equals(ServiceIdType.ACI)) { + addAciKyberPreKeys(records); + } else { + addPniKyberPreKeys(records); + } + } + + private void addAciKyberPreKeys(List records) { + for (var record : records) { + if (aciKyberPreKeyIdOffset != record.getId()) { + logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), aciKyberPreKeyIdOffset); + throw new AssertionError("Invalid kyber pre key id"); + } + getAciKyberPreKeyStore().storeKyberPreKey(record.getId(), record); + aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; + } + save(); + } + + private void addPniKyberPreKeys(List records) { + for (var record : records) { + if (pniKyberPreKeyIdOffset != record.getId()) { + logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), pniKyberPreKeyIdOffset); + throw new AssertionError("Invalid kyber pre key id"); + } + getPniKyberPreKeyStore().storeKyberPreKey(record.getId(), record); + pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; + } + save(); + } + + public void addLastResortKyberPreKey(ServiceIdType serviceIdType, KyberPreKeyRecord record) { + if (serviceIdType.equals(ServiceIdType.ACI)) { + addAciLastResortKyberPreKey(record); + } else { + addPniLastResortKyberPreKey(record); + } + } + + public void addAciLastResortKyberPreKey(KyberPreKeyRecord record) { + if (aciKyberPreKeyIdOffset != record.getId()) { + logger.error("Invalid last resort kyber pre key id {}, expected {}", + record.getId(), + aciKyberPreKeyIdOffset); + throw new AssertionError("Invalid last resort kyber pre key id"); + } + getAciKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record); + aciActiveLastResortKyberPreKeyId = record.getId(); + aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; + save(); + } + + public void addPniLastResortKyberPreKey(KyberPreKeyRecord record) { + if (pniKyberPreKeyIdOffset != record.getId()) { + logger.error("Invalid last resort kyber pre key id {}, expected {}", + record.getId(), + pniKyberPreKeyIdOffset); + throw new AssertionError("Invalid last resort kyber pre key id"); + } + getPniKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record); + pniActiveLastResortKyberPreKeyId = record.getId(); + pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID; save(); } @@ -1126,6 +1243,7 @@ public class SignalAccount implements Closeable { return getOrCreate(() -> aciSignalProtocolStore, () -> aciSignalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(), getAciSignedPreKeyStore(), + getAciKyberPreKeyStore(), getAciSessionStore(), getAciIdentityKeyStore(), getSenderKeyStore(), @@ -1136,6 +1254,7 @@ public class SignalAccount implements Closeable { return getOrCreate(() -> pniSignalProtocolStore, () -> pniSignalProtocolStore = new SignalProtocolStore(getPniPreKeyStore(), getPniSignedPreKeyStore(), + getPniKyberPreKeyStore(), getPniSessionStore(), getPniIdentityKeyStore(), getSenderKeyStore(), @@ -1152,6 +1271,11 @@ public class SignalAccount implements Closeable { () -> aciSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.ACI)); } + private KyberPreKeyStore getAciKyberPreKeyStore() { + return getOrCreate(() -> aciKyberPreKeyStore, + () -> aciKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.ACI)); + } + private PreKeyStore getPniPreKeyStore() { return getOrCreate(() -> pniPreKeyStore, () -> pniPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.PNI)); @@ -1162,6 +1286,11 @@ public class SignalAccount implements Closeable { () -> pniSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.PNI)); } + private KyberPreKeyStore getPniKyberPreKeyStore() { + return getOrCreate(() -> pniKyberPreKeyStore, + () -> pniKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.PNI)); + } + public SessionStore getAciSessionStore() { return getOrCreate(() -> aciSessionStore, () -> aciSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.ACI)); @@ -1362,6 +1491,9 @@ public class SignalAccount implements Closeable { identityKeyStore.deleteIdentity(this.pni); getPniPreKeyStore().removeAllPreKeys(); getPniSignedPreKeyStore().removeAllSignedPreKeys(); + getPniKyberPreKeyStore().removeAllKyberPreKeys(); + aciActiveLastResortKyberPreKeyId = -1; + pniActiveLastResortKyberPreKeyId = -1; } this.pni = updatedPni; @@ -1406,10 +1538,6 @@ public class SignalAccount implements Closeable { save(); } - public byte[] getEncryptedDeviceName() { - return encryptedDeviceName == null ? null : Base64.getDecoder().decode(encryptedDeviceName); - } - public void setEncryptedDeviceName(final String encryptedDeviceName) { this.encryptedDeviceName = encryptedDeviceName; save(); @@ -1590,6 +1718,10 @@ public class SignalAccount implements Closeable { return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId; } + public int getKyberPreKeyIdOffset(ServiceIdType serviceIdType) { + return serviceIdType.equals(ServiceIdType.ACI) ? aciKyberPreKeyIdOffset : pniKyberPreKeyIdOffset; + } + public boolean isRegistered() { return registered; } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java new file mode 100644 index 00000000..54fbba43 --- /dev/null +++ b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java @@ -0,0 +1,211 @@ +package org.asamk.signal.manager.storage.prekeys; + +import org.asamk.signal.manager.storage.Database; +import org.asamk.signal.manager.storage.Utils; +import org.signal.libsignal.protocol.InvalidKeyIdException; +import org.signal.libsignal.protocol.InvalidMessageException; +import org.signal.libsignal.protocol.state.KyberPreKeyRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore; +import org.whispersystems.signalservice.api.push.ServiceIdType; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; + +public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore { + + private static final String TABLE_KYBER_PRE_KEY = "kyber_pre_key"; + private final static Logger logger = LoggerFactory.getLogger(KyberPreKeyStore.class); + + private final Database database; + private final int accountIdType; + + public static void createSql(Connection connection) throws SQLException { + // When modifying the CREATE statement here, also add a migration in AccountDatabase.java + try (final var statement = connection.createStatement()) { + statement.executeUpdate(""" + CREATE TABLE kyber_pre_key ( + _id INTEGER PRIMARY KEY, + account_id_type INTEGER NOT NULL, + key_id INTEGER NOT NULL, + serialized BLOB NOT NULL, + is_last_resort INTEGER NOT NULL, + UNIQUE(account_id_type, key_id) + ) STRICT; + """); + } + } + + public KyberPreKeyStore(final Database database, final ServiceIdType serviceIdType) { + this.database = database; + this.accountIdType = Utils.getAccountIdType(serviceIdType); + } + + @Override + public KyberPreKeyRecord loadKyberPreKey(final int keyId) throws InvalidKeyIdException { + final var kyberPreKey = getPreKey(keyId); + if (kyberPreKey == null) { + throw new InvalidKeyIdException("No such kyber pre key record: " + keyId); + } + return kyberPreKey; + } + + @Override + public List loadKyberPreKeys() { + final var sql = ( + """ + SELECT p.serialized + FROM %s p + WHERE p.account_id_type = ? + """ + ).formatted(TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + return Utils.executeQueryForStream(statement, this::getKyberPreKeyRecordFromResultSet).toList(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed read from kyber_pre_key store", e); + } + } + + @Override + public List loadLastResortKyberPreKeys() { + final var sql = ( + """ + SELECT p.serialized + FROM %s p + WHERE p.account_id_type = ? AND p.is_last_resort = TRUE + """ + ).formatted(TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + return Utils.executeQueryForStream(statement, this::getKyberPreKeyRecordFromResultSet).toList(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed read from kyber_pre_key store", e); + } + } + + @Override + public void storeLastResortKyberPreKey(final int keyId, final KyberPreKeyRecord record) { + storeKyberPreKey(keyId, record, true); + } + + @Override + public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record) { + storeKyberPreKey(keyId, record, false); + } + + public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record, final boolean isLastResort) { + final var sql = ( + """ + INSERT INTO %s (account_id_type, key_id, serialized, is_last_resort) + VALUES (?, ?, ?, ?) + """ + ).formatted(TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setInt(2, keyId); + statement.setBytes(3, record.serialize()); + statement.setBoolean(4, isLastResort); + statement.executeUpdate(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update kyber_pre_key store", e); + } + } + + @Override + public boolean containsKyberPreKey(final int keyId) { + return getPreKey(keyId) != null; + } + + @Override + public void markKyberPreKeyUsed(final int keyId) { + final var sql = ( + """ + DELETE FROM %s AS p + WHERE p.account_id_type = ? AND p.key_id = ? AND p.is_last_resort = FALSE + """ + ).formatted(TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setInt(2, keyId); + statement.executeUpdate(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update kyber_pre_key store", e); + } + } + + @Override + public void removeKyberPreKey(final int keyId) { + final var sql = ( + """ + DELETE FROM %s AS p + WHERE p.account_id_type = ? AND p.key_id = ? + """ + ).formatted(TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setInt(2, keyId); + statement.executeUpdate(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update kyber_pre_key store", e); + } + } + + public void removeAllKyberPreKeys() { + final var sql = ( + """ + DELETE FROM %s AS p + WHERE p.account_id_type = ? + """ + ).formatted(TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.executeUpdate(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update kyber_pre_key store", e); + } + } + + private KyberPreKeyRecord getPreKey(int keyId) { + final var sql = ( + """ + SELECT p.serialized + FROM %s p + WHERE p.account_id_type = ? AND p.key_id = ? + """ + ).formatted(TABLE_KYBER_PRE_KEY); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setInt(2, keyId); + return Utils.executeQueryForOptional(statement, this::getKyberPreKeyRecordFromResultSet).orElse(null); + } + } catch (SQLException e) { + throw new RuntimeException("Failed read from kyber_pre_key store", e); + } + } + + private KyberPreKeyRecord getKyberPreKeyRecordFromResultSet(ResultSet resultSet) throws SQLException { + try { + final var serialized = resultSet.getBytes("serialized"); + return new KyberPreKeyRecord(serialized); + } catch (InvalidMessageException e) { + return null; + } + } +} diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java index cc51a20a..f6dc271e 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java @@ -49,7 +49,7 @@ public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeySt public PreKeyRecord loadPreKey(int preKeyId) throws InvalidKeyIdException { final var preKey = getPreKey(preKeyId); if (preKey == null) { - throw new InvalidKeyIdException("No such signed pre key record: " + preKeyId); + throw new InvalidKeyIdException("No such pre key record: " + preKeyId); } return preKey; } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java index b4e27afc..e565e915 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java @@ -14,6 +14,7 @@ import org.signal.libsignal.protocol.state.SessionRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyStore; import org.whispersystems.signalservice.api.SignalServiceAccountDataStore; +import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore; import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore; import org.whispersystems.signalservice.api.SignalServiceSessionStore; import org.whispersystems.signalservice.api.push.DistributionId; @@ -28,6 +29,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore { private final PreKeyStore preKeyStore; private final SignedPreKeyStore signedPreKeyStore; + private final SignalServiceKyberPreKeyStore kyberPreKeyStore; private final SignalServiceSessionStore sessionStore; private final IdentityKeyStore identityKeyStore; private final SignalServiceSenderKeyStore senderKeyStore; @@ -36,6 +38,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore { public SignalProtocolStore( final PreKeyStore preKeyStore, final SignedPreKeyStore signedPreKeyStore, + final SignalServiceKyberPreKeyStore kyberPreKeyStore, final SignalServiceSessionStore sessionStore, final IdentityKeyStore identityKeyStore, final SignalServiceSenderKeyStore senderKeyStore, @@ -43,6 +46,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore { ) { this.preKeyStore = preKeyStore; this.signedPreKeyStore = signedPreKeyStore; + this.kyberPreKeyStore = kyberPreKeyStore; this.sessionStore = sessionStore; this.identityKeyStore = identityKeyStore; this.senderKeyStore = senderKeyStore; @@ -201,45 +205,41 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore { @Override public KyberPreKeyRecord loadKyberPreKey(final int kyberPreKeyId) throws InvalidKeyIdException { - // TODO - throw new InvalidKeyIdException("Missing kyber prekey with ID: $kyberPreKeyId"); + return kyberPreKeyStore.loadKyberPreKey(kyberPreKeyId); } @Override public List loadKyberPreKeys() { - // TODO - return List.of(); + return kyberPreKeyStore.loadKyberPreKeys(); } @Override public void storeKyberPreKey(final int kyberPreKeyId, final KyberPreKeyRecord record) { - // TODO + kyberPreKeyStore.storeKyberPreKey(kyberPreKeyId, record); } @Override public boolean containsKyberPreKey(final int kyberPreKeyId) { - // TODO - return false; + return kyberPreKeyStore.containsKyberPreKey(kyberPreKeyId); } @Override public void markKyberPreKeyUsed(final int kyberPreKeyId) { - // TODO + kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId); } @Override public List loadLastResortKyberPreKeys() { - // TODO - return List.of(); + return kyberPreKeyStore.loadLastResortKyberPreKeys(); } @Override public void removeKyberPreKey(final int i) { - // TODO + kyberPreKeyStore.removeKyberPreKey(i); } @Override public void storeLastResortKyberPreKey(final int i, final KyberPreKeyRecord kyberPreKeyRecord) { - // TODO + kyberPreKeyStore.storeLastResortKyberPreKey(i, kyberPreKeyRecord); } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java index 0f267583..3608e0c1 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java @@ -406,9 +406,7 @@ public class SessionStore implements SignalServiceSessionStore { } private static boolean isActive(SessionRecord record) { - return record != null - && record.hasSenderChain() - && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION; + return record != null && record.hasSenderChain(); } record Key(ServiceId serviceId, int deviceId) {} diff --git a/lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java b/lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java index d87868ee..3ee5657e 100644 --- a/lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java +++ b/lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java @@ -5,9 +5,11 @@ import org.signal.libsignal.protocol.IdentityKeyPair; import org.signal.libsignal.protocol.InvalidKeyException; import org.signal.libsignal.protocol.ecc.Curve; import org.signal.libsignal.protocol.ecc.ECPrivateKey; +import org.signal.libsignal.protocol.kem.KEMKeyPair; +import org.signal.libsignal.protocol.kem.KEMKeyType; +import org.signal.libsignal.protocol.state.KyberPreKeyRecord; import org.signal.libsignal.protocol.state.PreKeyRecord; import org.signal.libsignal.protocol.state.SignedPreKeyRecord; -import org.signal.libsignal.protocol.util.Medium; import org.signal.libsignal.zkgroup.InvalidInputException; import org.signal.libsignal.zkgroup.profiles.ProfileKey; import org.whispersystems.signalservice.api.kbs.MasterKey; @@ -17,6 +19,8 @@ import java.util.ArrayList; import java.util.Base64; import java.util.List; +import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_MAXIMUM_ID; + public class KeyUtils { private static final SecureRandom secureRandom = new SecureRandom(); @@ -46,7 +50,7 @@ public class KeyUtils { public static List generatePreKeyRecords(final int offset, final int batchSize) { var records = new ArrayList(batchSize); for (var i = 0; i < batchSize; i++) { - var preKeyId = (offset + i) % Medium.MAX_VALUE; + var preKeyId = (offset + i) % PREKEY_MAXIMUM_ID; var keyPair = Curve.generateKeyPair(); var record = new PreKeyRecord(preKeyId, keyPair); @@ -68,6 +72,24 @@ public class KeyUtils { return new SignedPreKeyRecord(signedPreKeyId, System.currentTimeMillis(), keyPair, signature); } + public static List generateKyberPreKeyRecords( + final int offset, final int batchSize, final ECPrivateKey privateKey + ) { + var records = new ArrayList(batchSize); + for (var i = 0; i < batchSize; i++) { + var preKeyId = (offset + i) % PREKEY_MAXIMUM_ID; + records.add(generateKyberPreKeyRecord(preKeyId, privateKey)); + } + return records; + } + + public static KyberPreKeyRecord generateKyberPreKeyRecord(final int preKeyId, final ECPrivateKey privateKey) { + KEMKeyPair keyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024); + byte[] signature = privateKey.calculateSignature(keyPair.getPublicKey().serialize()); + + return new KyberPreKeyRecord(preKeyId, System.currentTimeMillis(), keyPair, signature); + } + public static ProfileKey createProfileKey() { try { return new ProfileKey(getSecretBytes(32));