1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.api
.Pair
;
4 import org
.asamk
.signal
.manager
.storage
.Database
;
5 import org
.asamk
.signal
.manager
.storage
.Utils
;
6 import org
.signal
.libsignal
.protocol
.InvalidMessageException
;
7 import org
.signal
.libsignal
.protocol
.NoSessionException
;
8 import org
.signal
.libsignal
.protocol
.SignalProtocolAddress
;
9 import org
.signal
.libsignal
.protocol
.ecc
.ECPublicKey
;
10 import org
.signal
.libsignal
.protocol
.message
.CiphertextMessage
;
11 import org
.signal
.libsignal
.protocol
.state
.SessionRecord
;
12 import org
.signal
.libsignal
.protocol
.util
.Hex
;
13 import org
.slf4j
.Logger
;
14 import org
.slf4j
.LoggerFactory
;
15 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
16 import org
.whispersystems
.signalservice
.api
.push
.ServiceId
;
17 import org
.whispersystems
.signalservice
.api
.push
.ServiceIdType
;
18 import org
.whispersystems
.signalservice
.api
.util
.UuidUtil
;
20 import java
.sql
.Connection
;
21 import java
.sql
.ResultSet
;
22 import java
.sql
.SQLException
;
23 import java
.util
.ArrayList
;
24 import java
.util
.Collection
;
25 import java
.util
.HashMap
;
26 import java
.util
.List
;
28 import java
.util
.Objects
;
30 import java
.util
.stream
.Collectors
;
32 public class SessionStore
implements SignalServiceSessionStore
{
34 private static final String TABLE_SESSION
= "session";
35 private final static Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
37 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
39 private final Database database
;
40 private final int accountIdType
;
42 public static void createSql(Connection connection
) throws SQLException
{
43 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
44 try (final var statement
= connection
.createStatement()) {
45 statement
.executeUpdate("""
46 CREATE TABLE session (
47 _id INTEGER PRIMARY KEY,
48 account_id_type INTEGER NOT NULL,
50 device_id INTEGER NOT NULL,
52 UNIQUE(account_id_type, uuid, device_id)
58 public SessionStore(final Database database
, final ServiceIdType serviceIdType
) {
59 this.database
= database
;
60 this.accountIdType
= Utils
.getAccountIdType(serviceIdType
);
64 public SessionRecord
loadSession(SignalProtocolAddress address
) {
65 final var key
= getKey(address
);
66 try (final var connection
= database
.getConnection()) {
67 final var session
= loadSession(connection
, key
);
68 return Objects
.requireNonNullElseGet(session
, SessionRecord
::new);
69 } catch (SQLException e
) {
70 throw new RuntimeException("Failed read from session store", e
);
75 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
76 final var keys
= addresses
.stream().map(this::getKey
).toList();
78 try (final var connection
= database
.getConnection()) {
79 final var sessions
= new ArrayList
<SessionRecord
>();
80 for (final var key
: keys
) {
81 final var sessionRecord
= loadSession(connection
, key
);
82 if (sessionRecord
!= null) {
83 sessions
.add(sessionRecord
);
87 if (sessions
.size() != addresses
.size()) {
88 String message
= "Mismatch! Asked for "
90 + " sessions, but only found "
94 throw new NoSessionException(message
);
98 } catch (SQLException e
) {
99 throw new RuntimeException("Failed read from session store", e
);
104 public List
<Integer
> getSubDeviceSessions(String name
) {
105 final var serviceId
= ServiceId
.parseOrThrow(name
);
106 // get all sessions for recipient except primary device session
111 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id != 1
113 ).formatted(TABLE_SESSION
);
114 try (final var connection
= database
.getConnection()) {
115 try (final var statement
= connection
.prepareStatement(sql
)) {
116 statement
.setInt(1, accountIdType
);
117 statement
.setBytes(2, serviceId
.toByteArray());
118 return Utils
.executeQueryForStream(statement
, res
-> res
.getInt("device_id")).toList();
120 } catch (SQLException e
) {
121 throw new RuntimeException("Failed read from session store", e
);
125 public boolean isCurrentRatchetKey(ServiceId serviceId
, int deviceId
, ECPublicKey ratchetKey
) {
126 final var key
= new Key(serviceId
, deviceId
);
128 try (final var connection
= database
.getConnection()) {
129 final var session
= loadSession(connection
, key
);
130 if (session
== null) {
133 return session
.currentRatchetKeyMatches(ratchetKey
);
134 } catch (SQLException e
) {
135 throw new RuntimeException("Failed read from session store", e
);
140 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
141 final var key
= getKey(address
);
143 try (final var connection
= database
.getConnection()) {
144 storeSession(connection
, key
, session
);
145 } catch (SQLException e
) {
146 throw new RuntimeException("Failed read from session store", e
);
151 public boolean containsSession(SignalProtocolAddress address
) {
152 final var key
= getKey(address
);
154 try (final var connection
= database
.getConnection()) {
155 final var session
= loadSession(connection
, key
);
156 return isActive(session
);
157 } catch (SQLException e
) {
158 throw new RuntimeException("Failed read from session store", e
);
163 public void deleteSession(SignalProtocolAddress address
) {
164 final var key
= getKey(address
);
166 try (final var connection
= database
.getConnection()) {
167 deleteSession(connection
, key
);
168 } catch (SQLException e
) {
169 throw new RuntimeException("Failed update session store", e
);
174 public void deleteAllSessions(String name
) {
175 final var serviceId
= ServiceId
.parseOrThrow(name
);
176 deleteAllSessions(serviceId
);
179 public void deleteAllSessions(ServiceId serviceId
) {
180 try (final var connection
= database
.getConnection()) {
181 deleteAllSessions(connection
, serviceId
);
182 } catch (SQLException e
) {
183 throw new RuntimeException("Failed update session store", e
);
188 public void archiveSession(final SignalProtocolAddress address
) {
189 if (!UuidUtil
.isUuid(address
.getName())) {
193 final var key
= getKey(address
);
195 try (final var connection
= database
.getConnection()) {
196 connection
.setAutoCommit(false);
197 final var session
= loadSession(connection
, key
);
198 if (session
!= null) {
199 session
.archiveCurrentState();
200 storeSession(connection
, key
, session
);
203 } catch (SQLException e
) {
204 throw new RuntimeException("Failed update session store", e
);
209 public Set
<SignalProtocolAddress
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
210 final var serviceIdsCommaSeparated
= addressNames
.stream()
211 .map(ServiceId
::parseOrThrow
)
212 .map(ServiceId
::toByteArray
)
213 .map(uuid
-> "x'" + Hex
.toStringCondensed(uuid
) + "'")
214 .collect(Collectors
.joining(","));
217 SELECT s.uuid, s.device_id, s.record
219 WHERE s.account_id_type = ? AND s.uuid IN (%s)
221 ).formatted(TABLE_SESSION
, serviceIdsCommaSeparated
);
222 try (final var connection
= database
.getConnection()) {
223 try (final var statement
= connection
.prepareStatement(sql
)) {
224 statement
.setInt(1, accountIdType
);
225 return Utils
.executeQueryForStream(statement
,
226 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
227 .filter(pair
-> isActive(pair
.second()))
229 .map(key
-> key
.serviceId().toProtocolAddress(key
.deviceId()))
230 .collect(Collectors
.toSet());
232 } catch (SQLException e
) {
233 throw new RuntimeException("Failed read from session store", e
);
237 public void archiveAllSessions() {
240 SELECT s.uuid, s.device_id, s.record
242 WHERE s.account_id_type = ?
244 ).formatted(TABLE_SESSION
);
245 try (final var connection
= database
.getConnection()) {
246 connection
.setAutoCommit(false);
247 final List
<Pair
<Key
, SessionRecord
>> records
;
248 try (final var statement
= connection
.prepareStatement(sql
)) {
249 statement
.setInt(1, accountIdType
);
250 records
= Utils
.executeQueryForStream(statement
,
251 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
))).toList();
253 for (final var record : records
) {
254 record.second().archiveCurrentState();
255 storeSession(connection
, record.first(), record.second());
258 } catch (SQLException e
) {
259 throw new RuntimeException("Failed update session store", e
);
263 public void archiveSessions(final ServiceId serviceId
) {
266 SELECT s.uuid, s.device_id, s.record
268 WHERE s.account_id_type = ? AND s.uuid = ?
270 ).formatted(TABLE_SESSION
);
271 try (final var connection
= database
.getConnection()) {
272 connection
.setAutoCommit(false);
273 final List
<Pair
<Key
, SessionRecord
>> records
;
274 try (final var statement
= connection
.prepareStatement(sql
)) {
275 statement
.setInt(1, accountIdType
);
276 statement
.setBytes(2, serviceId
.toByteArray());
277 records
= Utils
.executeQueryForStream(statement
,
278 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
))).toList();
280 for (final var record : records
) {
281 record.second().archiveCurrentState();
282 storeSession(connection
, record.first(), record.second());
285 } catch (SQLException e
) {
286 throw new RuntimeException("Failed update session store", e
);
290 void addLegacySessions(final Collection
<Pair
<Key
, SessionRecord
>> sessions
) {
291 logger
.debug("Migrating legacy sessions to database");
292 long start
= System
.nanoTime();
293 try (final var connection
= database
.getConnection()) {
294 connection
.setAutoCommit(false);
295 for (final var pair
: sessions
) {
296 storeSession(connection
, pair
.first(), pair
.second());
299 } catch (SQLException e
) {
300 throw new RuntimeException("Failed update session store", e
);
302 logger
.debug("Complete sessions migration took {}ms", (System
.nanoTime() - start
) / 1000000);
305 private Key
getKey(final SignalProtocolAddress address
) {
306 final var serviceId
= ServiceId
.parseOrThrow(address
.getName());
307 return new Key(serviceId
, address
.getDeviceId());
310 private SessionRecord
loadSession(Connection connection
, final Key key
) throws SQLException
{
311 synchronized (cachedSessions
) {
312 final var session
= cachedSessions
.get(key
);
313 if (session
!= null) {
321 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
323 ).formatted(TABLE_SESSION
);
324 try (final var statement
= connection
.prepareStatement(sql
)) {
325 statement
.setInt(1, accountIdType
);
326 statement
.setBytes(2, key
.serviceId().toByteArray());
327 statement
.setInt(3, key
.deviceId());
328 return Utils
.executeQueryForOptional(statement
, this::getSessionRecordFromResultSet
).orElse(null);
332 private Key
getKeyFromResultSet(ResultSet resultSet
) throws SQLException
{
333 final var serviceId
= ServiceId
.parseOrThrow(resultSet
.getBytes("uuid"));
334 final var deviceId
= resultSet
.getInt("device_id");
335 return new Key(serviceId
, deviceId
);
338 private SessionRecord
getSessionRecordFromResultSet(ResultSet resultSet
) throws SQLException
{
340 final var record = resultSet
.getBytes("record");
341 return new SessionRecord(record);
342 } catch (InvalidMessageException e
) {
343 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
348 private void storeSession(
349 final Connection connection
, final Key key
, final SessionRecord session
350 ) throws SQLException
{
351 synchronized (cachedSessions
) {
352 cachedSessions
.put(key
, session
);
356 INSERT OR REPLACE INTO %s (account_id_type, uuid, device_id, record)
358 """.formatted(TABLE_SESSION
);
359 try (final var statement
= connection
.prepareStatement(sql
)) {
360 statement
.setInt(1, accountIdType
);
361 statement
.setBytes(2, key
.serviceId().toByteArray());
362 statement
.setInt(3, key
.deviceId());
363 statement
.setBytes(4, session
.serialize());
364 statement
.executeUpdate();
368 private void deleteAllSessions(final Connection connection
, final ServiceId serviceId
) throws SQLException
{
369 synchronized (cachedSessions
) {
370 cachedSessions
.clear();
376 WHERE s.account_id_type = ? AND s.uuid = ?
378 ).formatted(TABLE_SESSION
);
379 try (final var statement
= connection
.prepareStatement(sql
)) {
380 statement
.setInt(1, accountIdType
);
381 statement
.setBytes(2, serviceId
.toByteArray());
382 statement
.executeUpdate();
386 private void deleteSession(Connection connection
, final Key key
) throws SQLException
{
387 synchronized (cachedSessions
) {
388 cachedSessions
.remove(key
);
394 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
396 ).formatted(TABLE_SESSION
);
397 try (final var statement
= connection
.prepareStatement(sql
)) {
398 statement
.setInt(1, accountIdType
);
399 statement
.setBytes(2, key
.serviceId().toByteArray());
400 statement
.setInt(3, key
.deviceId());
401 statement
.executeUpdate();
405 private static boolean isActive(SessionRecord
record) {
406 return record != null
407 && record.hasSenderChain()
408 && record.getSessionVersion() == CiphertextMessage
.CURRENT_VERSION
;
411 record Key(ServiceId serviceId
, int deviceId
) {}