1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.api
.Pair
;
4 import org
.asamk
.signal
.manager
.storage
.Database
;
5 import org
.asamk
.signal
.manager
.storage
.Utils
;
6 import org
.signal
.libsignal
.protocol
.NoSessionException
;
7 import org
.signal
.libsignal
.protocol
.SignalProtocolAddress
;
8 import org
.signal
.libsignal
.protocol
.ecc
.ECPublicKey
;
9 import org
.signal
.libsignal
.protocol
.state
.SessionRecord
;
10 import org
.slf4j
.Logger
;
11 import org
.slf4j
.LoggerFactory
;
12 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
13 import org
.whispersystems
.signalservice
.api
.push
.ServiceId
;
14 import org
.whispersystems
.signalservice
.api
.push
.ServiceIdType
;
16 import java
.sql
.Connection
;
17 import java
.sql
.ResultSet
;
18 import java
.sql
.SQLException
;
19 import java
.util
.ArrayList
;
20 import java
.util
.Collection
;
21 import java
.util
.HashMap
;
22 import java
.util
.List
;
24 import java
.util
.Objects
;
26 import java
.util
.stream
.Collectors
;
28 public class SessionStore
implements SignalServiceSessionStore
{
30 private static final String TABLE_SESSION
= "session";
31 private final static Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
33 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
35 private final Database database
;
36 private final int accountIdType
;
38 public static void createSql(Connection connection
) throws SQLException
{
39 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
40 try (final var statement
= connection
.createStatement()) {
41 statement
.executeUpdate("""
42 CREATE TABLE session (
43 _id INTEGER PRIMARY KEY,
44 account_id_type INTEGER NOT NULL,
45 address TEXT NOT NULL,
46 device_id INTEGER NOT NULL,
48 UNIQUE(account_id_type, address, device_id)
54 public SessionStore(final Database database
, final ServiceIdType serviceIdType
) {
55 this.database
= database
;
56 this.accountIdType
= Utils
.getAccountIdType(serviceIdType
);
60 public SessionRecord
loadSession(SignalProtocolAddress address
) {
61 final var key
= getKey(address
);
62 try (final var connection
= database
.getConnection()) {
63 final var session
= loadSession(connection
, key
);
64 return Objects
.requireNonNullElseGet(session
, SessionRecord
::new);
65 } catch (SQLException e
) {
66 throw new RuntimeException("Failed read from session store", e
);
71 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
72 final var keys
= addresses
.stream().map(this::getKey
).toList();
74 try (final var connection
= database
.getConnection()) {
75 final var sessions
= new ArrayList
<SessionRecord
>();
76 for (final var key
: keys
) {
77 final var sessionRecord
= loadSession(connection
, key
);
78 if (sessionRecord
!= null) {
79 sessions
.add(sessionRecord
);
83 if (sessions
.size() != addresses
.size()) {
84 String message
= "Mismatch! Asked for "
86 + " sessions, but only found "
90 throw new NoSessionException(message
);
94 } catch (SQLException e
) {
95 throw new RuntimeException("Failed read from session store", e
);
100 public List
<Integer
> getSubDeviceSessions(String name
) {
101 final var serviceId
= ServiceId
.parseOrThrow(name
);
102 // get all sessions for recipient except primary device session
107 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id != 1
109 ).formatted(TABLE_SESSION
);
110 try (final var connection
= database
.getConnection()) {
111 try (final var statement
= connection
.prepareStatement(sql
)) {
112 statement
.setInt(1, accountIdType
);
113 statement
.setString(2, serviceId
.toString());
114 return Utils
.executeQueryForStream(statement
, res
-> res
.getInt("device_id")).toList();
116 } catch (SQLException e
) {
117 throw new RuntimeException("Failed read from session store", e
);
121 public boolean isCurrentRatchetKey(ServiceId serviceId
, int deviceId
, ECPublicKey ratchetKey
) {
122 final var key
= new Key(serviceId
.toString(), deviceId
);
124 try (final var connection
= database
.getConnection()) {
125 final var session
= loadSession(connection
, key
);
126 if (session
== null) {
129 return session
.currentRatchetKeyMatches(ratchetKey
);
130 } catch (SQLException e
) {
131 throw new RuntimeException("Failed read from session store", e
);
136 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
137 final var key
= getKey(address
);
139 try (final var connection
= database
.getConnection()) {
140 storeSession(connection
, key
, session
);
141 } catch (SQLException e
) {
142 throw new RuntimeException("Failed read from session store", e
);
147 public boolean containsSession(SignalProtocolAddress address
) {
148 final var key
= getKey(address
);
150 try (final var connection
= database
.getConnection()) {
151 final var session
= loadSession(connection
, key
);
152 return isActive(session
);
153 } catch (SQLException e
) {
154 throw new RuntimeException("Failed read from session store", e
);
159 public void deleteSession(SignalProtocolAddress address
) {
160 final var key
= getKey(address
);
162 try (final var connection
= database
.getConnection()) {
163 deleteSession(connection
, key
);
164 } catch (SQLException e
) {
165 throw new RuntimeException("Failed update session store", e
);
170 public void deleteAllSessions(String name
) {
171 final var serviceId
= ServiceId
.parseOrThrow(name
);
172 deleteAllSessions(serviceId
);
175 public void deleteAllSessions(ServiceId serviceId
) {
176 try (final var connection
= database
.getConnection()) {
177 deleteAllSessions(connection
, serviceId
.toString());
178 } catch (SQLException e
) {
179 throw new RuntimeException("Failed update session store", e
);
184 public void archiveSession(final SignalProtocolAddress address
) {
185 final var key
= getKey(address
);
187 try (final var connection
= database
.getConnection()) {
188 connection
.setAutoCommit(false);
189 final var session
= loadSession(connection
, key
);
190 if (session
!= null) {
191 session
.archiveCurrentState();
192 storeSession(connection
, key
, session
);
195 } catch (SQLException e
) {
196 throw new RuntimeException("Failed update session store", e
);
201 public Set
<SignalProtocolAddress
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
202 final var serviceIdsCommaSeparated
= addressNames
.stream()
203 .map(address
-> "'" + address
.replaceAll("'", "''") + "'")
204 .collect(Collectors
.joining(","));
207 SELECT s.address, s.device_id, s.record
209 WHERE s.account_id_type = ? AND s.address IN (%s)
211 ).formatted(TABLE_SESSION
, serviceIdsCommaSeparated
);
212 try (final var connection
= database
.getConnection()) {
213 try (final var statement
= connection
.prepareStatement(sql
)) {
214 statement
.setInt(1, accountIdType
);
215 return Utils
.executeQueryForStream(statement
,
216 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
217 .filter(pair
-> isActive(pair
.second()))
219 .map(key
-> new SignalProtocolAddress(key
.address(), key
.deviceId()))
220 .collect(Collectors
.toSet());
222 } catch (SQLException e
) {
223 throw new RuntimeException("Failed read from session store", e
);
227 public void archiveAllSessions() {
230 SELECT s.address, s.device_id, s.record
232 WHERE s.account_id_type = ?
234 ).formatted(TABLE_SESSION
);
235 try (final var connection
= database
.getConnection()) {
236 connection
.setAutoCommit(false);
237 final List
<Pair
<Key
, SessionRecord
>> records
;
238 try (final var statement
= connection
.prepareStatement(sql
)) {
239 statement
.setInt(1, accountIdType
);
240 records
= Utils
.executeQueryForStream(statement
,
241 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
242 .filter(Objects
::nonNull
)
245 for (final var record : records
) {
246 record.second().archiveCurrentState();
247 storeSession(connection
, record.first(), record.second());
250 } catch (SQLException e
) {
251 throw new RuntimeException("Failed update session store", e
);
255 public void archiveSessions(final ServiceId serviceId
) {
258 SELECT s.address, s.device_id, s.record
260 WHERE s.account_id_type = ? AND s.address = ?
262 ).formatted(TABLE_SESSION
);
263 try (final var connection
= database
.getConnection()) {
264 connection
.setAutoCommit(false);
265 final List
<Pair
<Key
, SessionRecord
>> records
;
266 try (final var statement
= connection
.prepareStatement(sql
)) {
267 statement
.setInt(1, accountIdType
);
268 statement
.setString(2, serviceId
.toString());
269 records
= Utils
.executeQueryForStream(statement
,
270 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
271 .filter(Objects
::nonNull
)
274 for (final var record : records
) {
275 record.second().archiveCurrentState();
276 storeSession(connection
, record.first(), record.second());
279 } catch (SQLException e
) {
280 throw new RuntimeException("Failed update session store", e
);
284 void addLegacySessions(final Collection
<Pair
<Key
, SessionRecord
>> sessions
) {
285 logger
.debug("Migrating legacy sessions to database");
286 long start
= System
.nanoTime();
287 try (final var connection
= database
.getConnection()) {
288 connection
.setAutoCommit(false);
289 for (final var pair
: sessions
) {
290 storeSession(connection
, pair
.first(), pair
.second());
293 } catch (SQLException e
) {
294 throw new RuntimeException("Failed update session store", e
);
296 logger
.debug("Complete sessions migration took {}ms", (System
.nanoTime() - start
) / 1000000);
299 private Key
getKey(final SignalProtocolAddress address
) {
300 return new Key(address
.getName(), address
.getDeviceId());
303 private SessionRecord
loadSession(Connection connection
, final Key key
) throws SQLException
{
304 synchronized (cachedSessions
) {
305 final var session
= cachedSessions
.get(key
);
306 if (session
!= null) {
314 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
316 ).formatted(TABLE_SESSION
);
317 try (final var statement
= connection
.prepareStatement(sql
)) {
318 statement
.setInt(1, accountIdType
);
319 statement
.setString(2, key
.address());
320 statement
.setInt(3, key
.deviceId());
321 return Utils
.executeQueryForOptional(statement
, this::getSessionRecordFromResultSet
).orElse(null);
325 private Key
getKeyFromResultSet(ResultSet resultSet
) throws SQLException
{
326 final var address
= resultSet
.getString("address");
327 final var deviceId
= resultSet
.getInt("device_id");
328 return new Key(address
, deviceId
);
331 private SessionRecord
getSessionRecordFromResultSet(ResultSet resultSet
) throws SQLException
{
333 final var record = resultSet
.getBytes("record");
334 return new SessionRecord(record);
335 } catch (Exception e
) {
336 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
341 private void storeSession(
342 final Connection connection
, final Key key
, final SessionRecord session
343 ) throws SQLException
{
344 synchronized (cachedSessions
) {
345 cachedSessions
.put(key
, session
);
349 INSERT OR REPLACE INTO %s (account_id_type, address, device_id, record)
351 """.formatted(TABLE_SESSION
);
352 try (final var statement
= connection
.prepareStatement(sql
)) {
353 statement
.setInt(1, accountIdType
);
354 statement
.setString(2, key
.address());
355 statement
.setInt(3, key
.deviceId());
356 statement
.setBytes(4, session
.serialize());
357 statement
.executeUpdate();
361 private void deleteAllSessions(final Connection connection
, final String address
) throws SQLException
{
362 synchronized (cachedSessions
) {
363 cachedSessions
.clear();
369 WHERE s.account_id_type = ? AND s.address = ?
371 ).formatted(TABLE_SESSION
);
372 try (final var statement
= connection
.prepareStatement(sql
)) {
373 statement
.setInt(1, accountIdType
);
374 statement
.setString(2, address
);
375 statement
.executeUpdate();
379 private void deleteSession(Connection connection
, final Key key
) throws SQLException
{
380 synchronized (cachedSessions
) {
381 cachedSessions
.remove(key
);
387 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
389 ).formatted(TABLE_SESSION
);
390 try (final var statement
= connection
.prepareStatement(sql
)) {
391 statement
.setInt(1, accountIdType
);
392 statement
.setString(2, key
.address());
393 statement
.setInt(3, key
.deviceId());
394 statement
.executeUpdate();
398 private static boolean isActive(SessionRecord
record) {
399 return record != null && record.hasSenderChain();
402 record Key(String address
, int deviceId
) {}