1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.api
.Pair
;
4 import org
.asamk
.signal
.manager
.storage
.Database
;
5 import org
.asamk
.signal
.manager
.storage
.Utils
;
6 import org
.asamk
.signal
.manager
.storage
.recipients
.RecipientId
;
7 import org
.asamk
.signal
.manager
.storage
.recipients
.RecipientIdCreator
;
8 import org
.asamk
.signal
.manager
.storage
.recipients
.RecipientResolver
;
9 import org
.signal
.libsignal
.protocol
.InvalidMessageException
;
10 import org
.signal
.libsignal
.protocol
.NoSessionException
;
11 import org
.signal
.libsignal
.protocol
.SignalProtocolAddress
;
12 import org
.signal
.libsignal
.protocol
.ecc
.ECPublicKey
;
13 import org
.signal
.libsignal
.protocol
.message
.CiphertextMessage
;
14 import org
.signal
.libsignal
.protocol
.state
.SessionRecord
;
15 import org
.slf4j
.Logger
;
16 import org
.slf4j
.LoggerFactory
;
17 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
18 import org
.whispersystems
.signalservice
.api
.push
.ServiceIdType
;
20 import java
.sql
.Connection
;
21 import java
.sql
.ResultSet
;
22 import java
.sql
.SQLException
;
23 import java
.util
.ArrayList
;
24 import java
.util
.Collection
;
25 import java
.util
.HashMap
;
26 import java
.util
.List
;
28 import java
.util
.Objects
;
30 import java
.util
.stream
.Collectors
;
32 public class SessionStore
implements SignalServiceSessionStore
{
34 private static final String TABLE_SESSION
= "session";
35 private final static Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
37 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
39 private final Database database
;
40 private final int accountIdType
;
41 private final RecipientResolver resolver
;
42 private final RecipientIdCreator recipientIdCreator
;
44 public static void createSql(Connection connection
) throws SQLException
{
45 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
46 try (final var statement
= connection
.createStatement()) {
47 statement
.executeUpdate("""
48 CREATE TABLE session (
49 _id INTEGER PRIMARY KEY,
50 account_id_type INTEGER NOT NULL,
51 recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
52 device_id INTEGER NOT NULL,
54 UNIQUE(account_id_type, recipient_id, device_id)
61 final Database database
,
62 final ServiceIdType serviceIdType
,
63 final RecipientResolver resolver
,
64 final RecipientIdCreator recipientIdCreator
66 this.database
= database
;
67 this.accountIdType
= Utils
.getAccountIdType(serviceIdType
);
68 this.resolver
= resolver
;
69 this.recipientIdCreator
= recipientIdCreator
;
73 public SessionRecord
loadSession(SignalProtocolAddress address
) {
74 final var key
= getKey(address
);
75 try (final var connection
= database
.getConnection()) {
76 final var session
= loadSession(connection
, key
);
77 return Objects
.requireNonNullElseGet(session
, SessionRecord
::new);
78 } catch (SQLException e
) {
79 throw new RuntimeException("Failed read from session store", e
);
84 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
85 final var keys
= addresses
.stream().map(this::getKey
).toList();
87 try (final var connection
= database
.getConnection()) {
88 final var sessions
= new ArrayList
<SessionRecord
>();
89 for (final var key
: keys
) {
90 final var sessionRecord
= loadSession(connection
, key
);
91 if (sessionRecord
!= null) {
92 sessions
.add(sessionRecord
);
96 if (sessions
.size() != addresses
.size()) {
97 String message
= "Mismatch! Asked for "
99 + " sessions, but only found "
102 logger
.warn(message
);
103 throw new NoSessionException(message
);
107 } catch (SQLException e
) {
108 throw new RuntimeException("Failed read from session store", e
);
113 public List
<Integer
> getSubDeviceSessions(String name
) {
114 final var recipientId
= resolver
.resolveRecipient(name
);
115 // get all sessions for recipient except primary device session
120 WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id != 1
122 ).formatted(TABLE_SESSION
);
123 try (final var connection
= database
.getConnection()) {
124 try (final var statement
= connection
.prepareStatement(sql
)) {
125 statement
.setInt(1, accountIdType
);
126 statement
.setLong(2, recipientId
.id());
127 return Utils
.executeQueryForStream(statement
, res
-> res
.getInt("device_id")).toList();
129 } catch (SQLException e
) {
130 throw new RuntimeException("Failed read from session store", e
);
134 public boolean isCurrentRatchetKey(RecipientId recipientId
, int deviceId
, ECPublicKey ratchetKey
) {
135 final var key
= new Key(recipientId
, deviceId
);
137 try (final var connection
= database
.getConnection()) {
138 final var session
= loadSession(connection
, key
);
139 if (session
== null) {
142 return session
.currentRatchetKeyMatches(ratchetKey
);
143 } catch (SQLException e
) {
144 throw new RuntimeException("Failed read from session store", e
);
149 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
150 final var key
= getKey(address
);
152 try (final var connection
= database
.getConnection()) {
153 storeSession(connection
, key
, session
);
154 } catch (SQLException e
) {
155 throw new RuntimeException("Failed read from session store", e
);
160 public boolean containsSession(SignalProtocolAddress address
) {
161 final var key
= getKey(address
);
163 try (final var connection
= database
.getConnection()) {
164 final var session
= loadSession(connection
, key
);
165 return isActive(session
);
166 } catch (SQLException e
) {
167 throw new RuntimeException("Failed read from session store", e
);
172 public void deleteSession(SignalProtocolAddress address
) {
173 final var key
= getKey(address
);
175 try (final var connection
= database
.getConnection()) {
176 deleteSession(connection
, key
);
177 } catch (SQLException e
) {
178 throw new RuntimeException("Failed update session store", e
);
183 public void deleteAllSessions(String name
) {
184 final var recipientId
= resolver
.resolveRecipient(name
);
185 deleteAllSessions(recipientId
);
188 public void deleteAllSessions(RecipientId recipientId
) {
189 try (final var connection
= database
.getConnection()) {
190 deleteAllSessions(connection
, recipientId
);
191 } catch (SQLException e
) {
192 throw new RuntimeException("Failed update session store", e
);
197 public void archiveSession(final SignalProtocolAddress address
) {
198 final var key
= getKey(address
);
200 try (final var connection
= database
.getConnection()) {
201 connection
.setAutoCommit(false);
202 final var session
= loadSession(connection
, key
);
203 if (session
!= null) {
204 session
.archiveCurrentState();
205 storeSession(connection
, key
, session
);
208 } catch (SQLException e
) {
209 throw new RuntimeException("Failed update session store", e
);
214 public Set
<SignalProtocolAddress
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
215 final var recipientIdToNameMap
= addressNames
.stream()
216 .collect(Collectors
.toMap(resolver
::resolveRecipient
, name
-> name
));
217 final var recipientIdsCommaSeparated
= recipientIdToNameMap
.keySet()
219 .map(recipientId
-> String
.valueOf(recipientId
.id()))
220 .collect(Collectors
.joining(","));
223 SELECT s.recipient_id, s.device_id, s.record
225 WHERE s.account_id_type = ? AND s.recipient_id IN (%s)
227 ).formatted(TABLE_SESSION
, recipientIdsCommaSeparated
);
228 try (final var connection
= database
.getConnection()) {
229 try (final var statement
= connection
.prepareStatement(sql
)) {
230 statement
.setInt(1, accountIdType
);
231 return Utils
.executeQueryForStream(statement
,
232 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
233 .filter(pair
-> isActive(pair
.second()))
235 .map(key
-> new SignalProtocolAddress(recipientIdToNameMap
.get(key
.recipientId
),
237 .collect(Collectors
.toSet());
239 } catch (SQLException e
) {
240 throw new RuntimeException("Failed read from session store", e
);
244 public void archiveAllSessions() {
247 SELECT s.recipient_id, s.device_id, s.record
249 WHERE s.account_id_type = ?
251 ).formatted(TABLE_SESSION
);
252 try (final var connection
= database
.getConnection()) {
253 connection
.setAutoCommit(false);
254 final List
<Pair
<Key
, SessionRecord
>> records
;
255 try (final var statement
= connection
.prepareStatement(sql
)) {
256 statement
.setInt(1, accountIdType
);
257 records
= Utils
.executeQueryForStream(statement
,
258 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
))).toList();
260 for (final var record : records
) {
261 record.second().archiveCurrentState();
262 storeSession(connection
, record.first(), record.second());
265 } catch (SQLException e
) {
266 throw new RuntimeException("Failed update session store", e
);
270 public void archiveSessions(final RecipientId recipientId
) {
273 SELECT s.recipient_id, s.device_id, s.record
275 WHERE s.account_id_type = ? AND s.recipient_id = ?
277 ).formatted(TABLE_SESSION
);
278 try (final var connection
= database
.getConnection()) {
279 connection
.setAutoCommit(false);
280 final List
<Pair
<Key
, SessionRecord
>> records
;
281 try (final var statement
= connection
.prepareStatement(sql
)) {
282 statement
.setInt(1, accountIdType
);
283 statement
.setLong(2, recipientId
.id());
284 records
= Utils
.executeQueryForStream(statement
,
285 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
))).toList();
287 for (final var record : records
) {
288 record.second().archiveCurrentState();
289 storeSession(connection
, record.first(), record.second());
292 } catch (SQLException e
) {
293 throw new RuntimeException("Failed update session store", e
);
297 public void mergeRecipients(RecipientId recipientId
, RecipientId toBeMergedRecipientId
) {
298 try (final var connection
= database
.getConnection()) {
299 connection
.setAutoCommit(false);
300 synchronized (cachedSessions
) {
301 cachedSessions
.clear();
307 WHERE account_id_type = ? AND recipient_id = ?
308 """.formatted(TABLE_SESSION
);
309 try (final var statement
= connection
.prepareStatement(sql
)) {
310 statement
.setLong(1, recipientId
.id());
311 statement
.setInt(2, accountIdType
);
312 statement
.setLong(3, toBeMergedRecipientId
.id());
313 final var rows
= statement
.executeUpdate();
315 logger
.debug("Reassigned {} sessions of to be merged recipient.", rows
);
318 // Delete all conflicting sessions now
319 deleteAllSessions(connection
, toBeMergedRecipientId
);
321 } catch (SQLException e
) {
322 throw new RuntimeException("Failed update session store", e
);
326 void addLegacySessions(final Collection
<Pair
<Key
, SessionRecord
>> sessions
) {
327 logger
.debug("Migrating legacy sessions to database");
328 long start
= System
.nanoTime();
329 try (final var connection
= database
.getConnection()) {
330 connection
.setAutoCommit(false);
331 for (final var pair
: sessions
) {
332 storeSession(connection
, pair
.first(), pair
.second());
335 } catch (SQLException e
) {
336 throw new RuntimeException("Failed update session store", e
);
338 logger
.debug("Complete sessions migration took {}ms", (System
.nanoTime() - start
) / 1000000);
341 private Key
getKey(final SignalProtocolAddress address
) {
342 final var recipientId
= resolver
.resolveRecipient(address
.getName());
343 return new Key(recipientId
, address
.getDeviceId());
346 private SessionRecord
loadSession(Connection connection
, final Key key
) throws SQLException
{
347 synchronized (cachedSessions
) {
348 final var session
= cachedSessions
.get(key
);
349 if (session
!= null) {
357 WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ?
359 ).formatted(TABLE_SESSION
);
360 try (final var statement
= connection
.prepareStatement(sql
)) {
361 statement
.setInt(1, accountIdType
);
362 statement
.setLong(2, key
.recipientId().id());
363 statement
.setInt(3, key
.deviceId());
364 return Utils
.executeQueryForOptional(statement
, this::getSessionRecordFromResultSet
).orElse(null);
368 private Key
getKeyFromResultSet(ResultSet resultSet
) throws SQLException
{
369 final var recipientId
= resultSet
.getLong("recipient_id");
370 final var deviceId
= resultSet
.getInt("device_id");
371 return new Key(recipientIdCreator
.create(recipientId
), deviceId
);
374 private SessionRecord
getSessionRecordFromResultSet(ResultSet resultSet
) throws SQLException
{
376 final var record = resultSet
.getBytes("record");
377 return new SessionRecord(record);
378 } catch (InvalidMessageException e
) {
379 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
384 private void storeSession(
385 final Connection connection
, final Key key
, final SessionRecord session
386 ) throws SQLException
{
387 synchronized (cachedSessions
) {
388 cachedSessions
.put(key
, session
);
392 INSERT OR REPLACE INTO %s (account_id_type, recipient_id, device_id, record)
394 """.formatted(TABLE_SESSION
);
395 try (final var statement
= connection
.prepareStatement(sql
)) {
396 statement
.setInt(1, accountIdType
);
397 statement
.setLong(2, key
.recipientId().id());
398 statement
.setInt(3, key
.deviceId());
399 statement
.setBytes(4, session
.serialize());
400 statement
.executeUpdate();
404 private void deleteAllSessions(final Connection connection
, final RecipientId recipientId
) throws SQLException
{
405 synchronized (cachedSessions
) {
406 cachedSessions
.clear();
412 WHERE s.account_id_type = ? AND s.recipient_id = ?
414 ).formatted(TABLE_SESSION
);
415 try (final var statement
= connection
.prepareStatement(sql
)) {
416 statement
.setInt(1, accountIdType
);
417 statement
.setLong(2, recipientId
.id());
418 statement
.executeUpdate();
422 private void deleteSession(Connection connection
, final Key key
) throws SQLException
{
423 synchronized (cachedSessions
) {
424 cachedSessions
.remove(key
);
430 WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ?
432 ).formatted(TABLE_SESSION
);
433 try (final var statement
= connection
.prepareStatement(sql
)) {
434 statement
.setInt(1, accountIdType
);
435 statement
.setLong(2, key
.recipientId().id());
436 statement
.setInt(3, key
.deviceId());
437 statement
.executeUpdate();
441 private static boolean isActive(SessionRecord
record) {
442 return record != null
443 && record.hasSenderChain()
444 && record.getSessionVersion() == CiphertextMessage
.CURRENT_VERSION
;
447 record Key(RecipientId recipientId
, int deviceId
) {}