]> nmode's Git Repositories - signal-cli/commitdiff
Move session store to database
authorAsamK <asamk@gmx.de>
Thu, 9 Jun 2022 14:54:48 +0000 (16:54 +0200)
committerAsamK <asamk@gmx.de>
Sun, 28 Aug 2022 14:04:05 +0000 (16:04 +0200)
lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java
lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java
lib/src/main/java/org/asamk/signal/manager/storage/sessions/LegacySessionStore.java [new file with mode: 0644]
lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java

index 93a145ebf6b60218906b296aa527dc1dddad9bfd..ef88c227ca2c134f5fcc14e19d370466c08f34b8 100644 (file)
@@ -7,6 +7,7 @@ import org.asamk.signal.manager.storage.prekeys.PreKeyStore;
 import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore;
 import org.asamk.signal.manager.storage.recipients.RecipientStore;
 import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore;
+import org.asamk.signal.manager.storage.sessions.SessionStore;
 import org.asamk.signal.manager.storage.stickers.StickerStore;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -18,7 +19,7 @@ import java.sql.SQLException;
 public class AccountDatabase extends Database {
 
     private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class);
-    private static final long DATABASE_VERSION = 5;
+    private static final long DATABASE_VERSION = 6;
 
     private AccountDatabase(final HikariDataSource dataSource) {
         super(logger, DATABASE_VERSION, dataSource);
@@ -36,6 +37,7 @@ public class AccountDatabase extends Database {
         PreKeyStore.createSql(connection);
         SignedPreKeyStore.createSql(connection);
         GroupStore.createSql(connection);
+        SessionStore.createSql(connection);
     }
 
     @Override
@@ -143,5 +145,20 @@ public class AccountDatabase extends Database {
                                         """);
             }
         }
+        if (oldVersion < 6) {
+            logger.debug("Updating database: Creating session tables");
+            try (final var statement = connection.createStatement()) {
+                statement.executeUpdate("""
+                                        CREATE TABLE session (
+                                          _id INTEGER PRIMARY KEY,
+                                          account_id_type INTEGER NOT NULL,
+                                          recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
+                                          device_id INTEGER NOT NULL,
+                                          record BLOB NOT NULL,
+                                          UNIQUE(account_id_type, recipient_id, device_id)
+                                        );
+                                        """);
+            }
+        }
     }
 }
index 675c3f3e3711e0c0d5de32be7a5c741f18c471b0..6e4621cd40d87defa8c3ba3abe1b781b71e39553 100644 (file)
@@ -38,6 +38,7 @@ import org.asamk.signal.manager.storage.recipients.RecipientStore;
 import org.asamk.signal.manager.storage.recipients.RecipientTrustedResolver;
 import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore;
 import org.asamk.signal.manager.storage.senderKeys.SenderKeyStore;
+import org.asamk.signal.manager.storage.sessions.LegacySessionStore;
 import org.asamk.signal.manager.storage.sessions.SessionStore;
 import org.asamk.signal.manager.storage.stickers.LegacyStickerStore;
 import org.asamk.signal.manager.storage.stickers.StickerStore;
@@ -634,6 +635,11 @@ public class SignalAccount implements Closeable {
             LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, getPniSignedPreKeyStore());
             migratedLegacyConfig = true;
         }
+        final var legacySessionsPath = getSessionsPath(dataPath, accountPath);
+        if (legacySessionsPath.exists()) {
+            LegacySessionStore.migrate(legacySessionsPath, getRecipientResolver(), getSessionStore());
+            migratedLegacyConfig = true;
+        }
         final var legacySignalProtocolStore = rootNode.hasNonNull("axolotlStore")
                 ? jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"),
                 LegacyJsonSignalProtocolStore.class)
@@ -1067,7 +1073,10 @@ public class SignalAccount implements Closeable {
 
     public SessionStore getSessionStore() {
         return getOrCreate(() -> sessionStore,
-                () -> sessionStore = new SessionStore(getSessionsPath(dataPath, accountPath), getRecipientResolver()));
+                () -> sessionStore = new SessionStore(getAccountDatabase(),
+                        ServiceIdType.ACI,
+                        getRecipientResolver(),
+                        getRecipientIdCreator()));
     }
 
     public IdentityKeyStore getIdentityKeyStore() {
diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/sessions/LegacySessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/LegacySessionStore.java
new file mode 100644 (file)
index 0000000..d0646ec
--- /dev/null
@@ -0,0 +1,107 @@
+package org.asamk.signal.manager.storage.sessions;
+
+import org.asamk.signal.manager.api.Pair;
+import org.asamk.signal.manager.storage.recipients.RecipientResolver;
+import org.asamk.signal.manager.storage.sessions.SessionStore.Key;
+import org.asamk.signal.manager.util.IOUtils;
+import org.signal.libsignal.protocol.state.SessionRecord;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.Objects;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+public class LegacySessionStore {
+
+    private final static Logger logger = LoggerFactory.getLogger(LegacySessionStore.class);
+
+    public static void migrate(
+            final File sessionsPath, final RecipientResolver resolver, final SessionStore sessionStore
+    ) {
+        final var keys = getKeysLocked(sessionsPath, resolver);
+        final var sessions = keys.stream().map(key -> {
+            final var record = loadSessionLocked(key, sessionsPath);
+            if (record == null) {
+                return null;
+            }
+            return new Pair<>(key, record);
+        }).filter(Objects::nonNull).toList();
+        sessionStore.addLegacySessions(sessions);
+        deleteAllSessions(sessionsPath);
+    }
+
+    private static void deleteAllSessions(File sessionsPath) {
+        final var files = sessionsPath.listFiles();
+        if (files == null) {
+            return;
+        }
+
+        for (var file : files) {
+            try {
+                Files.delete(file.toPath());
+            } catch (IOException e) {
+                logger.error("Failed to delete session file {}: {}", file, e.getMessage());
+            }
+        }
+        try {
+            Files.delete(sessionsPath.toPath());
+        } catch (IOException e) {
+            logger.error("Failed to delete session directory {}: {}", sessionsPath, e.getMessage());
+        }
+    }
+
+    private static Collection<Key> getKeysLocked(File sessionsPath, final RecipientResolver resolver) {
+        final var files = sessionsPath.listFiles();
+        if (files == null) {
+            return List.of();
+        }
+        return parseFileNames(files, resolver);
+    }
+
+    static final Pattern sessionFileNamePattern = Pattern.compile("(\\d+)_(\\d+)");
+
+    private static List<Key> parseFileNames(final File[] files, final RecipientResolver resolver) {
+        return Arrays.stream(files)
+                .map(f -> sessionFileNamePattern.matcher(f.getName()))
+                .filter(Matcher::matches)
+                .map(matcher -> {
+                    final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1)));
+                    if (recipientId == null) {
+                        return null;
+                    }
+                    return new Key(recipientId, Integer.parseInt(matcher.group(2)));
+                })
+                .filter(Objects::nonNull)
+                .toList();
+    }
+
+    private static File getSessionFile(Key key, final File sessionsPath) {
+        try {
+            IOUtils.createPrivateDirectories(sessionsPath);
+        } catch (IOException e) {
+            throw new AssertionError("Failed to create sessions path", e);
+        }
+        return new File(sessionsPath, key.recipientId().id() + "_" + key.deviceId());
+    }
+
+    private static SessionRecord loadSessionLocked(final Key key, final File sessionsPath) {
+        final var file = getSessionFile(key, sessionsPath);
+        if (!file.exists()) {
+            return null;
+        }
+        try (var inputStream = new FileInputStream(file)) {
+            return new SessionRecord(inputStream.readAllBytes());
+        } catch (Exception e) {
+            logger.warn("Failed to load session, resetting session: {}", e.getMessage());
+            return null;
+        }
+    }
+}
index b265e0bc8822b6f00e1361a91827eca1a84a7baa..ad0e198ea6b03b13f431766e65eed004c8cd2d30 100644 (file)
@@ -1,8 +1,12 @@
 package org.asamk.signal.manager.storage.sessions;
 
+import org.asamk.signal.manager.api.Pair;
+import org.asamk.signal.manager.storage.Database;
+import org.asamk.signal.manager.storage.Utils;
 import org.asamk.signal.manager.storage.recipients.RecipientId;
+import org.asamk.signal.manager.storage.recipients.RecipientIdCreator;
 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
-import org.asamk.signal.manager.util.IOUtils;
+import org.signal.libsignal.protocol.InvalidMessageException;
 import org.signal.libsignal.protocol.NoSessionException;
 import org.signal.libsignal.protocol.SignalProtocolAddress;
 import org.signal.libsignal.protocol.ecc.ECPublicKey;
@@ -11,50 +15,68 @@ import org.signal.libsignal.protocol.state.SessionRecord;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
+import org.whispersystems.signalservice.api.push.ServiceIdType;
 
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.nio.file.Files;
-import java.util.Arrays;
+import java.sql.Connection;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Set;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
 public class SessionStore implements SignalServiceSessionStore {
 
+    private static final String TABLE_SESSION = "session";
     private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
 
     private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
 
-    private final File sessionsPath;
-
+    private final Database database;
+    private final int accountIdType;
     private final RecipientResolver resolver;
+    private final RecipientIdCreator recipientIdCreator;
+
+    public static void createSql(Connection connection) throws SQLException {
+        // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
+        try (final var statement = connection.createStatement()) {
+            statement.executeUpdate("""
+                                    CREATE TABLE session (
+                                      _id INTEGER PRIMARY KEY,
+                                      account_id_type INTEGER NOT NULL,
+                                      recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
+                                      device_id INTEGER NOT NULL,
+                                      record BLOB NOT NULL,
+                                      UNIQUE(account_id_type, recipient_id, device_id)
+                                    );
+                                    """);
+        }
+    }
 
     public SessionStore(
-            final File sessionsPath, final RecipientResolver resolver
+            final Database database,
+            final ServiceIdType serviceIdType,
+            final RecipientResolver resolver,
+            final RecipientIdCreator recipientIdCreator
     ) {
-        this.sessionsPath = sessionsPath;
+        this.database = database;
+        this.accountIdType = Utils.getAccountIdType(serviceIdType);
         this.resolver = resolver;
+        this.recipientIdCreator = recipientIdCreator;
     }
 
     @Override
     public SessionRecord loadSession(SignalProtocolAddress address) {
         final var key = getKey(address);
-
-        synchronized (cachedSessions) {
-            final var session = loadSessionLocked(key);
-            if (session == null) {
-                return new SessionRecord();
-            }
-            return session;
+        try (final var connection = database.getConnection()) {
+            final var session = loadSession(connection, key);
+            return Objects.requireNonNullElseGet(session, SessionRecord::new);
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from session store", e);
         }
     }
 
@@ -62,8 +84,14 @@ public class SessionStore implements SignalServiceSessionStore {
     public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
         final var keys = addresses.stream().map(this::getKey).toList();
 
-        synchronized (cachedSessions) {
-            final var sessions = keys.stream().map(this::loadSessionLocked).filter(Objects::nonNull).toList();
+        try (final var connection = database.getConnection()) {
+            final var sessions = new ArrayList<SessionRecord>();
+            for (final var key : keys) {
+                final var sessionRecord = loadSession(connection, key);
+                if (sessionRecord != null) {
+                    sessions.add(sessionRecord);
+                }
+            }
 
             if (sessions.size() != addresses.size()) {
                 String message = "Mismatch! Asked for "
@@ -76,31 +104,44 @@ public class SessionStore implements SignalServiceSessionStore {
             }
 
             return sessions;
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from session store", e);
         }
     }
 
     @Override
     public List<Integer> getSubDeviceSessions(String name) {
-        final var recipientId = resolveRecipient(name);
-
-        synchronized (cachedSessions) {
-            return getKeysLocked(recipientId).stream()
-                    // get all sessions for recipient except primary device session
-                    .filter(key -> key.deviceId() != 1 && key.recipientId().equals(recipientId))
-                    .map(Key::deviceId)
-                    .toList();
+        final var recipientId = resolver.resolveRecipient(name);
+        // get all sessions for recipient except primary device session
+        final var sql = (
+                """
+                SELECT s.device_id
+                FROM %s AS s
+                WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id != 1
+                """
+        ).formatted(TABLE_SESSION);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setLong(2, recipientId.id());
+                return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from session store", e);
         }
     }
 
     public boolean isCurrentRatchetKey(RecipientId recipientId, int deviceId, ECPublicKey ratchetKey) {
         final var key = new Key(recipientId, deviceId);
 
-        synchronized (cachedSessions) {
-            final var session = loadSessionLocked(key);
+        try (final var connection = database.getConnection()) {
+            final var session = loadSession(connection, key);
             if (session == null) {
                 return false;
             }
             return session.currentRatchetKeyMatches(ratchetKey);
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from session store", e);
         }
     }
 
@@ -108,8 +149,10 @@ public class SessionStore implements SignalServiceSessionStore {
     public void storeSession(SignalProtocolAddress address, SessionRecord session) {
         final var key = getKey(address);
 
-        synchronized (cachedSessions) {
-            storeSessionLocked(key, session);
+        try (final var connection = database.getConnection()) {
+            storeSession(connection, key, session);
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from session store", e);
         }
     }
 
@@ -117,9 +160,11 @@ public class SessionStore implements SignalServiceSessionStore {
     public boolean containsSession(SignalProtocolAddress address) {
         final var key = getKey(address);
 
-        synchronized (cachedSessions) {
-            final var session = loadSessionLocked(key);
+        try (final var connection = database.getConnection()) {
+            final var session = loadSession(connection, key);
             return isActive(session);
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from session store", e);
         }
     }
 
@@ -127,23 +172,24 @@ public class SessionStore implements SignalServiceSessionStore {
     public void deleteSession(SignalProtocolAddress address) {
         final var key = getKey(address);
 
-        synchronized (cachedSessions) {
-            deleteSessionLocked(key);
+        try (final var connection = database.getConnection()) {
+            deleteSession(connection, key);
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update session store", e);
         }
     }
 
     @Override
     public void deleteAllSessions(String name) {
-        final var recipientId = resolveRecipient(name);
+        final var recipientId = resolver.resolveRecipient(name);
         deleteAllSessions(recipientId);
     }
 
     public void deleteAllSessions(RecipientId recipientId) {
-        synchronized (cachedSessions) {
-            final var keys = getKeysLocked(recipientId);
-            for (var key : keys) {
-                deleteSessionLocked(key);
-            }
+        try (final var connection = database.getConnection()) {
+            deleteAllSessions(connection, recipientId);
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update session store", e);
         }
     }
 
@@ -151,186 +197,244 @@ public class SessionStore implements SignalServiceSessionStore {
     public void archiveSession(final SignalProtocolAddress address) {
         final var key = getKey(address);
 
-        synchronized (cachedSessions) {
-            archiveSessionLocked(key);
+        try (final var connection = database.getConnection()) {
+            connection.setAutoCommit(false);
+            final var session = loadSession(connection, key);
+            if (session != null) {
+                session.archiveCurrentState();
+                storeSession(connection, key, session);
+                connection.commit();
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update session store", e);
         }
     }
 
     @Override
     public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
         final var recipientIdToNameMap = addressNames.stream()
-                .collect(Collectors.toMap(this::resolveRecipient, name -> name));
-        synchronized (cachedSessions) {
-            return recipientIdToNameMap.keySet()
-                    .stream()
-                    .flatMap(recipientId -> getKeysLocked(recipientId).stream())
-                    .filter(key -> isActive(this.loadSessionLocked(key)))
-                    .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), key.deviceId()))
-                    .collect(Collectors.toSet());
+                .collect(Collectors.toMap(resolver::resolveRecipient, name -> name));
+        final var recipientIdsCommaSeparated = recipientIdToNameMap.keySet()
+                .stream()
+                .map(recipientId -> String.valueOf(recipientId.id()))
+                .collect(Collectors.joining(","));
+        final var sql = (
+                """
+                SELECT s.recipient_id, s.device_id, s.record
+                FROM %s AS s
+                WHERE s.account_id_type = ? AND s.recipient_id IN (%s)
+                """
+        ).formatted(TABLE_SESSION, recipientIdsCommaSeparated);
+        try (final var connection = database.getConnection()) {
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                return Utils.executeQueryForStream(statement,
+                                res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
+                        .filter(pair -> isActive(pair.second()))
+                        .map(Pair::first)
+                        .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId),
+                                key.deviceId()))
+                        .collect(Collectors.toSet());
+            }
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed read from session store", e);
         }
     }
 
     public void archiveAllSessions() {
-        synchronized (cachedSessions) {
-            final var keys = getKeysLocked();
-            for (var key : keys) {
-                archiveSessionLocked(key);
+        final var sql = (
+                """
+                SELECT s.recipient_id, s.device_id, s.record
+                FROM %s AS s
+                WHERE s.account_id_type = ?
+                """
+        ).formatted(TABLE_SESSION);
+        try (final var connection = database.getConnection()) {
+            connection.setAutoCommit(false);
+            final List<Pair<Key, SessionRecord>> records;
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                records = Utils.executeQueryForStream(statement,
+                        res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
+            }
+            for (final var record : records) {
+                record.second().archiveCurrentState();
+                storeSession(connection, record.first(), record.second());
             }
+            connection.commit();
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update session store", e);
         }
     }
 
     public void archiveSessions(final RecipientId recipientId) {
-        synchronized (cachedSessions) {
-            getKeysLocked().stream()
-                    .filter(key -> key.recipientId.equals(recipientId))
-                    .forEach(this::archiveSessionLocked);
+        final var sql = (
+                """
+                SELECT s.recipient_id, s.device_id, s.record
+                FROM %s AS s
+                WHERE s.account_id_type = ? AND s.recipient_id = ?
+                """
+        ).formatted(TABLE_SESSION);
+        try (final var connection = database.getConnection()) {
+            connection.setAutoCommit(false);
+            final List<Pair<Key, SessionRecord>> records;
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setInt(1, accountIdType);
+                statement.setLong(2, recipientId.id());
+                records = Utils.executeQueryForStream(statement,
+                        res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
+            }
+            for (final var record : records) {
+                record.second().archiveCurrentState();
+                storeSession(connection, record.first(), record.second());
+            }
+            connection.commit();
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update session store", e);
         }
     }
 
     public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
-        synchronized (cachedSessions) {
-            final var keys = getKeysLocked(toBeMergedRecipientId);
-            final var otherHasSession = keys.size() > 0;
-            if (!otherHasSession) {
-                return;
+        try (final var connection = database.getConnection()) {
+            connection.setAutoCommit(false);
+            synchronized (cachedSessions) {
+                cachedSessions.clear();
             }
 
-            final var hasSession = getKeysLocked(recipientId).size() > 0;
-            if (hasSession) {
-                logger.debug("To be merged recipient had sessions, deleting.");
-                deleteAllSessions(toBeMergedRecipientId);
-            } else {
-                logger.debug("Only to be merged recipient had sessions, re-assigning to the new recipient.");
-                for (var key : keys) {
-                    final var session = loadSessionLocked(key);
-                    deleteSessionLocked(key);
-                    if (session == null) {
-                        continue;
-                    }
-                    final var newKey = new Key(recipientId, key.deviceId());
-                    storeSessionLocked(newKey, session);
+            final var sql = """
+                            UPDATE OR IGNORE %s
+                            SET recipient_id = ?
+                            WHERE account_id_type = ? AND recipient_id = ?
+                            """.formatted(TABLE_SESSION);
+            try (final var statement = connection.prepareStatement(sql)) {
+                statement.setLong(1, recipientId.id());
+                statement.setInt(2, accountIdType);
+                statement.setLong(3, toBeMergedRecipientId.id());
+                final var rows = statement.executeUpdate();
+                if (rows > 0) {
+                    logger.debug("Reassigned {} sessions of to be merged recipient.", rows);
                 }
             }
+            // Delete all conflicting sessions now
+            deleteAllSessions(connection, toBeMergedRecipientId);
+            connection.commit();
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update session store", e);
         }
     }
 
-    /**
-     * @param identifier can be either a serialized uuid or a e164 phone number
-     */
-    private RecipientId resolveRecipient(String identifier) {
-        return resolver.resolveRecipient(identifier);
+    void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
+        logger.debug("Migrating legacy sessions to database");
+        long start = System.nanoTime();
+        try (final var connection = database.getConnection()) {
+            connection.setAutoCommit(false);
+            for (final var pair : sessions) {
+                storeSession(connection, pair.first(), pair.second());
+            }
+            connection.commit();
+        } catch (SQLException e) {
+            throw new RuntimeException("Failed update session store", e);
+        }
+        logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
     }
 
     private Key getKey(final SignalProtocolAddress address) {
-        final var recipientId = resolveRecipient(address.getName());
+        final var recipientId = resolver.resolveRecipient(address.getName());
         return new Key(recipientId, address.getDeviceId());
     }
 
-    private List<Key> getKeysLocked(RecipientId recipientId) {
-        final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.id() + "_"));
-        if (files == null) {
-            return List.of();
+    private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
+        synchronized (cachedSessions) {
+            final var session = cachedSessions.get(key);
+            if (session != null) {
+                return session;
+            }
         }
-        return parseFileNames(files);
-    }
-
-    private Collection<Key> getKeysLocked() {
-        final var files = sessionsPath.listFiles();
-        if (files == null) {
-            return List.of();
+        final var sql = (
+                """
+                SELECT s.record
+                FROM %s AS s
+                WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ?
+                """
+        ).formatted(TABLE_SESSION);
+        try (final var statement = connection.prepareStatement(sql)) {
+            statement.setInt(1, accountIdType);
+            statement.setLong(2, key.recipientId().id());
+            statement.setInt(3, key.deviceId());
+            return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
         }
-        return parseFileNames(files);
     }
 
-    final Pattern sessionFileNamePattern = Pattern.compile("(\\d+)_(\\d+)");
-
-    private List<Key> parseFileNames(final File[] files) {
-        return Arrays.stream(files)
-                .map(f -> sessionFileNamePattern.matcher(f.getName()))
-                .filter(Matcher::matches)
-                .map(matcher -> {
-                    final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1)));
-                    if (recipientId == null) {
-                        return null;
-                    }
-                    return new Key(recipientId, Integer.parseInt(matcher.group(2)));
-                })
-                .filter(Objects::nonNull)
-                .toList();
+    private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
+        final var recipientId = resultSet.getLong("recipient_id");
+        final var deviceId = resultSet.getInt("device_id");
+        return new Key(recipientIdCreator.create(recipientId), deviceId);
     }
 
-    private File getSessionFile(Key key) {
+    private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException {
         try {
-            IOUtils.createPrivateDirectories(sessionsPath);
-        } catch (IOException e) {
-            throw new AssertionError("Failed to create sessions path", e);
+            final var record = resultSet.getBytes("record");
+            return new SessionRecord(record);
+        } catch (InvalidMessageException e) {
+            logger.warn("Failed to load session, resetting session: {}", e.getMessage());
+            return null;
         }
-        return new File(sessionsPath, key.recipientId().id() + "_" + key.deviceId());
     }
 
-    private SessionRecord loadSessionLocked(final Key key) {
-        {
-            final var session = cachedSessions.get(key);
-            if (session != null) {
-                return session;
-            }
+    private void storeSession(
+            final Connection connection, final Key key, final SessionRecord session
+    ) throws SQLException {
+        synchronized (cachedSessions) {
+            cachedSessions.put(key, session);
         }
 
-        final var file = getSessionFile(key);
-        if (!file.exists()) {
-            return null;
-        }
-        try (var inputStream = new FileInputStream(file)) {
-            final var session = new SessionRecord(inputStream.readAllBytes());
-            cachedSessions.put(key, session);
-            return session;
-        } catch (Exception e) {
-            logger.warn("Failed to load session, resetting session: {}", e.getMessage());
-            return null;
+        final var sql = """
+                        INSERT OR REPLACE INTO %s (account_id_type, recipient_id, device_id, record)
+                        VALUES (?, ?, ?, ?)
+                        """.formatted(TABLE_SESSION);
+        try (final var statement = connection.prepareStatement(sql)) {
+            statement.setInt(1, accountIdType);
+            statement.setLong(2, key.recipientId().id());
+            statement.setInt(3, key.deviceId());
+            statement.setBytes(4, session.serialize());
+            statement.executeUpdate();
         }
     }
 
-    private void storeSessionLocked(final Key key, final SessionRecord session) {
-        cachedSessions.put(key, session);
-
-        final var file = getSessionFile(key);
-        try {
-            try (var outputStream = new FileOutputStream(file)) {
-                outputStream.write(session.serialize());
-            }
-        } catch (IOException e) {
-            logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
-            try {
-                Files.delete(file.toPath());
-                try (var outputStream = new FileOutputStream(file)) {
-                    outputStream.write(session.serialize());
-                }
-            } catch (IOException e2) {
-                logger.error("Failed to store session file {}: {}", file, e2.getMessage());
-            }
+    private void deleteAllSessions(final Connection connection, final RecipientId recipientId) throws SQLException {
+        synchronized (cachedSessions) {
+            cachedSessions.clear();
         }
-    }
 
-    private void archiveSessionLocked(final Key key) {
-        final var session = loadSessionLocked(key);
-        if (session == null) {
-            return;
+        final var sql = (
+                """
+                DELETE FROM %s AS s
+                WHERE s.account_id_type = ? AND s.recipient_id = ?
+                """
+        ).formatted(TABLE_SESSION);
+        try (final var statement = connection.prepareStatement(sql)) {
+            statement.setInt(1, accountIdType);
+            statement.setLong(2, recipientId.id());
+            statement.executeUpdate();
         }
-        session.archiveCurrentState();
-        storeSessionLocked(key, session);
     }
 
-    private void deleteSessionLocked(final Key key) {
-        cachedSessions.remove(key);
-
-        final var file = getSessionFile(key);
-        if (!file.exists()) {
-            return;
+    private void deleteSession(Connection connection, final Key key) throws SQLException {
+        synchronized (cachedSessions) {
+            cachedSessions.remove(key);
         }
-        try {
-            Files.delete(file.toPath());
-        } catch (IOException e) {
-            logger.error("Failed to delete session file {}: {}", file, e.getMessage());
+
+        final var sql = (
+                """
+                DELETE FROM %s AS s
+                WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ?
+                """
+        ).formatted(TABLE_SESSION);
+        try (final var statement = connection.prepareStatement(sql)) {
+            statement.setInt(1, accountIdType);
+            statement.setLong(2, key.recipientId().id());
+            statement.setInt(3, key.deviceId());
+            statement.executeUpdate();
         }
     }
 
@@ -340,5 +444,5 @@ public class SessionStore implements SignalServiceSessionStore {
                 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
     }
 
-    private record Key(RecipientId recipientId, int deviceId) {}
+    record Key(RecipientId recipientId, int deviceId) {}
 }