]> nmode's Git Repositories - signal-cli/commitdiff
Implement support for kyber pre keys
authorAsamK <asamk@gmx.de>
Sat, 17 Jun 2023 19:18:24 +0000 (21:18 +0200)
committerAsamK <asamk@gmx.de>
Sat, 17 Jun 2023 20:32:05 +0000 (22:32 +0200)
graalvm-config-dir/jni-config.json
graalvm-config-dir/reflect-config.json
lib/src/main/java/org/asamk/signal/manager/config/ServiceConfig.java
lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java
lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java
lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java
lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java [new file with mode: 0644]
lib/src/main/java/org/asamk/signal/manager/storage/prekeys/PreKeyStore.java
lib/src/main/java/org/asamk/signal/manager/storage/protocol/SignalProtocolStore.java
lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
lib/src/main/java/org/asamk/signal/manager/util/KeyUtils.java

index 387549dd46c4b58430e1708b60f811d6151bfaf0..33dbd070a31198c28db806c939c46573bb4e6361 100644 (file)
@@ -46,7 +46,7 @@
 },
 {
   "name":"org.asamk.signal.manager.storage.protocol.SignalProtocolStore",
-  "methods":[{"name":"getIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"getIdentityKeyPair","parameterTypes":[] }, {"name":"getLocalRegistrationId","parameterTypes":[] }, {"name":"isTrustedIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey","org.signal.libsignal.protocol.state.IdentityKeyStore$Direction"] }, {"name":"loadPreKey","parameterTypes":["int"] }, {"name":"loadSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID"] }, {"name":"loadSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"loadSignedPreKey","parameterTypes":["int"] }, {"name":"removePreKey","parameterTypes":["int"] }, {"name":"saveIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey"] }, {"name":"storeSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID","org.signal.libsignal.protocol.groups.state.SenderKeyRecord"] }, {"name":"storeSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.state.SessionRecord"] }]
+  "methods":[{"name":"getIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"getIdentityKeyPair","parameterTypes":[] }, {"name":"getLocalRegistrationId","parameterTypes":[] }, {"name":"isTrustedIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey","org.signal.libsignal.protocol.state.IdentityKeyStore$Direction"] }, {"name":"loadKyberPreKey","parameterTypes":["int"] }, {"name":"loadPreKey","parameterTypes":["int"] }, {"name":"loadSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID"] }, {"name":"loadSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress"] }, {"name":"loadSignedPreKey","parameterTypes":["int"] }, {"name":"markKyberPreKeyUsed","parameterTypes":["int"] }, {"name":"removePreKey","parameterTypes":["int"] }, {"name":"saveIdentity","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.IdentityKey"] }, {"name":"storeSenderKey","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","java.util.UUID","org.signal.libsignal.protocol.groups.state.SenderKeyRecord"] }, {"name":"storeSession","parameterTypes":["org.signal.libsignal.protocol.SignalProtocolAddress","org.signal.libsignal.protocol.state.SessionRecord"] }]
 },
 {
   "name":"org.asamk.signal.manager.storage.senderKeys.SenderKeyStore",
   "name":"org.signal.libsignal.protocol.state.IdentityKeyStore$Direction",
   "fields":[{"name":"RECEIVING"}, {"name":"SENDING"}]
 },
+{
+  "name":"org.signal.libsignal.protocol.state.KyberPreKeyRecord",
+  "fields":[{"name":"unsafeHandle"}]
+},
 {
   "name":"org.signal.libsignal.protocol.state.KyberPreKeyStore"
 },
index 6a75476634b12c1ef0cfa1e28ce270fef0436036..0f08810ccd78e9955674b67fbfc70eca5262ffee 100644 (file)
   "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity",
   "allDeclaredFields":true,
   "queryAllDeclaredMethods":true,
-  "queryAllDeclaredConstructors":true
+  "queryAllDeclaredConstructors":true,
+  "methods":[{"name":"<init>","parameterTypes":[] }, {"name":"getKeyId","parameterTypes":[] }, {"name":"getPublicKey","parameterTypes":[] }, {"name":"getSignature","parameterTypes":[] }]
 },
 {
   "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$ByteArrayDeserializer",
   "methods":[{"name":"<init>","parameterTypes":[] }]
 },
+{
+  "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$ByteArraySerializer",
+  "methods":[{"name":"<init>","parameterTypes":[] }]
+},
 {
   "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$KEMPublicKeyDeserializer",
   "methods":[{"name":"<init>","parameterTypes":[] }]
 },
