From: AsamK Date: Thu, 9 Jun 2022 14:54:48 +0000 (+0200) Subject: Move session store to database X-Git-Tag: v0.11.0~21 X-Git-Url: https://git.nmode.ca/signal-cli/commitdiff_plain/484daa4c69bf9c1da451585fdbcc91a40ed00053?ds=inline Move session store to database --- diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java b/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java index 93a145eb..ef88c227 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/AccountDatabase.java @@ -7,6 +7,7 @@ import org.asamk.signal.manager.storage.prekeys.PreKeyStore; import org.asamk.signal.manager.storage.prekeys.SignedPreKeyStore; import org.asamk.signal.manager.storage.recipients.RecipientStore; import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore; +import org.asamk.signal.manager.storage.sessions.SessionStore; import org.asamk.signal.manager.storage.stickers.StickerStore; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -18,7 +19,7 @@ import java.sql.SQLException; public class AccountDatabase extends Database { private final static Logger logger = LoggerFactory.getLogger(AccountDatabase.class); - private static final long DATABASE_VERSION = 5; + private static final long DATABASE_VERSION = 6; private AccountDatabase(final HikariDataSource dataSource) { super(logger, DATABASE_VERSION, dataSource); @@ -36,6 +37,7 @@ public class AccountDatabase extends Database { PreKeyStore.createSql(connection); SignedPreKeyStore.createSql(connection); GroupStore.createSql(connection); + SessionStore.createSql(connection); } @Override @@ -143,5 +145,20 @@ public class AccountDatabase extends Database { """); } } + if (oldVersion < 6) { + logger.debug("Updating database: Creating session tables"); + try (final var statement = connection.createStatement()) { + statement.executeUpdate(""" + CREATE TABLE session ( + _id INTEGER PRIMARY KEY, + account_id_type INTEGER NOT NULL, + recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE, + device_id INTEGER NOT NULL, + record BLOB NOT NULL, + UNIQUE(account_id_type, recipient_id, device_id) + ); + """); + } + } } } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java index 675c3f3e..6e4621cd 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java @@ -38,6 +38,7 @@ import org.asamk.signal.manager.storage.recipients.RecipientStore; import org.asamk.signal.manager.storage.recipients.RecipientTrustedResolver; import org.asamk.signal.manager.storage.sendLog.MessageSendLogStore; import org.asamk.signal.manager.storage.senderKeys.SenderKeyStore; +import org.asamk.signal.manager.storage.sessions.LegacySessionStore; import org.asamk.signal.manager.storage.sessions.SessionStore; import org.asamk.signal.manager.storage.stickers.LegacyStickerStore; import org.asamk.signal.manager.storage.stickers.StickerStore; @@ -634,6 +635,11 @@ public class SignalAccount implements Closeable { LegacySignedPreKeyStore.migrate(legacyPniSignedPreKeysPath, getPniSignedPreKeyStore()); migratedLegacyConfig = true; } + final var legacySessionsPath = getSessionsPath(dataPath, accountPath); + if (legacySessionsPath.exists()) { + LegacySessionStore.migrate(legacySessionsPath, getRecipientResolver(), getSessionStore()); + migratedLegacyConfig = true; + } final var legacySignalProtocolStore = rootNode.hasNonNull("axolotlStore") ? jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"), LegacyJsonSignalProtocolStore.class) @@ -1067,7 +1073,10 @@ public class SignalAccount implements Closeable { public SessionStore getSessionStore() { return getOrCreate(() -> sessionStore, - () -> sessionStore = new SessionStore(getSessionsPath(dataPath, accountPath), getRecipientResolver())); + () -> sessionStore = new SessionStore(getAccountDatabase(), + ServiceIdType.ACI, + getRecipientResolver(), + getRecipientIdCreator())); } public IdentityKeyStore getIdentityKeyStore() { diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/sessions/LegacySessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/LegacySessionStore.java new file mode 100644 index 00000000..d0646ec1 --- /dev/null +++ b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/LegacySessionStore.java @@ -0,0 +1,107 @@ +package org.asamk.signal.manager.storage.sessions; + +import org.asamk.signal.manager.api.Pair; +import org.asamk.signal.manager.storage.recipients.RecipientResolver; +import org.asamk.signal.manager.storage.sessions.SessionStore.Key; +import org.asamk.signal.manager.util.IOUtils; +import org.signal.libsignal.protocol.state.SessionRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class LegacySessionStore { + + private final static Logger logger = LoggerFactory.getLogger(LegacySessionStore.class); + + public static void migrate( + final File sessionsPath, final RecipientResolver resolver, final SessionStore sessionStore + ) { + final var keys = getKeysLocked(sessionsPath, resolver); + final var sessions = keys.stream().map(key -> { + final var record = loadSessionLocked(key, sessionsPath); + if (record == null) { + return null; + } + return new Pair<>(key, record); + }).filter(Objects::nonNull).toList(); + sessionStore.addLegacySessions(sessions); + deleteAllSessions(sessionsPath); + } + + private static void deleteAllSessions(File sessionsPath) { + final var files = sessionsPath.listFiles(); + if (files == null) { + return; + } + + for (var file : files) { + try { + Files.delete(file.toPath()); + } catch (IOException e) { + logger.error("Failed to delete session file {}: {}", file, e.getMessage()); + } + } + try { + Files.delete(sessionsPath.toPath()); + } catch (IOException e) { + logger.error("Failed to delete session directory {}: {}", sessionsPath, e.getMessage()); + } + } + + private static Collection getKeysLocked(File sessionsPath, final RecipientResolver resolver) { + final var files = sessionsPath.listFiles(); + if (files == null) { + return List.of(); + } + return parseFileNames(files, resolver); + } + + static final Pattern sessionFileNamePattern = Pattern.compile("(\\d+)_(\\d+)"); + + private static List parseFileNames(final File[] files, final RecipientResolver resolver) { + return Arrays.stream(files) + .map(f -> sessionFileNamePattern.matcher(f.getName())) + .filter(Matcher::matches) + .map(matcher -> { + final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1))); + if (recipientId == null) { + return null; + } + return new Key(recipientId, Integer.parseInt(matcher.group(2))); + }) + .filter(Objects::nonNull) + .toList(); + } + + private static File getSessionFile(Key key, final File sessionsPath) { + try { + IOUtils.createPrivateDirectories(sessionsPath); + } catch (IOException e) { + throw new AssertionError("Failed to create sessions path", e); + } + return new File(sessionsPath, key.recipientId().id() + "_" + key.deviceId()); + } + + private static SessionRecord loadSessionLocked(final Key key, final File sessionsPath) { + final var file = getSessionFile(key, sessionsPath); + if (!file.exists()) { + return null; + } + try (var inputStream = new FileInputStream(file)) { + return new SessionRecord(inputStream.readAllBytes()); + } catch (Exception e) { + logger.warn("Failed to load session, resetting session: {}", e.getMessage()); + return null; + } + } +} diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java index b265e0bc..ad0e198e 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java @@ -1,8 +1,12 @@ package org.asamk.signal.manager.storage.sessions; +import org.asamk.signal.manager.api.Pair; +import org.asamk.signal.manager.storage.Database; +import org.asamk.signal.manager.storage.Utils; import org.asamk.signal.manager.storage.recipients.RecipientId; +import org.asamk.signal.manager.storage.recipients.RecipientIdCreator; import org.asamk.signal.manager.storage.recipients.RecipientResolver; -import org.asamk.signal.manager.util.IOUtils; +import org.signal.libsignal.protocol.InvalidMessageException; import org.signal.libsignal.protocol.NoSessionException; import org.signal.libsignal.protocol.SignalProtocolAddress; import org.signal.libsignal.protocol.ecc.ECPublicKey; @@ -11,50 +15,68 @@ import org.signal.libsignal.protocol.state.SessionRecord; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.signalservice.api.SignalServiceSessionStore; +import org.whispersystems.signalservice.api.push.ServiceIdType; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.nio.file.Files; -import java.util.Arrays; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import java.util.stream.Collectors; public class SessionStore implements SignalServiceSessionStore { + private static final String TABLE_SESSION = "session"; private final static Logger logger = LoggerFactory.getLogger(SessionStore.class); private final Map cachedSessions = new HashMap<>(); - private final File sessionsPath; - + private final Database database; + private final int accountIdType; private final RecipientResolver resolver; + private final RecipientIdCreator recipientIdCreator; + + public static void createSql(Connection connection) throws SQLException { + // When modifying the CREATE statement here, also add a migration in AccountDatabase.java + try (final var statement = connection.createStatement()) { + statement.executeUpdate(""" + CREATE TABLE session ( + _id INTEGER PRIMARY KEY, + account_id_type INTEGER NOT NULL, + recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE, + device_id INTEGER NOT NULL, + record BLOB NOT NULL, + UNIQUE(account_id_type, recipient_id, device_id) + ); + """); + } + } public SessionStore( - final File sessionsPath, final RecipientResolver resolver + final Database database, + final ServiceIdType serviceIdType, + final RecipientResolver resolver, + final RecipientIdCreator recipientIdCreator ) { - this.sessionsPath = sessionsPath; + this.database = database; + this.accountIdType = Utils.getAccountIdType(serviceIdType); this.resolver = resolver; + this.recipientIdCreator = recipientIdCreator; } @Override public SessionRecord loadSession(SignalProtocolAddress address) { final var key = getKey(address); - - synchronized (cachedSessions) { - final var session = loadSessionLocked(key); - if (session == null) { - return new SessionRecord(); - } - return session; + try (final var connection = database.getConnection()) { + final var session = loadSession(connection, key); + return Objects.requireNonNullElseGet(session, SessionRecord::new); + } catch (SQLException e) { + throw new RuntimeException("Failed read from session store", e); } } @@ -62,8 +84,14 @@ public class SessionStore implements SignalServiceSessionStore { public List loadExistingSessions(final List addresses) throws NoSessionException { final var keys = addresses.stream().map(this::getKey).toList(); - synchronized (cachedSessions) { - final var sessions = keys.stream().map(this::loadSessionLocked).filter(Objects::nonNull).toList(); + try (final var connection = database.getConnection()) { + final var sessions = new ArrayList(); + for (final var key : keys) { + final var sessionRecord = loadSession(connection, key); + if (sessionRecord != null) { + sessions.add(sessionRecord); + } + } if (sessions.size() != addresses.size()) { String message = "Mismatch! Asked for " @@ -76,31 +104,44 @@ public class SessionStore implements SignalServiceSessionStore { } return sessions; + } catch (SQLException e) { + throw new RuntimeException("Failed read from session store", e); } } @Override public List getSubDeviceSessions(String name) { - final var recipientId = resolveRecipient(name); - - synchronized (cachedSessions) { - return getKeysLocked(recipientId).stream() - // get all sessions for recipient except primary device session - .filter(key -> key.deviceId() != 1 && key.recipientId().equals(recipientId)) - .map(Key::deviceId) - .toList(); + final var recipientId = resolver.resolveRecipient(name); + // get all sessions for recipient except primary device session + final var sql = ( + """ + SELECT s.device_id + FROM %s AS s + WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id != 1 + """ + ).formatted(TABLE_SESSION); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setLong(2, recipientId.id()); + return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed read from session store", e); } } public boolean isCurrentRatchetKey(RecipientId recipientId, int deviceId, ECPublicKey ratchetKey) { final var key = new Key(recipientId, deviceId); - synchronized (cachedSessions) { - final var session = loadSessionLocked(key); + try (final var connection = database.getConnection()) { + final var session = loadSession(connection, key); if (session == null) { return false; } return session.currentRatchetKeyMatches(ratchetKey); + } catch (SQLException e) { + throw new RuntimeException("Failed read from session store", e); } } @@ -108,8 +149,10 @@ public class SessionStore implements SignalServiceSessionStore { public void storeSession(SignalProtocolAddress address, SessionRecord session) { final var key = getKey(address); - synchronized (cachedSessions) { - storeSessionLocked(key, session); + try (final var connection = database.getConnection()) { + storeSession(connection, key, session); + } catch (SQLException e) { + throw new RuntimeException("Failed read from session store", e); } } @@ -117,9 +160,11 @@ public class SessionStore implements SignalServiceSessionStore { public boolean containsSession(SignalProtocolAddress address) { final var key = getKey(address); - synchronized (cachedSessions) { - final var session = loadSessionLocked(key); + try (final var connection = database.getConnection()) { + final var session = loadSession(connection, key); return isActive(session); + } catch (SQLException e) { + throw new RuntimeException("Failed read from session store", e); } } @@ -127,23 +172,24 @@ public class SessionStore implements SignalServiceSessionStore { public void deleteSession(SignalProtocolAddress address) { final var key = getKey(address); - synchronized (cachedSessions) { - deleteSessionLocked(key); + try (final var connection = database.getConnection()) { + deleteSession(connection, key); + } catch (SQLException e) { + throw new RuntimeException("Failed update session store", e); } } @Override public void deleteAllSessions(String name) { - final var recipientId = resolveRecipient(name); + final var recipientId = resolver.resolveRecipient(name); deleteAllSessions(recipientId); } public void deleteAllSessions(RecipientId recipientId) { - synchronized (cachedSessions) { - final var keys = getKeysLocked(recipientId); - for (var key : keys) { - deleteSessionLocked(key); - } + try (final var connection = database.getConnection()) { + deleteAllSessions(connection, recipientId); + } catch (SQLException e) { + throw new RuntimeException("Failed update session store", e); } } @@ -151,186 +197,244 @@ public class SessionStore implements SignalServiceSessionStore { public void archiveSession(final SignalProtocolAddress address) { final var key = getKey(address); - synchronized (cachedSessions) { - archiveSessionLocked(key); + try (final var connection = database.getConnection()) { + connection.setAutoCommit(false); + final var session = loadSession(connection, key); + if (session != null) { + session.archiveCurrentState(); + storeSession(connection, key, session); + connection.commit(); + } + } catch (SQLException e) { + throw new RuntimeException("Failed update session store", e); } } @Override public Set getAllAddressesWithActiveSessions(final List addressNames) { final var recipientIdToNameMap = addressNames.stream() - .collect(Collectors.toMap(this::resolveRecipient, name -> name)); - synchronized (cachedSessions) { - return recipientIdToNameMap.keySet() - .stream() - .flatMap(recipientId -> getKeysLocked(recipientId).stream()) - .filter(key -> isActive(this.loadSessionLocked(key))) - .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), key.deviceId())) - .collect(Collectors.toSet()); + .collect(Collectors.toMap(resolver::resolveRecipient, name -> name)); + final var recipientIdsCommaSeparated = recipientIdToNameMap.keySet() + .stream() + .map(recipientId -> String.valueOf(recipientId.id())) + .collect(Collectors.joining(",")); + final var sql = ( + """ + SELECT s.recipient_id, s.device_id, s.record + FROM %s AS s + WHERE s.account_id_type = ? AND s.recipient_id IN (%s) + """ + ).formatted(TABLE_SESSION, recipientIdsCommaSeparated); + try (final var connection = database.getConnection()) { + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + return Utils.executeQueryForStream(statement, + res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))) + .filter(pair -> isActive(pair.second())) + .map(Pair::first) + .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), + key.deviceId())) + .collect(Collectors.toSet()); + } + } catch (SQLException e) { + throw new RuntimeException("Failed read from session store", e); } } public void archiveAllSessions() { - synchronized (cachedSessions) { - final var keys = getKeysLocked(); - for (var key : keys) { - archiveSessionLocked(key); + final var sql = ( + """ + SELECT s.recipient_id, s.device_id, s.record + FROM %s AS s + WHERE s.account_id_type = ? + """ + ).formatted(TABLE_SESSION); + try (final var connection = database.getConnection()) { + connection.setAutoCommit(false); + final List> records; + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + records = Utils.executeQueryForStream(statement, + res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList(); + } + for (final var record : records) { + record.second().archiveCurrentState(); + storeSession(connection, record.first(), record.second()); } + connection.commit(); + } catch (SQLException e) { + throw new RuntimeException("Failed update session store", e); } } public void archiveSessions(final RecipientId recipientId) { - synchronized (cachedSessions) { - getKeysLocked().stream() - .filter(key -> key.recipientId.equals(recipientId)) - .forEach(this::archiveSessionLocked); + final var sql = ( + """ + SELECT s.recipient_id, s.device_id, s.record + FROM %s AS s + WHERE s.account_id_type = ? AND s.recipient_id = ? + """ + ).formatted(TABLE_SESSION); + try (final var connection = database.getConnection()) { + connection.setAutoCommit(false); + final List> records; + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setLong(2, recipientId.id()); + records = Utils.executeQueryForStream(statement, + res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList(); + } + for (final var record : records) { + record.second().archiveCurrentState(); + storeSession(connection, record.first(), record.second()); + } + connection.commit(); + } catch (SQLException e) { + throw new RuntimeException("Failed update session store", e); } } public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) { - synchronized (cachedSessions) { - final var keys = getKeysLocked(toBeMergedRecipientId); - final var otherHasSession = keys.size() > 0; - if (!otherHasSession) { - return; + try (final var connection = database.getConnection()) { + connection.setAutoCommit(false); + synchronized (cachedSessions) { + cachedSessions.clear(); } - final var hasSession = getKeysLocked(recipientId).size() > 0; - if (hasSession) { - logger.debug("To be merged recipient had sessions, deleting."); - deleteAllSessions(toBeMergedRecipientId); - } else { - logger.debug("Only to be merged recipient had sessions, re-assigning to the new recipient."); - for (var key : keys) { - final var session = loadSessionLocked(key); - deleteSessionLocked(key); - if (session == null) { - continue; - } - final var newKey = new Key(recipientId, key.deviceId()); - storeSessionLocked(newKey, session); + final var sql = """ + UPDATE OR IGNORE %s + SET recipient_id = ? + WHERE account_id_type = ? AND recipient_id = ? + """.formatted(TABLE_SESSION); + try (final var statement = connection.prepareStatement(sql)) { + statement.setLong(1, recipientId.id()); + statement.setInt(2, accountIdType); + statement.setLong(3, toBeMergedRecipientId.id()); + final var rows = statement.executeUpdate(); + if (rows > 0) { + logger.debug("Reassigned {} sessions of to be merged recipient.", rows); } } + // Delete all conflicting sessions now + deleteAllSessions(connection, toBeMergedRecipientId); + connection.commit(); + } catch (SQLException e) { + throw new RuntimeException("Failed update session store", e); } } - /** - * @param identifier can be either a serialized uuid or a e164 phone number - */ - private RecipientId resolveRecipient(String identifier) { - return resolver.resolveRecipient(identifier); + void addLegacySessions(final Collection> sessions) { + logger.debug("Migrating legacy sessions to database"); + long start = System.nanoTime(); + try (final var connection = database.getConnection()) { + connection.setAutoCommit(false); + for (final var pair : sessions) { + storeSession(connection, pair.first(), pair.second()); + } + connection.commit(); + } catch (SQLException e) { + throw new RuntimeException("Failed update session store", e); + } + logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000); } private Key getKey(final SignalProtocolAddress address) { - final var recipientId = resolveRecipient(address.getName()); + final var recipientId = resolver.resolveRecipient(address.getName()); return new Key(recipientId, address.getDeviceId()); } - private List getKeysLocked(RecipientId recipientId) { - final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.id() + "_")); - if (files == null) { - return List.of(); + private SessionRecord loadSession(Connection connection, final Key key) throws SQLException { + synchronized (cachedSessions) { + final var session = cachedSessions.get(key); + if (session != null) { + return session; + } } - return parseFileNames(files); - } - - private Collection getKeysLocked() { - final var files = sessionsPath.listFiles(); - if (files == null) { - return List.of(); + final var sql = ( + """ + SELECT s.record + FROM %s AS s + WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ? + """ + ).formatted(TABLE_SESSION); + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setLong(2, key.recipientId().id()); + statement.setInt(3, key.deviceId()); + return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null); } - return parseFileNames(files); } - final Pattern sessionFileNamePattern = Pattern.compile("(\\d+)_(\\d+)"); - - private List parseFileNames(final File[] files) { - return Arrays.stream(files) - .map(f -> sessionFileNamePattern.matcher(f.getName())) - .filter(Matcher::matches) - .map(matcher -> { - final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1))); - if (recipientId == null) { - return null; - } - return new Key(recipientId, Integer.parseInt(matcher.group(2))); - }) - .filter(Objects::nonNull) - .toList(); + private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException { + final var recipientId = resultSet.getLong("recipient_id"); + final var deviceId = resultSet.getInt("device_id"); + return new Key(recipientIdCreator.create(recipientId), deviceId); } - private File getSessionFile(Key key) { + private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException { try { - IOUtils.createPrivateDirectories(sessionsPath); - } catch (IOException e) { - throw new AssertionError("Failed to create sessions path", e); + final var record = resultSet.getBytes("record"); + return new SessionRecord(record); + } catch (InvalidMessageException e) { + logger.warn("Failed to load session, resetting session: {}", e.getMessage()); + return null; } - return new File(sessionsPath, key.recipientId().id() + "_" + key.deviceId()); } - private SessionRecord loadSessionLocked(final Key key) { - { - final var session = cachedSessions.get(key); - if (session != null) { - return session; - } + private void storeSession( + final Connection connection, final Key key, final SessionRecord session + ) throws SQLException { + synchronized (cachedSessions) { + cachedSessions.put(key, session); } - final var file = getSessionFile(key); - if (!file.exists()) { - return null; - } - try (var inputStream = new FileInputStream(file)) { - final var session = new SessionRecord(inputStream.readAllBytes()); - cachedSessions.put(key, session); - return session; - } catch (Exception e) { - logger.warn("Failed to load session, resetting session: {}", e.getMessage()); - return null; + final var sql = """ + INSERT OR REPLACE INTO %s (account_id_type, recipient_id, device_id, record) + VALUES (?, ?, ?, ?) + """.formatted(TABLE_SESSION); + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setLong(2, key.recipientId().id()); + statement.setInt(3, key.deviceId()); + statement.setBytes(4, session.serialize()); + statement.executeUpdate(); } } - private void storeSessionLocked(final Key key, final SessionRecord session) { - cachedSessions.put(key, session); - - final var file = getSessionFile(key); - try { - try (var outputStream = new FileOutputStream(file)) { - outputStream.write(session.serialize()); - } - } catch (IOException e) { - logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage()); - try { - Files.delete(file.toPath()); - try (var outputStream = new FileOutputStream(file)) { - outputStream.write(session.serialize()); - } - } catch (IOException e2) { - logger.error("Failed to store session file {}: {}", file, e2.getMessage()); - } + private void deleteAllSessions(final Connection connection, final RecipientId recipientId) throws SQLException { + synchronized (cachedSessions) { + cachedSessions.clear(); } - } - private void archiveSessionLocked(final Key key) { - final var session = loadSessionLocked(key); - if (session == null) { - return; + final var sql = ( + """ + DELETE FROM %s AS s + WHERE s.account_id_type = ? AND s.recipient_id = ? + """ + ).formatted(TABLE_SESSION); + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setLong(2, recipientId.id()); + statement.executeUpdate(); } - session.archiveCurrentState(); - storeSessionLocked(key, session); } - private void deleteSessionLocked(final Key key) { - cachedSessions.remove(key); - - final var file = getSessionFile(key); - if (!file.exists()) { - return; + private void deleteSession(Connection connection, final Key key) throws SQLException { + synchronized (cachedSessions) { + cachedSessions.remove(key); } - try { - Files.delete(file.toPath()); - } catch (IOException e) { - logger.error("Failed to delete session file {}: {}", file, e.getMessage()); + + final var sql = ( + """ + DELETE FROM %s AS s + WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ? + """ + ).formatted(TABLE_SESSION); + try (final var statement = connection.prepareStatement(sql)) { + statement.setInt(1, accountIdType); + statement.setLong(2, key.recipientId().id()); + statement.setInt(3, key.deviceId()); + statement.executeUpdate(); } } @@ -340,5 +444,5 @@ public class SessionStore implements SignalServiceSessionStore { && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION; } - private record Key(RecipientId recipientId, int deviceId) {} + record Key(RecipientId recipientId, int deviceId) {} }