1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.api
.Pair
;
4 import org
.asamk
.signal
.manager
.storage
.Database
;
5 import org
.asamk
.signal
.manager
.storage
.Utils
;
6 import org
.signal
.libsignal
.protocol
.NoSessionException
;
7 import org
.signal
.libsignal
.protocol
.SignalProtocolAddress
;
8 import org
.signal
.libsignal
.protocol
.ecc
.ECPublicKey
;
9 import org
.signal
.libsignal
.protocol
.state
.SessionRecord
;
10 import org
.slf4j
.Logger
;
11 import org
.slf4j
.LoggerFactory
;
12 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
13 import org
.whispersystems
.signalservice
.api
.push
.ServiceId
;
14 import org
.whispersystems
.signalservice
.api
.push
.ServiceIdType
;
16 import java
.sql
.Connection
;
17 import java
.sql
.ResultSet
;
18 import java
.sql
.SQLException
;
19 import java
.util
.ArrayList
;
20 import java
.util
.Collection
;
21 import java
.util
.HashMap
;
22 import java
.util
.List
;
24 import java
.util
.Objects
;
25 import java
.util
.stream
.Collectors
;
27 public class SessionStore
implements SignalServiceSessionStore
{
29 private static final String TABLE_SESSION
= "session";
30 private static final Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
32 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
34 private final Database database
;
35 private final int accountIdType
;
37 public static void createSql(Connection connection
) throws SQLException
{
38 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
39 try (final var statement
= connection
.createStatement()) {
40 statement
.executeUpdate("""
41 CREATE TABLE session (
42 _id INTEGER PRIMARY KEY,
43 account_id_type INTEGER NOT NULL,
44 address TEXT NOT NULL,
45 device_id INTEGER NOT NULL,
47 UNIQUE(account_id_type, address, device_id)
53 public SessionStore(final Database database
, final ServiceIdType serviceIdType
) {
54 this.database
= database
;
55 this.accountIdType
= Utils
.getAccountIdType(serviceIdType
);
59 public SessionRecord
loadSession(SignalProtocolAddress address
) {
60 final var key
= getKey(address
);
61 try (final var connection
= database
.getConnection()) {
62 final var session
= loadSession(connection
, key
);
63 return Objects
.requireNonNullElseGet(session
, SessionRecord
::new);
64 } catch (SQLException e
) {
65 throw new RuntimeException("Failed read from session store", e
);
70 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
71 final var keys
= addresses
.stream().map(this::getKey
).toList();
73 try (final var connection
= database
.getConnection()) {
74 final var sessions
= new ArrayList
<SessionRecord
>();
75 for (final var key
: keys
) {
76 final var sessionRecord
= loadSession(connection
, key
);
77 if (sessionRecord
!= null) {
78 sessions
.add(sessionRecord
);
82 if (sessions
.size() != addresses
.size()) {
83 String message
= "Mismatch! Asked for "
85 + " sessions, but only found "
89 throw new NoSessionException(message
);
93 } catch (SQLException e
) {
94 throw new RuntimeException("Failed read from session store", e
);
99 public List
<Integer
> getSubDeviceSessions(String name
) {
100 final var serviceId
= ServiceId
.parseOrThrow(name
);
101 // get all sessions for recipient except primary device session
106 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id != 1
108 ).formatted(TABLE_SESSION
);
109 try (final var connection
= database
.getConnection()) {
110 try (final var statement
= connection
.prepareStatement(sql
)) {
111 statement
.setInt(1, accountIdType
);
112 statement
.setString(2, serviceId
.toString());
113 return Utils
.executeQueryForStream(statement
, res
-> res
.getInt("device_id")).toList();
115 } catch (SQLException e
) {
116 throw new RuntimeException("Failed read from session store", e
);
120 public boolean isCurrentRatchetKey(ServiceId serviceId
, int deviceId
, ECPublicKey ratchetKey
) {
121 final var key
= new Key(serviceId
.toString(), deviceId
);
123 try (final var connection
= database
.getConnection()) {
124 final var session
= loadSession(connection
, key
);
125 if (session
== null) {
128 return session
.currentRatchetKeyMatches(ratchetKey
);
129 } catch (SQLException e
) {
130 throw new RuntimeException("Failed read from session store", e
);
135 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
136 final var key
= getKey(address
);
138 try (final var connection
= database
.getConnection()) {
139 storeSession(connection
, key
, session
);
140 } catch (SQLException e
) {
141 throw new RuntimeException("Failed read from session store", e
);
146 public boolean containsSession(SignalProtocolAddress address
) {
147 final var key
= getKey(address
);
149 try (final var connection
= database
.getConnection()) {
150 final var session
= loadSession(connection
, key
);
151 return isActive(session
);
152 } catch (SQLException e
) {
153 throw new RuntimeException("Failed read from session store", e
);
158 public void deleteSession(SignalProtocolAddress address
) {
159 final var key
= getKey(address
);
161 try (final var connection
= database
.getConnection()) {
162 deleteSession(connection
, key
);
163 } catch (SQLException e
) {
164 throw new RuntimeException("Failed update session store", e
);
169 public void deleteAllSessions(String name
) {
170 final var serviceId
= ServiceId
.parseOrThrow(name
);
171 deleteAllSessions(serviceId
);
174 public void deleteAllSessions(ServiceId serviceId
) {
175 try (final var connection
= database
.getConnection()) {
176 deleteAllSessions(connection
, serviceId
.toString());
177 } catch (SQLException e
) {
178 throw new RuntimeException("Failed update session store", e
);
183 public void archiveSession(final SignalProtocolAddress address
) {
184 final var key
= getKey(address
);
186 try (final var connection
= database
.getConnection()) {
187 connection
.setAutoCommit(false);
188 final var session
= loadSession(connection
, key
);
189 if (session
!= null) {
190 session
.archiveCurrentState();
191 storeSession(connection
, key
, session
);
194 } catch (SQLException e
) {
195 throw new RuntimeException("Failed update session store", e
);
200 public Map
<SignalProtocolAddress
, SessionRecord
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
201 final var serviceIdsCommaSeparated
= addressNames
.stream()
202 .map(address
-> "'" + address
.replaceAll("'", "''") + "'")
203 .collect(Collectors
.joining(","));
206 SELECT s.address, s.device_id, s.record
208 WHERE s.account_id_type = ? AND s.address IN (%s)
210 ).formatted(TABLE_SESSION
, serviceIdsCommaSeparated
);
211 try (final var connection
= database
.getConnection()) {
212 try (final var statement
= connection
.prepareStatement(sql
)) {
213 statement
.setInt(1, accountIdType
);
214 return Utils
.executeQueryForStream(statement
,
215 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
216 .filter(pair
-> isActive(pair
.second()))
217 .collect(Collectors
.toMap(pair
-> new SignalProtocolAddress(pair
.first().address(),
218 pair
.first().deviceId()), Pair
::second
));
220 } catch (SQLException e
) {
221 throw new RuntimeException("Failed read from session store", e
);
225 public void archiveAllSessions() {
228 SELECT s.address, s.device_id, s.record
230 WHERE s.account_id_type = ?
232 ).formatted(TABLE_SESSION
);
233 try (final var connection
= database
.getConnection()) {
234 connection
.setAutoCommit(false);
235 final List
<Pair
<Key
, SessionRecord
>> records
;
236 try (final var statement
= connection
.prepareStatement(sql
)) {
237 statement
.setInt(1, accountIdType
);
238 records
= Utils
.executeQueryForStream(statement
,
239 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
240 .filter(Objects
::nonNull
)
243 for (final var record : records
) {
244 record.second().archiveCurrentState();
245 storeSession(connection
, record.first(), record.second());
248 } catch (SQLException e
) {
249 throw new RuntimeException("Failed update session store", e
);
253 public void archiveSessions(final ServiceId serviceId
) {
256 SELECT s.address, s.device_id, s.record
258 WHERE s.account_id_type = ? AND s.address = ?
260 ).formatted(TABLE_SESSION
);
261 try (final var connection
= database
.getConnection()) {
262 connection
.setAutoCommit(false);
263 final List
<Pair
<Key
, SessionRecord
>> records
;
264 try (final var statement
= connection
.prepareStatement(sql
)) {
265 statement
.setInt(1, accountIdType
);
266 statement
.setString(2, serviceId
.toString());
267 records
= Utils
.executeQueryForStream(statement
,
268 res
-> new Pair
<>(getKeyFromResultSet(res
), getSessionRecordFromResultSet(res
)))
269 .filter(Objects
::nonNull
)
272 for (final var record : records
) {
273 record.second().archiveCurrentState();
274 storeSession(connection
, record.first(), record.second());
277 } catch (SQLException e
) {
278 throw new RuntimeException("Failed update session store", e
);
282 void addLegacySessions(final Collection
<Pair
<Key
, SessionRecord
>> sessions
) {
283 logger
.debug("Migrating legacy sessions to database");
284 long start
= System
.nanoTime();
285 try (final var connection
= database
.getConnection()) {
286 connection
.setAutoCommit(false);
287 for (final var pair
: sessions
) {
288 storeSession(connection
, pair
.first(), pair
.second());
291 } catch (SQLException e
) {
292 throw new RuntimeException("Failed update session store", e
);
294 logger
.debug("Complete sessions migration took {}ms", (System
.nanoTime() - start
) / 1000000);
297 private Key
getKey(final SignalProtocolAddress address
) {
298 return new Key(address
.getName(), address
.getDeviceId());
301 private SessionRecord
loadSession(Connection connection
, final Key key
) throws SQLException
{
302 synchronized (cachedSessions
) {
303 final var session
= cachedSessions
.get(key
);
304 if (session
!= null) {
312 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
314 ).formatted(TABLE_SESSION
);
315 try (final var statement
= connection
.prepareStatement(sql
)) {
316 statement
.setInt(1, accountIdType
);
317 statement
.setString(2, key
.address());
318 statement
.setInt(3, key
.deviceId());
319 return Utils
.executeQueryForOptional(statement
, this::getSessionRecordFromResultSet
).orElse(null);
323 private Key
getKeyFromResultSet(ResultSet resultSet
) throws SQLException
{
324 final var address
= resultSet
.getString("address");
325 final var deviceId
= resultSet
.getInt("device_id");
326 return new Key(address
, deviceId
);
329 private SessionRecord
getSessionRecordFromResultSet(ResultSet resultSet
) {
331 final var record = resultSet
.getBytes("record");
332 return new SessionRecord(record);
333 } catch (Exception e
) {
334 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
339 private void storeSession(
340 final Connection connection
, final Key key
, final SessionRecord session
341 ) throws SQLException
{
342 synchronized (cachedSessions
) {
343 cachedSessions
.put(key
, session
);
347 INSERT OR REPLACE INTO %s (account_id_type, address, device_id, record)
349 """.formatted(TABLE_SESSION
);
350 try (final var statement
= connection
.prepareStatement(sql
)) {
351 statement
.setInt(1, accountIdType
);
352 statement
.setString(2, key
.address());
353 statement
.setInt(3, key
.deviceId());
354 statement
.setBytes(4, session
.serialize());
355 statement
.executeUpdate();
359 private void deleteAllSessions(final Connection connection
, final String address
) throws SQLException
{
360 synchronized (cachedSessions
) {
361 cachedSessions
.clear();
367 WHERE s.account_id_type = ? AND s.address = ?
369 ).formatted(TABLE_SESSION
);
370 try (final var statement
= connection
.prepareStatement(sql
)) {
371 statement
.setInt(1, accountIdType
);
372 statement
.setString(2, address
);
373 statement
.executeUpdate();
377 private void deleteSession(Connection connection
, final Key key
) throws SQLException
{
378 synchronized (cachedSessions
) {
379 cachedSessions
.remove(key
);
385 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
387 ).formatted(TABLE_SESSION
);
388 try (final var statement
= connection
.prepareStatement(sql
)) {
389 statement
.setInt(1, accountIdType
);
390 statement
.setString(2, key
.address());
391 statement
.setInt(3, key
.deviceId());
392 statement
.executeUpdate();
396 private static boolean isActive(SessionRecord
record) {
397 return record != null && record.hasSenderChain();
400 record Key(String address
, int deviceId
) {}