1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.api
.Pair
;
4 import org
.asamk
.signal
.manager
.storage
.Database
;
5 import org
.asamk
.signal
.manager
.storage
.Utils
;
6 import org
.signal
.libsignal
.protocol
.InvalidMessageException
;
7 import org
.signal
.libsignal
.protocol
.NoSessionException
;
8 import org
.signal
.libsignal
.protocol
.SignalProtocolAddress
;
9 import org
.signal
.libsignal
.protocol
.ecc
.ECPublicKey
;
10 import org
.signal
.libsignal
.protocol
.message
.CiphertextMessage
;
11 import org
.signal
.libsignal
.protocol
.state
.SessionRecord
;
12 import org
.slf4j
.Logger
;
13 import org
.slf4j
.LoggerFactory
;
14 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
15 import org
.whispersystems
.signalservice
.api
.push
.ServiceId
;
16 import org
.whispersystems
.signalservice
.api
.push
.ServiceIdType
;
17 import org
.whispersystems
.signalservice
.api
.util
.UuidUtil
;
19 import java
.sql
.Connection
;
20 import java
.sql
.ResultSet
;
21 import java
.sql
.SQLException
;
22 import java
.util
.ArrayList
;
23 import java
.util
.Collection
;
24 import java
.util
.HashMap
;
25 import java
.util
.List
;
27 import java
.util
.Objects
;
29 import java
.util
.stream
.Collectors
;
31 public class SessionStore
implements SignalServiceSessionStore
{
33 private static final String TABLE_SESSION
= "session";
34 private final static Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
36 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
38 private final Database database
;
39 private final int accountIdType
;
41 public static void createSql(Connection connection
) throws SQLException
{
42 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
43 try (final var statement
= connection
.createStatement()) {
44 statement
.executeUpdate("""
45 CREATE TABLE session (
46 _id INTEGER PRIMARY KEY,
47 account_id_type INTEGER NOT NULL,
49 device_id INTEGER NOT NULL,
51 UNIQUE(account_id_type, uuid, device_id)
57 public SessionStore(final Database database
, final ServiceIdType serviceIdType
) {
58 this.database
= database
;
59 this.accountIdType
= Utils
.getAccountIdType(serviceIdType
);
63 public SessionRecord
loadSession(SignalProtocolAddress address
) {
64 final var key
= getKey(address
);
65 try (final var connection
= database
.getConnection()) {
66 final var session
= loadSession(connection
, key
);
67 return Objects
.requireNonNullElseGet(session
, SessionRecord
::new);
68 } catch (SQLException e
) {
69 throw new RuntimeException("Failed read from session store", e
);
74 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
75 final var keys
= addresses
.stream().map(this::getKey
).toList();
77 try (final var connection
= database
.getConnection()) {
78 final var sessions
= new ArrayList
<SessionRecord
>();
79 for (final var key
: keys
) {
80 final var sessionRecord
= loadSession(connection
, key
);
81 if (sessionRecord
!= null) {
82 sessions
.add(sessionRecord
);
86 if (sessions
.size() != addresses
.size()) {
87 String message
= "Mismatch! Asked for "
89 + " sessions, but only found "
93 throw new NoSessionException(message
);
97 } catch (SQLException e
) {
98 throw new RuntimeException("Failed read from session store", e
);
103 public List
<Integer
> getSubDeviceSessions(String name
) {
104 final var serviceId
= ServiceId
.parseOrThrow(name
);
105 // get all sessions for recipient except primary device session
110 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id != 1
112 ).formatted(TABLE_SESSION
);
113 try (final var connection
= database
.getConnection()) {
114 try (final var statement
= connection
.prepareStatement(sql
)) {
115 statement
.setInt(1, accountIdType
);
116 statement
.setBytes(2, serviceId
.toByteArray());
117 return Utils
.executeQueryForStream(statement
, res
-> res
.getInt("device_id")).toList();
119 } catch (SQLException e
) {
120 throw new RuntimeException("Failed read from session store", e
);
124 public boolean isCurrentRatchetKey(ServiceId serviceId
, int deviceId
, ECPublicKey ratchetKey
) {
125 final var key
= new Key(serviceId
, deviceId
);
127 try (final var connection
= database
.getConnection()) {
128 final var session
= loadSession(connection
, key
);
129 if (session
== null) {
132 return session
.currentRatchetKeyMatches(ratchetKey
);
133 } catch (SQLException e
) {
134 throw new RuntimeException("Failed read from session store", e
);
139 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
140 final var key
= getKey(address
);
142 try (final var connection
= database
.getConnection()) {
143 storeSession(connection
, key
, session
);
144 } catch (SQLException e
) {
145 throw new RuntimeException("Failed read from session store", e
);
150 public boolean containsSession(SignalProtocolAddress address
) {
151 final var key
= getKey(address
);
153 try (final var connection
= database
.getConnection()) {
154 final var session
= loadSession(connection
, key
);
155 return isActive(session
);
156 } catch (SQLException e
) {
157 throw new RuntimeException("Failed read from session store", e
);
162 public void deleteSession(SignalProtocolAddress address
) {
163 final var key
= getKey(address
);
165 try (final var connection
= database
.getConnection()) {
166 deleteSession(connection
, key
);
167 } catch (SQLException e
) {
168 throw new RuntimeException("Failed update session store", e
);
173 public void deleteAllSessions(String name
) {
174 final var serviceId
= ServiceId
.parseOrThrow(name
);
175 deleteAllSessions(serviceId
);
178 public void deleteAllSessions(ServiceId serviceId
) {
179 try (final var connection
= database
.getConnection()) {
180 deleteAllSessions(connection
, serviceId
);
181 } catch (SQLException e
) {
182 throw new RuntimeException("Failed update session store", e
);
187 public void archiveSession(final SignalProtocolAddress address
) {
188 if (!UuidUtil
.isUuid(address
.getName())) {
192 final var key
= getKey(address
);
194 try (final var connection
= database
.getConnection()) {
195 connection
.setAutoCommit(false);
196 final var session
= loadSession(connection
, key
);
197 if (session
!= null) {
198 session
.archiveCurrentState();
199 storeSession(connection
, key
, session
);
202 } catch (SQLException e
) {
203 throw new RuntimeException("Failed update session store", e
);
208 public Set
<SignalProtocolAddress
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
209 final var serviceIdsCommaSeparated
= addressNames
.stream()
210 .map(ServiceId
::parseOrThrow
)
211 .map(ServiceId
::toString
)
212 .collect(Collectors
.joining(","));
215 SELECT s.uuid, s.device_id, s.record
217 WHERE s.account_id_type = ? AND s.uuid IN (%s)
219 ).formatted(TABLE_SESSION
, serviceIdsCommaSeparated
);
220 try (final var connection
= database
.getConnection()) {
221 try (final var statement
= connection
.prepareStatement(sql
)) {
222 statement
.setInt(1, accountIdType
);
223 return Utils
.executeQueryForStream(statement
,
224 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
225 .filter(pair
-> isActive(pair
.second()))
227 .map(key
-> key
.serviceId().toProtocolAddress(key
.deviceId()))
228 .collect(Collectors
.toSet());
230 } catch (SQLException e
) {
231 throw new RuntimeException("Failed read from session store", e
);
235 public void archiveAllSessions() {
238 SELECT s.uuid, s.device_id, s.record
240 WHERE s.account_id_type = ?
242 ).formatted(TABLE_SESSION
);
243 try (final var connection
= database
.getConnection()) {
244 connection
.setAutoCommit(false);
245 final List
<Pair
<Key
, SessionRecord
>> records
;
246 try (final var statement
= connection
.prepareStatement(sql
)) {
247 statement
.setInt(1, accountIdType
);
248 records
= Utils
.executeQueryForStream(statement
,
249 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
))).toList();
251 for (final var record : records
) {
252 record.second().archiveCurrentState();
253 storeSession(connection
, record.first(), record.second());
256 } catch (SQLException e
) {
257 throw new RuntimeException("Failed update session store", e
);
261 public void archiveSessions(final ServiceId serviceId
) {
264 SELECT s.uuid, s.device_id, s.record
266 WHERE s.account_id_type = ? AND s.uuid = ?
268 ).formatted(TABLE_SESSION
);
269 try (final var connection
= database
.getConnection()) {
270 connection
.setAutoCommit(false);
271 final List
<Pair
<Key
, SessionRecord
>> records
;
272 try (final var statement
= connection
.prepareStatement(sql
)) {
273 statement
.setInt(1, accountIdType
);
274 statement
.setBytes(2, serviceId
.toByteArray());
275 records
= Utils
.executeQueryForStream(statement
,
276 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
))).toList();
278 for (final var record : records
) {
279 record.second().archiveCurrentState();
280 storeSession(connection
, record.first(), record.second());
283 } catch (SQLException e
) {
284 throw new RuntimeException("Failed update session store", e
);
288 void addLegacySessions(final Collection
<Pair
<Key
, SessionRecord
>> sessions
) {
289 logger
.debug("Migrating legacy sessions to database");
290 long start
= System
.nanoTime();
291 try (final var connection
= database
.getConnection()) {
292 connection
.setAutoCommit(false);
293 for (final var pair
: sessions
) {
294 storeSession(connection
, pair
.first(), pair
.second());
297 } catch (SQLException e
) {
298 throw new RuntimeException("Failed update session store", e
);
300 logger
.debug("Complete sessions migration took {}ms", (System
.nanoTime() - start
) / 1000000);
303 private Key
getKey(final SignalProtocolAddress address
) {
304 final var serviceId
= ServiceId
.parseOrThrow(address
.getName());
305 return new Key(serviceId
, address
.getDeviceId());
308 private SessionRecord
loadSession(Connection connection
, final Key key
) throws SQLException
{
309 synchronized (cachedSessions
) {
310 final var session
= cachedSessions
.get(key
);
311 if (session
!= null) {
319 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
321 ).formatted(TABLE_SESSION
);
322 try (final var statement
= connection
.prepareStatement(sql
)) {
323 statement
.setInt(1, accountIdType
);
324 statement
.setBytes(2, key
.serviceId().toByteArray());
325 statement
.setInt(3, key
.deviceId());
326 return Utils
.executeQueryForOptional(statement
, this::getSessionRecordFromResultSet
).orElse(null);
330 private Key
getKeyFromResultSet(ResultSet resultSet
) throws SQLException
{
331 final var serviceId
= ServiceId
.parseOrThrow(resultSet
.getBytes("uuid"));
332 final var deviceId
= resultSet
.getInt("device_id");
333 return new Key(serviceId
, deviceId
);
336 private SessionRecord
getSessionRecordFromResultSet(ResultSet resultSet
) throws SQLException
{
338 final var record = resultSet
.getBytes("record");
339 return new SessionRecord(record);
340 } catch (InvalidMessageException e
) {
341 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
346 private void storeSession(
347 final Connection connection
, final Key key
, final SessionRecord session
348 ) throws SQLException
{
349 synchronized (cachedSessions
) {
350 cachedSessions
.put(key
, session
);
354 INSERT OR REPLACE INTO %s (account_id_type, uuid, device_id, record)
356 """.formatted(TABLE_SESSION
);
357 try (final var statement
= connection
.prepareStatement(sql
)) {
358 statement
.setInt(1, accountIdType
);
359 statement
.setBytes(2, key
.serviceId().toByteArray());
360 statement
.setInt(3, key
.deviceId());
361 statement
.setBytes(4, session
.serialize());
362 statement
.executeUpdate();
366 private void deleteAllSessions(final Connection connection
, final ServiceId serviceId
) throws SQLException
{
367 synchronized (cachedSessions
) {
368 cachedSessions
.clear();
374 WHERE s.account_id_type = ? AND s.uuid = ?
376 ).formatted(TABLE_SESSION
);
377 try (final var statement
= connection
.prepareStatement(sql
)) {
378 statement
.setInt(1, accountIdType
);
379 statement
.setBytes(2, serviceId
.toByteArray());
380 statement
.executeUpdate();
384 private void deleteSession(Connection connection
, final Key key
) throws SQLException
{
385 synchronized (cachedSessions
) {
386 cachedSessions
.remove(key
);
392 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
394 ).formatted(TABLE_SESSION
);
395 try (final var statement
= connection
.prepareStatement(sql
)) {
396 statement
.setInt(1, accountIdType
);
397 statement
.setBytes(2, key
.serviceId().toByteArray());
398 statement
.setInt(3, key
.deviceId());
399 statement
.executeUpdate();
403 private static boolean isActive(SessionRecord
record) {
404 return record != null
405 && record.hasSenderChain()
406 && record.getSessionVersion() == CiphertextMessage
.CURRENT_VERSION
;
409 record Key(ServiceId serviceId
, int deviceId
) {}