1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.api
.Pair
;
4 import org
.asamk
.signal
.manager
.storage
.Database
;
5 import org
.asamk
.signal
.manager
.storage
.Utils
;
6 import org
.signal
.libsignal
.protocol
.NoSessionException
;
7 import org
.signal
.libsignal
.protocol
.SignalProtocolAddress
;
8 import org
.signal
.libsignal
.protocol
.ecc
.ECPublicKey
;
9 import org
.signal
.libsignal
.protocol
.message
.CiphertextMessage
;
10 import org
.signal
.libsignal
.protocol
.state
.SessionRecord
;
11 import org
.signal
.libsignal
.protocol
.util
.Hex
;
12 import org
.slf4j
.Logger
;
13 import org
.slf4j
.LoggerFactory
;
14 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
15 import org
.whispersystems
.signalservice
.api
.push
.ServiceId
;
16 import org
.whispersystems
.signalservice
.api
.push
.ServiceIdType
;
17 import org
.whispersystems
.signalservice
.api
.util
.UuidUtil
;
19 import java
.sql
.Connection
;
20 import java
.sql
.ResultSet
;
21 import java
.sql
.SQLException
;
22 import java
.util
.ArrayList
;
23 import java
.util
.Collection
;
24 import java
.util
.HashMap
;
25 import java
.util
.List
;
27 import java
.util
.Objects
;
29 import java
.util
.stream
.Collectors
;
31 public class SessionStore
implements SignalServiceSessionStore
{
33 private static final String TABLE_SESSION
= "session";
34 private final static Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
36 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
38 private final Database database
;
39 private final int accountIdType
;
41 public static void createSql(Connection connection
) throws SQLException
{
42 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
43 try (final var statement
= connection
.createStatement()) {
44 statement
.executeUpdate("""
45 CREATE TABLE session (
46 _id INTEGER PRIMARY KEY,
47 account_id_type INTEGER NOT NULL,
49 device_id INTEGER NOT NULL,
51 UNIQUE(account_id_type, uuid, device_id)
57 public SessionStore(final Database database
, final ServiceIdType serviceIdType
) {
58 this.database
= database
;
59 this.accountIdType
= Utils
.getAccountIdType(serviceIdType
);
63 public SessionRecord
loadSession(SignalProtocolAddress address
) {
64 final var key
= getKey(address
);
65 try (final var connection
= database
.getConnection()) {
66 final var session
= loadSession(connection
, key
);
67 return Objects
.requireNonNullElseGet(session
, SessionRecord
::new);
68 } catch (SQLException e
) {
69 throw new RuntimeException("Failed read from session store", e
);
74 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
75 final var keys
= addresses
.stream().map(this::getKey
).toList();
77 try (final var connection
= database
.getConnection()) {
78 final var sessions
= new ArrayList
<SessionRecord
>();
79 for (final var key
: keys
) {
80 final var sessionRecord
= loadSession(connection
, key
);
81 if (sessionRecord
!= null) {
82 sessions
.add(sessionRecord
);
86 if (sessions
.size() != addresses
.size()) {
87 String message
= "Mismatch! Asked for "
89 + " sessions, but only found "
93 throw new NoSessionException(message
);
97 } catch (SQLException e
) {
98 throw new RuntimeException("Failed read from session store", e
);
103 public List
<Integer
> getSubDeviceSessions(String name
) {
104 final var serviceId
= ServiceId
.parseOrThrow(name
);
105 // get all sessions for recipient except primary device session
110 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id != 1
112 ).formatted(TABLE_SESSION
);
113 try (final var connection
= database
.getConnection()) {
114 try (final var statement
= connection
.prepareStatement(sql
)) {
115 statement
.setInt(1, accountIdType
);
116 statement
.setBytes(2, serviceId
.toByteArray());
117 return Utils
.executeQueryForStream(statement
, res
-> res
.getInt("device_id")).toList();
119 } catch (SQLException e
) {
120 throw new RuntimeException("Failed read from session store", e
);
124 public boolean isCurrentRatchetKey(ServiceId serviceId
, int deviceId
, ECPublicKey ratchetKey
) {
125 final var key
= new Key(serviceId
, deviceId
);
127 try (final var connection
= database
.getConnection()) {
128 final var session
= loadSession(connection
, key
);
129 if (session
== null) {
132 return session
.currentRatchetKeyMatches(ratchetKey
);
133 } catch (SQLException e
) {
134 throw new RuntimeException("Failed read from session store", e
);
139 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
140 final var key
= getKey(address
);
142 try (final var connection
= database
.getConnection()) {
143 storeSession(connection
, key
, session
);
144 } catch (SQLException e
) {
145 throw new RuntimeException("Failed read from session store", e
);
150 public boolean containsSession(SignalProtocolAddress address
) {
151 final var key
= getKey(address
);
153 try (final var connection
= database
.getConnection()) {
154 final var session
= loadSession(connection
, key
);
155 return isActive(session
);
156 } catch (SQLException e
) {
157 throw new RuntimeException("Failed read from session store", e
);
162 public void deleteSession(SignalProtocolAddress address
) {
163 final var key
= getKey(address
);
165 try (final var connection
= database
.getConnection()) {
166 deleteSession(connection
, key
);
167 } catch (SQLException e
) {
168 throw new RuntimeException("Failed update session store", e
);
173 public void deleteAllSessions(String name
) {
174 final var serviceId
= ServiceId
.parseOrThrow(name
);
175 deleteAllSessions(serviceId
);
178 public void deleteAllSessions(ServiceId serviceId
) {
179 try (final var connection
= database
.getConnection()) {
180 deleteAllSessions(connection
, serviceId
);
181 } catch (SQLException e
) {
182 throw new RuntimeException("Failed update session store", e
);
187 public void archiveSession(final SignalProtocolAddress address
) {
188 if (!UuidUtil
.isUuid(address
.getName())) {
192 final var key
= getKey(address
);
194 try (final var connection
= database
.getConnection()) {
195 connection
.setAutoCommit(false);
196 final var session
= loadSession(connection
, key
);
197 if (session
!= null) {
198 session
.archiveCurrentState();
199 storeSession(connection
, key
, session
);
202 } catch (SQLException e
) {
203 throw new RuntimeException("Failed update session store", e
);
208 public Set
<SignalProtocolAddress
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
209 final var serviceIdsCommaSeparated
= addressNames
.stream()
210 .map(ServiceId
::parseOrThrow
)
211 .map(ServiceId
::toByteArray
)
212 .map(uuid
-> "x'" + Hex
.toStringCondensed(uuid
) + "'")
213 .collect(Collectors
.joining(","));
216 SELECT s.uuid, s.device_id, s.record
218 WHERE s.account_id_type = ? AND s.uuid IN (%s)
220 ).formatted(TABLE_SESSION
, serviceIdsCommaSeparated
);
221 try (final var connection
= database
.getConnection()) {
222 try (final var statement
= connection
.prepareStatement(sql
)) {
223 statement
.setInt(1, accountIdType
);
224 return Utils
.executeQueryForStream(statement
,
225 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
226 .filter(pair
-> isActive(pair
.second()))
228 .map(key
-> key
.serviceId().toProtocolAddress(key
.deviceId()))
229 .collect(Collectors
.toSet());
231 } catch (SQLException e
) {
232 throw new RuntimeException("Failed read from session store", e
);
236 public void archiveAllSessions() {
239 SELECT s.uuid, s.device_id, s.record
241 WHERE s.account_id_type = ?
243 ).formatted(TABLE_SESSION
);
244 try (final var connection
= database
.getConnection()) {
245 connection
.setAutoCommit(false);
246 final List
<Pair
<Key
, SessionRecord
>> records
;
247 try (final var statement
= connection
.prepareStatement(sql
)) {
248 statement
.setInt(1, accountIdType
);
249 records
= Utils
.executeQueryForStream(statement
,
250 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
))).toList();
252 for (final var record : records
) {
253 record.second().archiveCurrentState();
254 storeSession(connection
, record.first(), record.second());
257 } catch (SQLException e
) {
258 throw new RuntimeException("Failed update session store", e
);
262 public void archiveSessions(final ServiceId serviceId
) {
265 SELECT s.uuid, s.device_id, s.record
267 WHERE s.account_id_type = ? AND s.uuid = ?
269 ).formatted(TABLE_SESSION
);
270 try (final var connection
= database
.getConnection()) {
271 connection
.setAutoCommit(false);
272 final List
<Pair
<Key
, SessionRecord
>> records
;
273 try (final var statement
= connection
.prepareStatement(sql
)) {
274 statement
.setInt(1, accountIdType
);
275 statement
.setBytes(2, serviceId
.toByteArray());
276 records
= Utils
.executeQueryForStream(statement
,
277 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
))).toList();
279 for (final var record : records
) {
280 record.second().archiveCurrentState();
281 storeSession(connection
, record.first(), record.second());
284 } catch (SQLException e
) {
285 throw new RuntimeException("Failed update session store", e
);
289 void addLegacySessions(final Collection
<Pair
<Key
, SessionRecord
>> sessions
) {
290 logger
.debug("Migrating legacy sessions to database");
291 long start
= System
.nanoTime();
292 try (final var connection
= database
.getConnection()) {
293 connection
.setAutoCommit(false);
294 for (final var pair
: sessions
) {
295 storeSession(connection
, pair
.first(), pair
.second());
298 } catch (SQLException e
) {
299 throw new RuntimeException("Failed update session store", e
);
301 logger
.debug("Complete sessions migration took {}ms", (System
.nanoTime() - start
) / 1000000);
304 private Key
getKey(final SignalProtocolAddress address
) {
305 final var serviceId
= ServiceId
.parseOrThrow(address
.getName());
306 return new Key(serviceId
, address
.getDeviceId());
309 private SessionRecord
loadSession(Connection connection
, final Key key
) throws SQLException
{
310 synchronized (cachedSessions
) {
311 final var session
= cachedSessions
.get(key
);
312 if (session
!= null) {
320 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
322 ).formatted(TABLE_SESSION
);
323 try (final var statement
= connection
.prepareStatement(sql
)) {
324 statement
.setInt(1, accountIdType
);
325 statement
.setBytes(2, key
.serviceId().toByteArray());
326 statement
.setInt(3, key
.deviceId());
327 return Utils
.executeQueryForOptional(statement
, this::getSessionRecordFromResultSet
).orElse(null);
331 private Key
getKeyFromResultSet(ResultSet resultSet
) throws SQLException
{
332 final var serviceId
= ServiceId
.parseOrThrow(resultSet
.getBytes("uuid"));
333 final var deviceId
= resultSet
.getInt("device_id");
334 return new Key(serviceId
, deviceId
);
337 private SessionRecord
getSessionRecordFromResultSet(ResultSet resultSet
) throws SQLException
{
339 final var record = resultSet
.getBytes("record");
340 return new SessionRecord(record);
341 } catch (Exception e
) {
342 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
347 private void storeSession(
348 final Connection connection
, final Key key
, final SessionRecord session
349 ) throws SQLException
{
350 synchronized (cachedSessions
) {
351 cachedSessions
.put(key
, session
);
355 INSERT OR REPLACE INTO %s (account_id_type, uuid, device_id, record)
357 """.formatted(TABLE_SESSION
);
358 try (final var statement
= connection
.prepareStatement(sql
)) {
359 statement
.setInt(1, accountIdType
);
360 statement
.setBytes(2, key
.serviceId().toByteArray());
361 statement
.setInt(3, key
.deviceId());
362 statement
.setBytes(4, session
.serialize());
363 statement
.executeUpdate();
367 private void deleteAllSessions(final Connection connection
, final ServiceId serviceId
) throws SQLException
{
368 synchronized (cachedSessions
) {
369 cachedSessions
.clear();
375 WHERE s.account_id_type = ? AND s.uuid = ?
377 ).formatted(TABLE_SESSION
);
378 try (final var statement
= connection
.prepareStatement(sql
)) {
379 statement
.setInt(1, accountIdType
);
380 statement
.setBytes(2, serviceId
.toByteArray());
381 statement
.executeUpdate();
385 private void deleteSession(Connection connection
, final Key key
) throws SQLException
{
386 synchronized (cachedSessions
) {
387 cachedSessions
.remove(key
);
393 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
395 ).formatted(TABLE_SESSION
);
396 try (final var statement
= connection
.prepareStatement(sql
)) {
397 statement
.setInt(1, accountIdType
);
398 statement
.setBytes(2, key
.serviceId().toByteArray());
399 statement
.setInt(3, key
.deviceId());
400 statement
.executeUpdate();
404 private static boolean isActive(SessionRecord
record) {
405 return record != null
406 && record.hasSenderChain()
407 && record.getSessionVersion() == CiphertextMessage
.CURRENT_VERSION
;
410 record Key(ServiceId serviceId
, int deviceId
) {}