1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.api
.Pair
;
4 import org
.asamk
.signal
.manager
.storage
.Database
;
5 import org
.asamk
.signal
.manager
.storage
.Utils
;
6 import org
.signal
.libsignal
.protocol
.NoSessionException
;
7 import org
.signal
.libsignal
.protocol
.SignalProtocolAddress
;
8 import org
.signal
.libsignal
.protocol
.ecc
.ECPublicKey
;
9 import org
.signal
.libsignal
.protocol
.message
.CiphertextMessage
;
10 import org
.signal
.libsignal
.protocol
.state
.SessionRecord
;
11 import org
.signal
.libsignal
.protocol
.util
.Hex
;
12 import org
.slf4j
.Logger
;
13 import org
.slf4j
.LoggerFactory
;
14 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
15 import org
.whispersystems
.signalservice
.api
.push
.ServiceId
;
16 import org
.whispersystems
.signalservice
.api
.push
.ServiceIdType
;
17 import org
.whispersystems
.signalservice
.api
.util
.UuidUtil
;
19 import java
.sql
.Connection
;
20 import java
.sql
.ResultSet
;
21 import java
.sql
.SQLException
;
22 import java
.util
.ArrayList
;
23 import java
.util
.Collection
;
24 import java
.util
.HashMap
;
25 import java
.util
.List
;
27 import java
.util
.Objects
;
29 import java
.util
.stream
.Collectors
;
31 public class SessionStore
implements SignalServiceSessionStore
{
33 private static final String TABLE_SESSION
= "session";
34 private final static Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
36 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
38 private final Database database
;
39 private final int accountIdType
;
41 public static void createSql(Connection connection
) throws SQLException
{
42 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
43 try (final var statement
= connection
.createStatement()) {
44 statement
.executeUpdate("""
45 CREATE TABLE session (
46 _id INTEGER PRIMARY KEY,
47 account_id_type INTEGER NOT NULL,
49 device_id INTEGER NOT NULL,
51 UNIQUE(account_id_type, uuid, device_id)
57 public SessionStore(final Database database
, final ServiceIdType serviceIdType
) {
58 this.database
= database
;
59 this.accountIdType
= Utils
.getAccountIdType(serviceIdType
);
63 public SessionRecord
loadSession(SignalProtocolAddress address
) {
64 final var key
= getKey(address
);
65 try (final var connection
= database
.getConnection()) {
66 final var session
= loadSession(connection
, key
);
67 return Objects
.requireNonNullElseGet(session
, SessionRecord
::new);
68 } catch (SQLException e
) {
69 throw new RuntimeException("Failed read from session store", e
);
74 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
75 final var keys
= addresses
.stream().map(this::getKey
).toList();
77 try (final var connection
= database
.getConnection()) {
78 final var sessions
= new ArrayList
<SessionRecord
>();
79 for (final var key
: keys
) {
80 final var sessionRecord
= loadSession(connection
, key
);
81 if (sessionRecord
!= null) {
82 sessions
.add(sessionRecord
);
86 if (sessions
.size() != addresses
.size()) {
87 String message
= "Mismatch! Asked for "
89 + " sessions, but only found "
93 throw new NoSessionException(message
);
97 } catch (SQLException e
) {
98 throw new RuntimeException("Failed read from session store", e
);
103 public List
<Integer
> getSubDeviceSessions(String name
) {
104 final var serviceId
= ServiceId
.parseOrThrow(name
);
105 // get all sessions for recipient except primary device session
110 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id != 1
112 ).formatted(TABLE_SESSION
);
113 try (final var connection
= database
.getConnection()) {
114 try (final var statement
= connection
.prepareStatement(sql
)) {
115 statement
.setInt(1, accountIdType
);
116 statement
.setBytes(2, serviceId
.toByteArray());
117 return Utils
.executeQueryForStream(statement
, res
-> res
.getInt("device_id")).toList();
119 } catch (SQLException e
) {
120 throw new RuntimeException("Failed read from session store", e
);
124 public boolean isCurrentRatchetKey(ServiceId serviceId
, int deviceId
, ECPublicKey ratchetKey
) {
125 final var key
= new Key(serviceId
, deviceId
);
127 try (final var connection
= database
.getConnection()) {
128 final var session
= loadSession(connection
, key
);
129 if (session
== null) {
132 return session
.currentRatchetKeyMatches(ratchetKey
);
133 } catch (SQLException e
) {
134 throw new RuntimeException("Failed read from session store", e
);
139 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
140 final var key
= getKey(address
);
142 try (final var connection
= database
.getConnection()) {
143 storeSession(connection
, key
, session
);
144 } catch (SQLException e
) {
145 throw new RuntimeException("Failed read from session store", e
);
150 public boolean containsSession(SignalProtocolAddress address
) {
151 final var key
= getKey(address
);
153 try (final var connection
= database
.getConnection()) {
154 final var session
= loadSession(connection
, key
);
155 return isActive(session
);
156 } catch (SQLException e
) {
157 throw new RuntimeException("Failed read from session store", e
);
162 public void deleteSession(SignalProtocolAddress address
) {
163 final var key
= getKey(address
);
165 try (final var connection
= database
.getConnection()) {
166 deleteSession(connection
, key
);
167 } catch (SQLException e
) {
168 throw new RuntimeException("Failed update session store", e
);
173 public void deleteAllSessions(String name
) {
174 final var serviceId
= ServiceId
.parseOrThrow(name
);
175 deleteAllSessions(serviceId
);
178 public void deleteAllSessions(ServiceId serviceId
) {
179 try (final var connection
= database
.getConnection()) {
180 deleteAllSessions(connection
, serviceId
);
181 } catch (SQLException e
) {
182 throw new RuntimeException("Failed update session store", e
);
187 public void archiveSession(final SignalProtocolAddress address
) {
188 if (!UuidUtil
.isUuid(address
.getName())) {
192 final var key
= getKey(address
);
194 try (final var connection
= database
.getConnection()) {
195 connection
.setAutoCommit(false);
196 final var session
= loadSession(connection
, key
);
197 if (session
!= null) {
198 session
.archiveCurrentState();
199 storeSession(connection
, key
, session
);
202 } catch (SQLException e
) {
203 throw new RuntimeException("Failed update session store", e
);
208 public Set
<SignalProtocolAddress
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
209 final var serviceIdsCommaSeparated
= addressNames
.stream()
210 .map(ServiceId
::parseOrThrow
)
211 .map(ServiceId
::toByteArray
)
212 .map(uuid
-> "x'" + Hex
.toStringCondensed(uuid
) + "'")
213 .collect(Collectors
.joining(","));
216 SELECT s.uuid, s.device_id, s.record
218 WHERE s.account_id_type = ? AND s.uuid IN (%s)
220 ).formatted(TABLE_SESSION
, serviceIdsCommaSeparated
);
221 try (final var connection
= database
.getConnection()) {
222 try (final var statement
= connection
.prepareStatement(sql
)) {
223 statement
.setInt(1, accountIdType
);
224 return Utils
.executeQueryForStream(statement
,
225 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
226 .filter(pair
-> isActive(pair
.second()))
228 .map(key
-> key
.serviceId().toProtocolAddress(key
.deviceId()))
229 .collect(Collectors
.toSet());
231 } catch (SQLException e
) {
232 throw new RuntimeException("Failed read from session store", e
);
236 public void archiveAllSessions() {
239 SELECT s.uuid, s.device_id, s.record
241 WHERE s.account_id_type = ?
243 ).formatted(TABLE_SESSION
);
244 try (final var connection
= database
.getConnection()) {
245 connection
.setAutoCommit(false);
246 final List
<Pair
<Key
, SessionRecord
>> records
;
247 try (final var statement
= connection
.prepareStatement(sql
)) {
248 statement
.setInt(1, accountIdType
);
249 records
= Utils
.executeQueryForStream(statement
,
250 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
251 .filter(Objects
::nonNull
)
254 for (final var record : records
) {
255 record.second().archiveCurrentState();
256 storeSession(connection
, record.first(), record.second());
259 } catch (SQLException e
) {
260 throw new RuntimeException("Failed update session store", e
);
264 public void archiveSessions(final ServiceId serviceId
) {
267 SELECT s.uuid, s.device_id, s.record
269 WHERE s.account_id_type = ? AND s.uuid = ?
271 ).formatted(TABLE_SESSION
);
272 try (final var connection
= database
.getConnection()) {
273 connection
.setAutoCommit(false);
274 final List
<Pair
<Key
, SessionRecord
>> records
;
275 try (final var statement
= connection
.prepareStatement(sql
)) {
276 statement
.setInt(1, accountIdType
);
277 statement
.setBytes(2, serviceId
.toByteArray());
278 records
= Utils
.executeQueryForStream(statement
,
279 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
280 .filter(Objects
::nonNull
)
283 for (final var record : records
) {
284 record.second().archiveCurrentState();
285 storeSession(connection
, record.first(), record.second());
288 } catch (SQLException e
) {
289 throw new RuntimeException("Failed update session store", e
);
293 void addLegacySessions(final Collection
<Pair
<Key
, SessionRecord
>> sessions
) {
294 logger
.debug("Migrating legacy sessions to database");
295 long start
= System
.nanoTime();
296 try (final var connection
= database
.getConnection()) {
297 connection
.setAutoCommit(false);
298 for (final var pair
: sessions
) {
299 storeSession(connection
, pair
.first(), pair
.second());
302 } catch (SQLException e
) {
303 throw new RuntimeException("Failed update session store", e
);
305 logger
.debug("Complete sessions migration took {}ms", (System
.nanoTime() - start
) / 1000000);
308 private Key
getKey(final SignalProtocolAddress address
) {
309 final var serviceId
= ServiceId
.parseOrThrow(address
.getName());
310 return new Key(serviceId
, address
.getDeviceId());
313 private SessionRecord
loadSession(Connection connection
, final Key key
) throws SQLException
{
314 synchronized (cachedSessions
) {
315 final var session
= cachedSessions
.get(key
);
316 if (session
!= null) {
324 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
326 ).formatted(TABLE_SESSION
);
327 try (final var statement
= connection
.prepareStatement(sql
)) {
328 statement
.setInt(1, accountIdType
);
329 statement
.setBytes(2, key
.serviceId().toByteArray());
330 statement
.setInt(3, key
.deviceId());
331 return Utils
.executeQueryForOptional(statement
, this::getSessionRecordFromResultSet
).orElse(null);
335 private Key
getKeyFromResultSet(ResultSet resultSet
) throws SQLException
{
336 final var serviceId
= ServiceId
.parseOrThrow(resultSet
.getBytes("uuid"));
337 final var deviceId
= resultSet
.getInt("device_id");
338 return new Key(serviceId
, deviceId
);
341 private SessionRecord
getSessionRecordFromResultSet(ResultSet resultSet
) throws SQLException
{
343 final var record = resultSet
.getBytes("record");
344 return new SessionRecord(record);
345 } catch (Exception e
) {
346 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
351 private void storeSession(
352 final Connection connection
, final Key key
, final SessionRecord session
353 ) throws SQLException
{
354 synchronized (cachedSessions
) {
355 cachedSessions
.put(key
, session
);
359 INSERT OR REPLACE INTO %s (account_id_type, uuid, device_id, record)
361 """.formatted(TABLE_SESSION
);
362 try (final var statement
= connection
.prepareStatement(sql
)) {
363 statement
.setInt(1, accountIdType
);
364 statement
.setBytes(2, key
.serviceId().toByteArray());
365 statement
.setInt(3, key
.deviceId());
366 statement
.setBytes(4, session
.serialize());
367 statement
.executeUpdate();
371 private void deleteAllSessions(final Connection connection
, final ServiceId serviceId
) throws SQLException
{
372 synchronized (cachedSessions
) {
373 cachedSessions
.clear();
379 WHERE s.account_id_type = ? AND s.uuid = ?
381 ).formatted(TABLE_SESSION
);
382 try (final var statement
= connection
.prepareStatement(sql
)) {
383 statement
.setInt(1, accountIdType
);
384 statement
.setBytes(2, serviceId
.toByteArray());
385 statement
.executeUpdate();
389 private void deleteSession(Connection connection
, final Key key
) throws SQLException
{
390 synchronized (cachedSessions
) {
391 cachedSessions
.remove(key
);
397 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
399 ).formatted(TABLE_SESSION
);
400 try (final var statement
= connection
.prepareStatement(sql
)) {
401 statement
.setInt(1, accountIdType
);
402 statement
.setBytes(2, key
.serviceId().toByteArray());
403 statement
.setInt(3, key
.deviceId());
404 statement
.executeUpdate();
408 private static boolean isActive(SessionRecord
record) {
409 return record != null
410 && record.hasSenderChain()
411 && record.getSessionVersion() == CiphertextMessage
.CURRENT_VERSION
;
414 record Key(ServiceId serviceId
, int deviceId
) {}