1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.api
.Pair
;
4 import org
.asamk
.signal
.manager
.storage
.Database
;
5 import org
.asamk
.signal
.manager
.storage
.Utils
;
6 import org
.signal
.libsignal
.protocol
.NoSessionException
;
7 import org
.signal
.libsignal
.protocol
.SignalProtocolAddress
;
8 import org
.signal
.libsignal
.protocol
.ecc
.ECPublicKey
;
9 import org
.signal
.libsignal
.protocol
.state
.SessionRecord
;
10 import org
.slf4j
.Logger
;
11 import org
.slf4j
.LoggerFactory
;
12 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
13 import org
.whispersystems
.signalservice
.api
.push
.ServiceId
;
14 import org
.whispersystems
.signalservice
.api
.push
.ServiceIdType
;
16 import java
.sql
.Connection
;
17 import java
.sql
.ResultSet
;
18 import java
.sql
.SQLException
;
19 import java
.util
.ArrayList
;
20 import java
.util
.Collection
;
21 import java
.util
.HashMap
;
22 import java
.util
.List
;
24 import java
.util
.Objects
;
25 import java
.util
.stream
.Collectors
;
27 public class SessionStore
implements SignalServiceSessionStore
{
29 private static final String TABLE_SESSION
= "session";
30 private static final Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
32 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
34 private final Database database
;
35 private final int accountIdType
;
37 public static void createSql(Connection connection
) throws SQLException
{
38 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
39 try (final var statement
= connection
.createStatement()) {
40 statement
.executeUpdate("""
41 CREATE TABLE session (
42 _id INTEGER PRIMARY KEY,
43 account_id_type INTEGER NOT NULL,
44 address TEXT NOT NULL,
45 device_id INTEGER NOT NULL,
47 UNIQUE(account_id_type, address, device_id)
53 public SessionStore(final Database database
, final ServiceIdType serviceIdType
) {
54 this.database
= database
;
55 this.accountIdType
= Utils
.getAccountIdType(serviceIdType
);
59 public SessionRecord
loadSession(SignalProtocolAddress address
) {
60 final var key
= getKey(address
);
61 try (final var connection
= database
.getConnection()) {
62 final var sessionRecord
= Objects
.requireNonNullElseGet(loadSession(connection
, key
), SessionRecord
::new);
63 synchronized (cachedSessions
) {
64 cachedSessions
.put(key
, sessionRecord
);
67 } catch (SQLException e
) {
68 throw new RuntimeException("Failed read from session store", e
);
73 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
74 final var keys
= addresses
.stream().map(this::getKey
).toList();
76 try (final var connection
= database
.getConnection()) {
77 final var sessions
= new ArrayList
<SessionRecord
>();
78 for (final var key
: keys
) {
79 final var sessionRecord
= loadSession(connection
, key
);
80 if (sessionRecord
!= null) {
81 sessions
.add(sessionRecord
);
85 if (sessions
.size() != addresses
.size()) {
86 String message
= "Mismatch! Asked for "
88 + " sessions, but only found "
92 throw new NoSessionException(message
);
96 } catch (SQLException e
) {
97 throw new RuntimeException("Failed read from session store", e
);
102 public List
<Integer
> getSubDeviceSessions(String name
) {
103 final var serviceId
= ServiceId
.parseOrThrow(name
);
104 // get all sessions for recipient except primary device session
109 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id != 1
111 ).formatted(TABLE_SESSION
);
112 try (final var connection
= database
.getConnection()) {
113 try (final var statement
= connection
.prepareStatement(sql
)) {
114 statement
.setInt(1, accountIdType
);
115 statement
.setString(2, serviceId
.toString());
116 return Utils
.executeQueryForStream(statement
, res
-> res
.getInt("device_id")).toList();
118 } catch (SQLException e
) {
119 throw new RuntimeException("Failed read from session store", e
);
123 public boolean isCurrentRatchetKey(ServiceId serviceId
, int deviceId
, ECPublicKey ratchetKey
) {
124 final var key
= new Key(serviceId
.toString(), deviceId
);
126 try (final var connection
= database
.getConnection()) {
127 final var session
= loadSession(connection
, key
);
128 if (session
== null) {
131 return session
.currentRatchetKeyMatches(ratchetKey
);
132 } catch (SQLException e
) {
133 throw new RuntimeException("Failed read from session store", e
);
138 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
139 final var key
= getKey(address
);
141 try (final var connection
= database
.getConnection()) {
142 storeSession(connection
, key
, session
);
143 } catch (SQLException e
) {
144 throw new RuntimeException("Failed read from session store", e
);
149 public boolean containsSession(SignalProtocolAddress address
) {
150 final var key
= getKey(address
);
152 try (final var connection
= database
.getConnection()) {
153 final var session
= loadSession(connection
, key
);
154 final var active
= isActive(session
);
155 logger
.trace("Contains session {}: {} (active: {})", address
, session
!= null, active
);
157 } catch (SQLException e
) {
158 throw new RuntimeException("Failed read from session store", e
);
163 public void deleteSession(SignalProtocolAddress address
) {
164 final var key
= getKey(address
);
166 try (final var connection
= database
.getConnection()) {
167 deleteSession(connection
, key
);
168 } catch (SQLException e
) {
169 throw new RuntimeException("Failed update session store", e
);
174 public void deleteAllSessions(String name
) {
175 final var serviceId
= ServiceId
.parseOrThrow(name
);
176 deleteAllSessions(serviceId
);
179 public void deleteAllSessions(ServiceId serviceId
) {
180 try (final var connection
= database
.getConnection()) {
181 deleteAllSessions(connection
, serviceId
.toString());
182 } catch (SQLException e
) {
183 throw new RuntimeException("Failed update session store", e
);
188 public void archiveSession(final SignalProtocolAddress address
) {
189 final var key
= getKey(address
);
191 try (final var connection
= database
.getConnection()) {
192 connection
.setAutoCommit(false);
193 final var session
= loadSession(connection
, key
);
194 if (session
!= null) {
195 session
.archiveCurrentState();
196 storeSession(connection
, key
, session
);
199 } catch (SQLException e
) {
200 throw new RuntimeException("Failed update session store", e
);
205 public Map
<SignalProtocolAddress
, SessionRecord
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
206 final var serviceIdsCommaSeparated
= addressNames
.stream()
207 .map(address
-> "'" + address
.replaceAll("'", "''") + "'")
208 .collect(Collectors
.joining(","));
211 SELECT s.address, s.device_id, s.record
213 WHERE s.account_id_type = ? AND s.address IN (%s)
215 ).formatted(TABLE_SESSION
, serviceIdsCommaSeparated
);
216 try (final var connection
= database
.getConnection()) {
217 try (final var statement
= connection
.prepareStatement(sql
)) {
218 statement
.setInt(1, accountIdType
);
219 return Utils
.executeQueryForStream(statement
,
220 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
221 .filter(pair
-> isActive(pair
.second()))
222 .collect(Collectors
.toMap(pair
-> new SignalProtocolAddress(pair
.first().address(),
223 pair
.first().deviceId()), Pair
::second
));
225 } catch (SQLException e
) {
226 throw new RuntimeException("Failed read from session store", e
);
230 public void archiveAllSessions() {
233 SELECT s.address, s.device_id, s.record
235 WHERE s.account_id_type = ?
237 ).formatted(TABLE_SESSION
);
238 try (final var connection
= database
.getConnection()) {
239 connection
.setAutoCommit(false);
240 final List
<Pair
<Key
, SessionRecord
>> records
;
241 try (final var statement
= connection
.prepareStatement(sql
)) {
242 statement
.setInt(1, accountIdType
);
243 records
= Utils
.executeQueryForStream(statement
,
244 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
245 .filter(Objects
::nonNull
)
248 for (final var record : records
) {
249 record.second().archiveCurrentState();
250 storeSession(connection
, record.first(), record.second());
253 } catch (SQLException e
) {
254 throw new RuntimeException("Failed update session store", e
);
258 public void archiveSessions(final ServiceId serviceId
) {
261 SELECT s.address, s.device_id, s.record
263 WHERE s.account_id_type = ? AND s.address = ?
265 ).formatted(TABLE_SESSION
);
266 try (final var connection
= database
.getConnection()) {
267 connection
.setAutoCommit(false);
268 final List
<Pair
<Key
, SessionRecord
>> records
;
269 try (final var statement
= connection
.prepareStatement(sql
)) {
270 statement
.setInt(1, accountIdType
);
271 statement
.setString(2, serviceId
.toString());
272 records
= Utils
.executeQueryForStream(statement
,
273 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
274 .filter(Objects
::nonNull
)
277 for (final var record : records
) {
278 record.second().archiveCurrentState();
279 storeSession(connection
, record.first(), record.second());
282 } catch (SQLException e
) {
283 throw new RuntimeException("Failed update session store", e
);
287 void addLegacySessions(final Collection
<Pair
<Key
, SessionRecord
>> sessions
) {
288 logger
.debug("Migrating legacy sessions to database");
289 long start
= System
.nanoTime();
290 try (final var connection
= database
.getConnection()) {
291 connection
.setAutoCommit(false);
292 for (final var pair
: sessions
) {
293 storeSession(connection
, pair
.first(), pair
.second());
296 } catch (SQLException e
) {
297 throw new RuntimeException("Failed update session store", e
);
299 logger
.debug("Complete sessions migration took {}ms", (System
.nanoTime() - start
) / 1000000);
302 private Key
getKey(final SignalProtocolAddress address
) {
303 return new Key(address
.getName(), address
.getDeviceId());
306 private SessionRecord
loadSession(Connection connection
, final Key key
) throws SQLException
{
307 synchronized (cachedSessions
) {
308 final var session
= cachedSessions
.get(key
);
309 if (session
!= null) {
317 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
319 ).formatted(TABLE_SESSION
);
320 try (final var statement
= connection
.prepareStatement(sql
)) {
321 statement
.setInt(1, accountIdType
);
322 statement
.setString(2, key
.address());
323 statement
.setInt(3, key
.deviceId());
324 return Utils
.executeQueryForOptional(statement
, this::getSessionRecordFromResultSet
).orElse(null);
328 private Key
getKeyFromResultSet(ResultSet resultSet
) throws SQLException
{
329 final var address
= resultSet
.getString("address");
330 final var deviceId
= resultSet
.getInt("device_id");
331 return new Key(address
, deviceId
);
334 private SessionRecord
getSessionRecordFromResultSet(ResultSet resultSet
) {
336 final var record = resultSet
.getBytes("record");
337 return new SessionRecord(record);
338 } catch (Exception e
) {
339 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
344 private void storeSession(
345 final Connection connection
, final Key key
, final SessionRecord session
346 ) throws SQLException
{
347 synchronized (cachedSessions
) {
348 cachedSessions
.put(key
, session
);
352 INSERT OR REPLACE INTO %s (account_id_type, address, device_id, record)
354 """.formatted(TABLE_SESSION
);
355 try (final var statement
= connection
.prepareStatement(sql
)) {
356 statement
.setInt(1, accountIdType
);
357 statement
.setString(2, key
.address());
358 statement
.setInt(3, key
.deviceId());
359 statement
.setBytes(4, session
.serialize());
360 statement
.executeUpdate();
364 private void deleteAllSessions(final Connection connection
, final String address
) throws SQLException
{
365 synchronized (cachedSessions
) {
366 cachedSessions
.clear();
372 WHERE s.account_id_type = ? AND s.address = ?
374 ).formatted(TABLE_SESSION
);
375 try (final var statement
= connection
.prepareStatement(sql
)) {
376 statement
.setInt(1, accountIdType
);
377 statement
.setString(2, address
);
378 statement
.executeUpdate();
382 private void deleteSession(Connection connection
, final Key key
) throws SQLException
{
383 synchronized (cachedSessions
) {
384 cachedSessions
.remove(key
);
390 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
392 ).formatted(TABLE_SESSION
);
393 try (final var statement
= connection
.prepareStatement(sql
)) {
394 statement
.setInt(1, accountIdType
);
395 statement
.setString(2, key
.address());
396 statement
.setInt(3, key
.deviceId());
397 statement
.executeUpdate();
401 private static boolean isActive(SessionRecord
record) {
402 return record != null && record.hasSenderChain();
405 record Key(String address
, int deviceId
) {}