+{
+  "name":"org.whispersystems.signalservice.internal.push.KyberPreKeyEntity$KEMPublicKeySerializer",
+  "methods":[{"name":"<init>","parameterTypes":[] }]
+},
 {
   "name":"org.whispersystems.signalservice.internal.push.MismatchedDevices",
   "allDeclaredFields":true,
   "name":"org.whispersystems.signalservice.internal.push.PreKeyState",
   "allDeclaredFields":true,
   "allDeclaredMethods":true,
-  "allDeclaredConstructors":true
+  "allDeclaredConstructors":true,
+  "methods":[{"name":"getIdentityKey","parameterTypes":[] }, {"name":"getPreKeys","parameterTypes":[] }, {"name":"getSignedPreKey","parameterTypes":[] }]
 },
 {
   "name":"org.whispersystems.signalservice.internal.push.PreKeyStatus",
index b2b45a9cf67825d1618ec68fbdb6db0c7536632e..f97e8f68c3da15a3b8b02513d28d775db4e6b1a6 100644 (file)
@@ -1,6 +1,7 @@
 package org.asamk.signal.manager.config;
 
 import org.asamk.signal.manager.api.ServiceEnvironment;
+import org.signal.libsignal.protocol.util.Medium;
 import org.whispersystems.signalservice.api.account.AccountAttributes;
 import org.whispersystems.signalservice.api.push.TrustStore;
 
@@ -15,8 +16,9 @@ import okhttp3.Interceptor;
 
 public class ServiceConfig {
 
-    public final static int PREKEY_MINIMUM_COUNT = 20;
+    public final static int PREKEY_MINIMUM_COUNT = 10;
     public final static int PREKEY_BATCH_SIZE = 100;
+    public final static int PREKEY_MAXIMUM_ID = Medium.MAX_VALUE;
     public final static int MAX_ATTACHMENT_SIZE = 150 * 1024 * 1024;
     public final static long MAX_ENVELOPE_SIZE = 0;
     public final static long AVATAR_DOWNLOAD_FAILSAFE_MAX_SIZE = 10 * 1024 * 1024;
index 06dc31b91d0431ead7052a20dc99063be643c6bf..282b017aa435d57b4b6f22e7a5957da22c4bdfe0 100644 (file)
@@ -5,6 +5,7 @@ import org.asamk.signal.manager.internal.SignalDependencies;
 import org.asamk.signal.manager.storage.SignalAccount;
 import org.asamk.signal.manager.util.KeyUtils;
 import org.signal.libsignal.protocol.IdentityKeyPair;
+import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
 import org.signal.libsignal.protocol.state.PreKeyRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
 import org.slf4j.Logger;
@@ -39,7 +40,9 @@ public class PreKeyHelper {
         if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
             refreshPreKeys(serviceIdType);
         }
-        // TODO kyber pre keys
+        if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
+            refreshKyberPreKeys(serviceIdType);
+        }
     }
 
     public void refreshPreKeys() throws IOException {
@@ -47,7 +50,7 @@ public class PreKeyHelper {
         refreshPreKeys(ServiceIdType.PNI);
     }
 
-    public void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
+    private void refreshPreKeys(ServiceIdType serviceIdType) throws IOException {
         final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
         if (identityKeyPair == null) {
             return;
@@ -97,4 +100,61 @@ public class PreKeyHelper {
 
         return record;
     }
+
+    private void refreshKyberPreKeys(ServiceIdType serviceIdType) throws IOException {
+        final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
+        if (identityKeyPair == null) {
+            return;
+        }
+        final var accountId = account.getAccountId(serviceIdType);
+        if (accountId == null) {
+            return;
+        }
+        try {
+            refreshKyberPreKeys(serviceIdType, identityKeyPair);
+        } catch (Exception e) {
+            logger.warn("Failed to store new pre keys, resetting preKey id offset", e);
+            account.resetKyberPreKeyOffsets(serviceIdType);
+            refreshKyberPreKeys(serviceIdType, identityKeyPair);
+        }
+    }
+
+    private void refreshKyberPreKeys(
+            final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
+    ) throws IOException {
+        final var oneTimePreKeys = generateKyberPreKeys(serviceIdType, identityKeyPair);
+        final var lastResortPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
+
+        final var preKeyUpload = new PreKeyUpload(serviceIdType,
+                identityKeyPair.getPublicKey(),
+                null,
+                null,
+                lastResortPreKeyRecord,
+                oneTimePreKeys);
+        dependencies.getAccountManager().setPreKeys(preKeyUpload);
+    }
+
+    private List<KyberPreKeyRecord> generateKyberPreKeys(
+            ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
+    ) {
+        final var offset = account.getKyberPreKeyIdOffset(serviceIdType);
+
+        var records = KeyUtils.generateKyberPreKeyRecords(offset,
+                ServiceConfig.PREKEY_BATCH_SIZE,
+                identityKeyPair.getPrivateKey());
+        account.addKyberPreKeys(serviceIdType, records);
+
+        return records;
+    }
+
+    private KyberPreKeyRecord generateLastResortKyberPreKey(
+            ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair
+    ) {
+        final var signedPreKeyId = account.getKyberPreKeyIdOffset(serviceIdType);
+
+        var record = KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
+        account.addLastResortKyberPreKey(serviceIdType, record);
+
+        return record;
+    }
 }
index a15b4aabc07e7c24726b4d586e3c946f3f481fe5..7f6babe6fc572c4dde15269f22b698931d2b55a2 100644 (file)
@@ -4,6 +4,7 @@ import com.zaxxer.hikari.HikariDataSource;
 
 import org.asamk.signal.manager.storage.groups.GroupStore;
 import org.asamk.signal.manager.storage.identities.IdentityKeyStore;
+import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore;
 import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
 import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
 import org.asamk.signal.manager.storage.recipients.RecipientStore;
@@ -23,7 +24,7 @@ import java.sql.SQLException;
 public class AccountDatabase extends Database {
 
     private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
-    private static final long DATABASE_VERSION = 13;
+    private static final long DATABASE_VERSION = 14;
 
     private AccountDatabase(final HikariDataSource dataSource) {
         super(logger, DATABASE_VERSION, dataSource);
@@ -40,6 +41,7 @@ public class AccountDatabase extends Database {
         StickerStore.createSql(connection);
         PreKeyStore.createSql(connection);
         SignedPreKeyStore.createSql(connection);
+        KyberPreKeyStore.createSql(connection);
         GroupStore.createSql(connection);
         SessionStore.createSql(connection);
         IdentityKeyStore.createSql(connection);
@@ -328,5 +330,23 @@ public class AccountDatabase extends Database {
                 }
             }
         }
+        if (oldVersion < 14) {
+            logger.debug("Updating database: Creating kyber_pre_key table");
+            {
+                try (final var statement = connection.createStatement()) {
+                    statement.executeUpdate("""
+                                            CREATE TABLE kyber_pre_key (
+                                                    _id INTEGER PRIMARY KEY,
+                                                    account_id_type INTEGER NOT NULL,
+                                                    key_id INTEGER NOT NULL,
+                                                    serialized BLOB NOT NULL,
+                                                    is_last_resort INTEGER NOT NULL,
+                                                    UNIQUE(account_id_type, key_id)
+                                            ) STRICT;
+                                            """);
+                }
+            }
+
+        }
     }
 }
index 850b264fb84286f91503cc0757bae779faf96d2a..76867b728939d750992caae356dc849a5bc347f3 100644 (file)
@@ -21,6 +21,7 @@ import org.asamk.signal.manager.storage.identities.IdentityKeyStore;
 import org.asamk.signal.manager.storage.identities.LegacyIdentityKeyStore;
 import org.asamk.signal.manager.storage.identities.SignalIdentityKeyStore;
 import org.asamk.signal.manager.storage.messageCache.MessageCache;
+import org.asamk.signal.manager.storage.prekeys.KyberPreKeyStore;
 import org.asamk.signal.manager.storage.prekeys.LegacyPreKeyStore;
 import org.asamk.signal.manager.storage.prekeys.LegacySignedPreKeyStore;
 import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
@@ -51,11 +52,11 @@ import org.asamk.signal.manager.util.KeyUtils;
 import org.signal.libsignal.protocol.IdentityKeyPair;
 import org.signal.libsignal.protocol.InvalidMessageException;
 import org.signal.libsignal.protocol.SignalProtocolAddress;
+import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
 import org.signal.libsignal.protocol.state.PreKeyRecord;
 import org.signal.libsignal.protocol.state.SessionRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
 import org.signal.libsignal.protocol.util.KeyHelper;
-import org.signal.libsignal.protocol.util.Medium;
 import org.signal.libsignal.zkgroup.InvalidInputException;
 import org.signal.libsignal.zkgroup.profiles.ProfileKey;
 import org.slf4j.Logger;
@@ -98,6 +99,7 @@ import java.util.List;
 import java.util.Optional;
 import java.util.function.Supplier;
 
+import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_MAXIMUM_ID;
 import static org.asamk.signal.manager.config.ServiceConfig.getCapabilities;
 
 public class SignalAccount implements Closeable {
@@ -105,7 +107,7 @@ public class SignalAccount implements Closeable {
     private final static Logger logger = LoggerFactory.getLogger(SignalAccount.class);
 
     private static final int MINIMUM_STORAGE_VERSION = 1;
-    private static final int CURRENT_STORAGE_VERSION = 6;
+    private static final int CURRENT_STORAGE_VERSION = 7;
 
     private final Object LOCK = new Object();
 
@@ -138,6 +140,10 @@ public class SignalAccount implements Closeable {
     private int aciNextSignedPreKeyId = 1;
     private int pniPreKeyIdOffset = 1;
     private int pniNextSignedPreKeyId = 1;
+    private int aciKyberPreKeyIdOffset = 1;
+    private int aciActiveLastResortKyberPreKeyId = -1;
+    private int pniKyberPreKeyIdOffset = 1;
+    private int pniActiveLastResortKyberPreKeyId = -1;
     private IdentityKeyPair aciIdentityKeyPair;
     private IdentityKeyPair pniIdentityKeyPair;
     private int localRegistrationId;
@@ -151,8 +157,10 @@ public class SignalAccount implements Closeable {
     private SignalProtocolStore pniSignalProtocolStore;
     private PreKeyStore aciPreKeyStore;
     private SignedPreKeyStore aciSignedPreKeyStore;
+    private KyberPreKeyStore aciKyberPreKeyStore;
     private PreKeyStore pniPreKeyStore;
     private SignedPreKeyStore pniSignedPreKeyStore;
+    private KyberPreKeyStore pniKyberPreKeyStore;
     private SessionStore aciSessionStore;
     private SessionStore pniSessionStore;
     private IdentityKeyStore identityKeyStore;
@@ -302,10 +310,14 @@ public class SignalAccount implements Closeable {
     private void clearAllPreKeys() {
         resetPreKeyOffsets(ServiceIdType.ACI);
         resetPreKeyOffsets(ServiceIdType.PNI);
+        resetKyberPreKeyOffsets(ServiceIdType.ACI);
+        resetKyberPreKeyOffsets(ServiceIdType.PNI);
         this.getAciPreKeyStore().removeAllPreKeys();
         this.getAciSignedPreKeyStore().removeAllSignedPreKeys();
+        this.getAciKyberPreKeyStore().removeAllKyberPreKeys();
         this.getPniPreKeyStore().removeAllPreKeys();
         this.getPniSignedPreKeyStore().removeAllSignedPreKeys();
+        this.getPniKyberPreKeyStore().removeAllKyberPreKeys();
         save();
     }
 
@@ -614,22 +626,42 @@ public class SignalAccount implements Closeable {
         if (rootNode.hasNonNull("preKeyIdOffset")) {
             aciPreKeyIdOffset = rootNode.get("preKeyIdOffset").asInt(1);
         } else {
-            aciPreKeyIdOffset = 1;
+            aciPreKeyIdOffset = getRandomPreKeyIdOffset();
         }
         if (rootNode.hasNonNull("nextSignedPreKeyId")) {
             aciNextSignedPreKeyId = rootNode.get("nextSignedPreKeyId").asInt(1);
         } else {
-            aciNextSignedPreKeyId = 1;
+            aciNextSignedPreKeyId = getRandomPreKeyIdOffset();
         }
         if (rootNode.hasNonNull("pniPreKeyIdOffset")) {
             pniPreKeyIdOffset = rootNode.get("pniPreKeyIdOffset").asInt(1);
         } else {
-            pniPreKeyIdOffset = 1;
+            pniPreKeyIdOffset = getRandomPreKeyIdOffset();
         }
         if (rootNode.hasNonNull("pniNextSignedPreKeyId")) {
             pniNextSignedPreKeyId = rootNode.get("pniNextSignedPreKeyId").asInt(1);
         } else {
-            pniNextSignedPreKeyId = 1;
+            pniNextSignedPreKeyId = getRandomPreKeyIdOffset();
+        }
+        if (rootNode.hasNonNull("kyberPreKeyIdOffset")) {
+            aciKyberPreKeyIdOffset = rootNode.get("kyberPreKeyIdOffset").asInt(1);
+        } else {
+            aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
+        }
+        if (rootNode.hasNonNull("activeLastResortKyberPreKeyId")) {
+            aciActiveLastResortKyberPreKeyId = rootNode.get("activeLastResortKyberPreKeyId").asInt(-1);
+        } else {
+            aciActiveLastResortKyberPreKeyId = -1;
+        }
+        if (rootNode.hasNonNull("pniKyberPreKeyIdOffset")) {
+            pniKyberPreKeyIdOffset = rootNode.get("pniKyberPreKeyIdOffset").asInt(1);
+        } else {
+            pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
+        }
+        if (rootNode.hasNonNull("pniActiveLastResortKyberPreKeyId")) {
+            pniActiveLastResortKyberPreKeyId = rootNode.get("pniActiveLastResortKyberPreKeyId").asInt(-1);
+        } else {
+            pniActiveLastResortKyberPreKeyId = -1;
         }
         if (rootNode.hasNonNull("profileKey")) {
             try {
@@ -974,6 +1006,10 @@ public class SignalAccount implements Closeable {
                     .put("nextSignedPreKeyId", aciNextSignedPreKeyId)
                     .put("pniPreKeyIdOffset", pniPreKeyIdOffset)
                     .put("pniNextSignedPreKeyId", pniNextSignedPreKeyId)
+                    .put("kyberPreKeyIdOffset", aciKyberPreKeyIdOffset)
+                    .put("activeLastResortKyberPreKeyId", aciActiveLastResortKyberPreKeyId)
+                    .put("pniKyberPreKeyIdOffset", pniKyberPreKeyIdOffset)
+                    .put("pniActiveLastResortKyberPreKeyId", pniActiveLastResortKyberPreKeyId)
                     .put("profileKey",
                             profileKey == null ? null : Base64.getEncoder().encodeToString(profileKey.serialize()))
                     .put("registered", registered)
@@ -1019,15 +1055,19 @@ public class SignalAccount implements Closeable {
 
     public void resetPreKeyOffsets(final ServiceIdType serviceIdType) {
         if (serviceIdType.equals(ServiceIdType.ACI)) {
-            this.aciPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
-            this.aciNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
+            this.aciPreKeyIdOffset = getRandomPreKeyIdOffset();
+            this.aciNextSignedPreKeyId = getRandomPreKeyIdOffset();
         } else {
-            this.pniPreKeyIdOffset = new SecureRandom().nextInt(Medium.MAX_VALUE);
-            this.pniNextSignedPreKeyId = new SecureRandom().nextInt(Medium.MAX_VALUE);
+            this.pniPreKeyIdOffset = getRandomPreKeyIdOffset();
+            this.pniNextSignedPreKeyId = getRandomPreKeyIdOffset();
         }
         save();
     }
 
+    private static int getRandomPreKeyIdOffset() {
+        return new SecureRandom().nextInt(PREKEY_MAXIMUM_ID);
+    }
+
     public void addPreKeys(ServiceIdType serviceIdType, List<PreKeyRecord> records) {
         if (serviceIdType.equals(ServiceIdType.ACI)) {
             addAciPreKeys(records);
@@ -1036,26 +1076,26 @@ public class SignalAccount implements Closeable {
         }
     }
 
-    public void addAciPreKeys(List<PreKeyRecord> records) {
+    private void addAciPreKeys(List<PreKeyRecord> records) {
         for (var record : records) {
             if (aciPreKeyIdOffset != record.getId()) {
                 logger.error("Invalid pre key id {}, expected {}", record.getId(), aciPreKeyIdOffset);
                 throw new AssertionError("Invalid pre key id");
             }
             getAciPreKeyStore().storePreKey(record.getId(), record);
-            aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % Medium.MAX_VALUE;
+            aciPreKeyIdOffset = (aciPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
         }
         save();
     }
 
-    public void addPniPreKeys(List<PreKeyRecord> records) {
+    private void addPniPreKeys(List<PreKeyRecord> records) {
         for (var record : records) {
             if (pniPreKeyIdOffset != record.getId()) {
                 logger.error("Invalid pre key id {}, expected {}", record.getId(), pniPreKeyIdOffset);
                 throw new AssertionError("Invalid pre key id");
             }
             getPniPreKeyStore().storePreKey(record.getId(), record);
-            pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % Medium.MAX_VALUE;
+            pniPreKeyIdOffset = (pniPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
         }
         save();
     }
@@ -1074,7 +1114,7 @@ public class SignalAccount implements Closeable {
             throw new AssertionError("Invalid signed pre key id");
         }
         getAciSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
-        aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
+        aciNextSignedPreKeyId = (aciNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
         save();
     }
 
@@ -1084,7 +1124,84 @@ public class SignalAccount implements Closeable {
             throw new AssertionError("Invalid signed pre key id");
         }
         getPniSignedPreKeyStore().storeSignedPreKey(record.getId(), record);
-        pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % Medium.MAX_VALUE;
+        pniNextSignedPreKeyId = (pniNextSignedPreKeyId + 1) % PREKEY_MAXIMUM_ID;
+        save();
+    }
+
+    public void resetKyberPreKeyOffsets(final ServiceIdType serviceIdType) {
+        if (serviceIdType.equals(ServiceIdType.ACI)) {
+            this.aciKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
+            this.aciActiveLastResortKyberPreKeyId = -1;
+        } else {
+            this.pniKyberPreKeyIdOffset = getRandomPreKeyIdOffset();
+            this.pniActiveLastResortKyberPreKeyId = -1;
+        }
+        save();
+    }
+
+    public void addKyberPreKeys(ServiceIdType serviceIdType, List<KyberPreKeyRecord> records) {
+        if (serviceIdType.equals(ServiceIdType.ACI)) {
+            addAciKyberPreKeys(records);
+        } else {
+            addPniKyberPreKeys(records);
+        }
+    }
+
+    private void addAciKyberPreKeys(List<KyberPreKeyRecord> records) {
+        for (var record : records) {
+            if (aciKyberPreKeyIdOffset != record.getId()) {
+                logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), aciKyberPreKeyIdOffset);
+                throw new AssertionError("Invalid kyber pre key id");
+            }
+            getAciKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
+            aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
+        }
+        save();
+    }
+
+    private void addPniKyberPreKeys(List<KyberPreKeyRecord> records) {
+        for (var record : records) {
+            if (pniKyberPreKeyIdOffset != record.getId()) {
+                logger.error("Invalid kyber pre key id {}, expected {}", record.getId(), pniKyberPreKeyIdOffset);
+                throw new AssertionError("Invalid kyber pre key id");
+            }
+            getPniKyberPreKeyStore().storeKyberPreKey(record.getId(), record);
+            pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
+        }
+        save();
+    }
+
+    public void addLastResortKyberPreKey(ServiceIdType serviceIdType, KyberPreKeyRecord record) {
+        if (serviceIdType.equals(ServiceIdType.ACI)) {
+            addAciLastResortKyberPreKey(record);
+        } else {
+            addPniLastResortKyberPreKey(record);
+        }
+    }
+
+    public void addAciLastResortKyberPreKey(KyberPreKeyRecord record) {
+        if (aciKyberPreKeyIdOffset != record.getId()) {
+            logger.error("Invalid last resort kyber pre key id {}, expected {}",
+                    record.getId(),
+                    aciKyberPreKeyIdOffset);
+            throw new AssertionError("Invalid last resort kyber pre key id");
+        }
+        getAciKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
+        aciActiveLastResortKyberPreKeyId = record.getId();
+        aciKyberPreKeyIdOffset = (aciKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
+        save();
+    }
+
+    public void addPniLastResortKyberPreKey(KyberPreKeyRecord record) {
+        if (pniKyberPreKeyIdOffset != record.getId()) {
+            logger.error("Invalid last resort kyber pre key id {}, expected {}",
+                    record.getId(),
+                    pniKyberPreKeyIdOffset);
+            throw new AssertionError("Invalid last resort kyber pre key id");
+        }
+        getPniKyberPreKeyStore().storeLastResortKyberPreKey(record.getId(), record);
+        pniActiveLastResortKyberPreKeyId = record.getId();
+        pniKyberPreKeyIdOffset = (pniKyberPreKeyIdOffset + 1) % PREKEY_MAXIMUM_ID;
         save();
     }
 
@@ -1126,6 +1243,7 @@ public class SignalAccount implements Closeable {
         return getOrCreate(() -> aciSignalProtocolStore,
                 () -> aciSignalProtocolStore = new SignalProtocolStore(getAciPreKeyStore(),
                         getAciSignedPreKeyStore(),
+                        getAciKyberPreKeyStore(),
                         getAciSessionStore(),
                         getAciIdentityKeyStore(),
                         getSenderKeyStore(),
@@ -1136,6 +1254,7 @@ public class SignalAccount implements Closeable {
         return getOrCreate(() -> pniSignalProtocolStore,
                 () -> pniSignalProtocolStore = new SignalProtocolStore(getPniPreKeyStore(),
                         getPniSignedPreKeyStore(),
+                        getPniKyberPreKeyStore(),
                         getPniSessionStore(),
                         getPniIdentityKeyStore(),
                         getSenderKeyStore(),
@@ -1152,6 +1271,11 @@ public class SignalAccount implements Closeable {
                 () -> aciSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
     }
 
+    private KyberPreKeyStore getAciKyberPreKeyStore() {
+        return getOrCreate(() -> aciKyberPreKeyStore,
+                () -> aciKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.ACI));
+    }
+
     private PreKeyStore getPniPreKeyStore() {
         return getOrCreate(() -> pniPreKeyStore,
                 () -> pniPreKeyStore = new PreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
@@ -1162,6 +1286,11 @@ public class SignalAccount implements Closeable {
                 () -> pniSignedPreKeyStore = new SignedPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
     }
 
+    private KyberPreKeyStore getPniKyberPreKeyStore() {
+        return getOrCreate(() -> pniKyberPreKeyStore,
+                () -> pniKyberPreKeyStore = new KyberPreKeyStore(getAccountDatabase(), ServiceIdType.PNI));
+    }
+
     public SessionStore getAciSessionStore() {
         return getOrCreate(() -> aciSessionStore,
                 () -> aciSessionStore = new SessionStore(getAccountDatabase(), ServiceIdType.ACI));
@@ -1362,6 +1491,9 @@ public class SignalAccount implements Closeable {
             identityKeyStore.deleteIdentity(this.pni);
             getPniPreKeyStore().removeAllPreKeys();
             getPniSignedPreKeyStore().removeAllSignedPreKeys();
+            getPniKyberPreKeyStore().removeAllKyberPreKeys();
+            aciActiveLastResortKyberPreKeyId = -1;
+            pniActiveLastResortKyberPreKeyId = -1;
         }
 
         this.pni = updatedPni;
@@ -1406,10 +1538,6 @@ public class SignalAccount implements Closeable {
         save();
     }
 
-    public byte[] getEncryptedDeviceName() {
-        return encryptedDeviceName == null ? null : Base64.getDecoder().decode(encryptedDeviceName);
-    }
-
     public void setEncryptedDeviceName(final String encryptedDeviceName) {
         this.encryptedDeviceName = encryptedDeviceName;
         save();
@@ -1590,6 +1718,10 @@ public class SignalAccount implements Closeable {
         return serviceIdType.equals(ServiceIdType.ACI) ? aciNextSignedPreKeyId : pniNextSignedPreKeyId;
     }
 
+    public int getKyberPreKeyIdOffset(ServiceIdType serviceIdType) {
+        return serviceIdType.equals(ServiceIdType.ACI) ? aciKyberPreKeyIdOffset : pniKyberPreKeyIdOffset;
+    }
+
     public boolean isRegistered() {
         return registered;
     }
diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/prekeys/KyberPreKeyStore.java
new file mode 100644 (file)
index 0000000..54fbba4
--- /dev/null
@@ -0,0 +1,211 @@
+package org.asamk.signal.manager.storage.prekeys;
+
+import org.asamk.signal.manager.storage.Database;
+import org.asamk.signal.manager.storage.Utils;
+import org.signal.libsignal.protocol.InvalidKeyIdException;
+import org.signal.libsignal.protocol.InvalidMessageException;
+import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore;
+import org.whispersystems.signalservice.api.push.ServiceIdType;
+
+import java.sql.Connection;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.List;
+
+public class KyberPreKeyStore implements SignalServiceKyberPreKeyStore {
+
+    private static final String TABLE_KYBER_PRE_KEY = "kyber_pre_key";
+    private final static Logger logger = LoggerFactory.getLogger(KyberPreKeyStore.class);
+
+    private final Database database;
+    private final int accountIdType;
+
+    public static void createSql(Connection connection) throws SQLException {
+        // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
+        try (final var statement = connection.createStatement()) {
+            statement.executeUpdate("""
+                                    CREATE TABLE kyber_pre_key (
+                                      _id INTEGER PRIMARY KEY,
+                                      account_id_type INTEGER NOT NULL,
+                                      key_id INTEGER NOT NULL,
+                                      serialized BLOB NOT NULL,
+                                      is_last_resort INTEGER NOT NULL,
+                                      UNIQUE(account_id_type, key_id)
+                                    ) STRICT;
+                                    """);
+        }
+    }
+
+    public KyberPreKeyStore(final Database database, final ServiceIdType serviceIdType) {
+        this.database = database;
+        this.accountIdType = Utils.getAccountIdType(serviceIdType);
+    }
+
+    @Override
+    public KyberPreKeyRecord loadKyberPreKey(final int keyId) throws InvalidKeyIdException {
+        final var kyberPreKey = getPreKey(keyId);
+        if (kyberPreKey == null) {
+            throw new InvalidKeyIdException("No such kyber pre key record: " + keyId);
+        }
+        return kyberPreKey;
+    }
+
+    @Override
+    public List<KyberPreKeyRecord> loadKyberPreKeys() {
+        final var sql = (
+                """
+                SELECT p.serialized
+                FROM %s p
+                WHERE p.account_id_type = ?
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                return Utils.executeQueryForStream(statement, this::getKyberPreKeyRecordFromResultSet).toList();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from kyber_pre_key store", e);
+        }
+    }
+
+    @Override
+    public List<KyberPreKeyRecord> loadLastResortKyberPreKeys() {
+        final var sql = (
+                """
+                SELECT p.serialized
+                FROM %s p
+                WHERE p.account_id_type = ? AND p.is_last_resort = TRUE
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                return Utils.executeQueryForStream(statement, this::getKyberPreKeyRecordFromResultSet).toList();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from kyber_pre_key store", e);
+        }
+    }
+
+    @Override
+    public void storeLastResortKyberPreKey(final int keyId, final KyberPreKeyRecord record) {
+        storeKyberPreKey(keyId, record, true);
+    }
+
+    @Override
+    public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record) {
+        storeKyberPreKey(keyId, record, false);
+    }
+
+    public void storeKyberPreKey(final int keyId, final KyberPreKeyRecord record, final boolean isLastResort) {
+        final var sql = (
+                """
+                INSERT INTO %s (account_id_type, key_id, serialized, is_last_resort)
+                VALUES (?, ?, ?, ?)
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setInt(2, keyId);
+                statement.setBytes(3, record.serialize());
+                statement.setBoolean(4, isLastResort);
+                statement.executeUpdate();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update kyber_pre_key store", e);
+        }
+    }
+
+    @Override
+    public boolean containsKyberPreKey(final int keyId) {
+        return getPreKey(keyId) != null;
+    }
+
+    @Override
+    public void markKyberPreKeyUsed(final int keyId) {
+        final var sql = (
+                """
+                DELETE FROM %s AS p
+                WHERE p.account_id_type = ? AND p.key_id = ? AND p.is_last_resort = FALSE
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setInt(2, keyId);
+                statement.executeUpdate();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update kyber_pre_key store", e);
+        }
+    }
+
+    @Override
+    public void removeKyberPreKey(final int keyId) {
+        final var sql = (
+                """
+                DELETE FROM %s AS p
+                WHERE p.account_id_type = ? AND p.key_id = ?
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setInt(2, keyId);
+                statement.executeUpdate();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update kyber_pre_key store", e);
+        }
+    }
+
+    public void removeAllKyberPreKeys() {
+        final var sql = (
+                """
+                DELETE FROM %s AS p
+                WHERE p.account_id_type = ?
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.executeUpdate();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update kyber_pre_key store", e);
+        }
+    }
+
+    private KyberPreKeyRecord getPreKey(int keyId) {
+        final var sql = (
+                """
+                SELECT p.serialized
+                FROM %s p
+                WHERE p.account_id_type = ? AND p.key_id = ?
+                """
+        ).formatted(TABLE_KYBER_PRE_KEY);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setInt(2, keyId);
+                return Utils.executeQueryForOptional(statement, this::getKyberPreKeyRecordFromResultSet).orElse(null);
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from kyber_pre_key store", e);
+        }
+    }
+
+    private KyberPreKeyRecord getKyberPreKeyRecordFromResultSet(ResultSet resultSet) throws SQLException {
+        try {
+            final var serialized = resultSet.getBytes("serialized");
+            return new KyberPreKeyRecord(serialized);
+        } catch (InvalidMessageException e) {
+            return null;
+        }
+    }
+}
index cc51a20aa4ae244a12ab516723404f729be2a5e5..f6dc271eea7765c982864deb475aac8a1811eb21 100644 (file)
@@ -49,7 +49,7 @@ public class PreKeyStore implements org.signal.libsignal.protocol.state.PreKeySt
     public PreKeyRecord loadPreKey(int preKeyId) throws InvalidKeyIdException {
         final var preKey = getPreKey(preKeyId);
         if (preKey == null) {
-            throw new InvalidKeyIdException("No such signed pre key record: " + preKeyId);
+            throw new InvalidKeyIdException("No such pre key record: " + preKeyId);
         }
         return preKey;
     }
index b4e27afcff7cc001a35aa70fb4f73267808af8c0..e565e915761175d0d53e09df8aacea1ca2cab073 100644 (file)
@@ -14,6 +14,7 @@ import org.signal.libsignal.protocol.state.SessionRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyStore;
 import org.whispersystems.signalservice.api.SignalServiceAccountDataStore;
+import org.whispersystems.signalservice.api.SignalServiceKyberPreKeyStore;
 import org.whispersystems.signalservice.api.SignalServiceSenderKeyStore;
 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
 import org.whispersystems.signalservice.api.push.DistributionId;
@@ -28,6 +29,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
 
     private final PreKeyStore preKeyStore;
     private final SignedPreKeyStore signedPreKeyStore;
+    private final SignalServiceKyberPreKeyStore kyberPreKeyStore;
     private final SignalServiceSessionStore sessionStore;
     private final IdentityKeyStore identityKeyStore;
     private final SignalServiceSenderKeyStore senderKeyStore;
@@ -36,6 +38,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
     public SignalProtocolStore(
             final PreKeyStore preKeyStore,
             final SignedPreKeyStore signedPreKeyStore,
+            final SignalServiceKyberPreKeyStore kyberPreKeyStore,
             final SignalServiceSessionStore sessionStore,
             final IdentityKeyStore identityKeyStore,
             final SignalServiceSenderKeyStore senderKeyStore,
@@ -43,6 +46,7 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
     ) {
         this.preKeyStore = preKeyStore;
         this.signedPreKeyStore = signedPreKeyStore;
+        this.kyberPreKeyStore = kyberPreKeyStore;
         this.sessionStore = sessionStore;
         this.identityKeyStore = identityKeyStore;
         this.senderKeyStore = senderKeyStore;
@@ -201,45 +205,41 @@ public class SignalProtocolStore implements SignalServiceAccountDataStore {
 
     @Override
     public KyberPreKeyRecord loadKyberPreKey(final int kyberPreKeyId) throws InvalidKeyIdException {
-        // TODO
-        throw new InvalidKeyIdException("Missing kyber prekey with ID: $kyberPreKeyId");
+        return kyberPreKeyStore.loadKyberPreKey(kyberPreKeyId);
     }
 
     @Override
     public List<KyberPreKeyRecord> loadKyberPreKeys() {
-        // TODO
-        return List.of();
+        return kyberPreKeyStore.loadKyberPreKeys();
     }
 
     @Override
     public void storeKyberPreKey(final int kyberPreKeyId, final KyberPreKeyRecord record) {
-        // TODO
+        kyberPreKeyStore.storeKyberPreKey(kyberPreKeyId, record);
     }
 
     @Override
     public boolean containsKyberPreKey(final int kyberPreKeyId) {
-        // TODO
-        return false;
+        return kyberPreKeyStore.containsKyberPreKey(kyberPreKeyId);
     }
 
     @Override
     public void markKyberPreKeyUsed(final int kyberPreKeyId) {
-        // TODO
+        kyberPreKeyStore.markKyberPreKeyUsed(kyberPreKeyId);
     }
 
     @Override
     public List<KyberPreKeyRecord> loadLastResortKyberPreKeys() {
-        // TODO
-        return List.of();
+        return kyberPreKeyStore.loadLastResortKyberPreKeys();
     }
 
     @Override
     public void removeKyberPreKey(final int i) {
-        // TODO
+        kyberPreKeyStore.removeKyberPreKey(i);
     }
 
     @Override
     public void storeLastResortKyberPreKey(final int i, final KyberPreKeyRecord kyberPreKeyRecord) {
-        // TODO
+        kyberPreKeyStore.storeLastResortKyberPreKey(i, kyberPreKeyRecord);
     }
 }
index 0f26758315b19a65877a14924dce35f3a8e8e4aa..3608e0c12ae370364363576a9fe133743802dc2b 100644 (file)
@@ -406,9 +406,7 @@ public class SessionStore implements SignalServiceSessionStore {
     }
 
     private static boolean isActive(SessionRecord record) {
-        return record != null
-                && record.hasSenderChain()
-                && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
+        return record != null && record.hasSenderChain();
     }
 
     record Key(ServiceId serviceId, int deviceId) {}
index d87868eedf754b7f3aa0eb4438ff68424019f99c..3ee5657ebda436abd6d4f3bbd81e7f5ff969419a 100644 (file)
@@ -5,9 +5,11 @@ import org.signal.libsignal.protocol.IdentityKeyPair;
 import org.signal.libsignal.protocol.InvalidKeyException;
 import org.signal.libsignal.protocol.ecc.Curve;
 import org.signal.libsignal.protocol.ecc.ECPrivateKey;
+import org.signal.libsignal.protocol.kem.KEMKeyPair;
+import org.signal.libsignal.protocol.kem.KEMKeyType;
+import org.signal.libsignal.protocol.state.KyberPreKeyRecord;
 import org.signal.libsignal.protocol.state.PreKeyRecord;
 import org.signal.libsignal.protocol.state.SignedPreKeyRecord;
-import org.signal.libsignal.protocol.util.Medium;
 import org.signal.libsignal.zkgroup.InvalidInputException;
 import org.signal.libsignal.zkgroup.profiles.ProfileKey;
 import org.whispersystems.signalservice.api.kbs.MasterKey;
@@ -17,6 +19,8 @@ import java.util.ArrayList;
 import java.util.Base64;
 import java.util.List;
 
+import static org.asamk.signal.manager.config.ServiceConfig.PREKEY_MAXIMUM_ID;
+
 public class KeyUtils {
 
     private static final SecureRandom secureRandom = new SecureRandom();
@@ -46,7 +50,7 @@ public class KeyUtils {
     public static List<PreKeyRecord> generatePreKeyRecords(final int offset, final int batchSize) {
         var records = new ArrayList<PreKeyRecord>(batchSize);
         for (var i = 0; i < batchSize; i++) {
-            var preKeyId = (offset + i) % Medium.MAX_VALUE;
+            var preKeyId = (offset + i) % PREKEY_MAXIMUM_ID;
             var keyPair = Curve.generateKeyPair();
             var record = new PreKeyRecord(preKeyId, keyPair);
 
@@ -68,6 +72,24 @@ public class KeyUtils {
         return new SignedPreKeyRecord(signedPreKeyId, System.currentTimeMillis(), keyPair, signature);
     }
 
+    public static List<KyberPreKeyRecord> generateKyberPreKeyRecords(
+            final int offset, final int batchSize, final ECPrivateKey privateKey
+    ) {
+        var records = new ArrayList<KyberPreKeyRecord>(batchSize);
+        for (var i = 0; i < batchSize; i++) {
+            var preKeyId = (offset + i) % PREKEY_MAXIMUM_ID;
+            records.add(generateKyberPreKeyRecord(preKeyId, privateKey));
+        }
+        return records;
+    }
+
+    public static KyberPreKeyRecord generateKyberPreKeyRecord(final int preKeyId, final ECPrivateKey privateKey) {
+        KEMKeyPair keyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024);
+        byte[] signature = privateKey.calculateSignature(keyPair.getPublicKey().serialize());
+
+        return new KyberPreKeyRecord(preKeyId, System.currentTimeMillis(), keyPair, signature);
+    }
+
     public static ProfileKey createProfileKey() {
         try {
             return new ProfileKey(getSecretBytes(32));