]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
ad0e198ea6b03b13f431766e65eed004c8cd2d30
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.api.Pair;
4 import org.asamk.signal.manager.storage.Database;
5 import org.asamk.signal.manager.storage.Utils;
6 import org.asamk.signal.manager.storage.recipients.RecipientId;
7 import org.asamk.signal.manager.storage.recipients.RecipientIdCreator;
8 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
9 import org.signal.libsignal.protocol.InvalidMessageException;
10 import org.signal.libsignal.protocol.NoSessionException;
11 import org.signal.libsignal.protocol.SignalProtocolAddress;
12 import org.signal.libsignal.protocol.ecc.ECPublicKey;
13 import org.signal.libsignal.protocol.message.CiphertextMessage;
14 import org.signal.libsignal.protocol.state.SessionRecord;
15 import org.slf4j.Logger;
16 import org.slf4j.LoggerFactory;
17 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
18 import org.whispersystems.signalservice.api.push.ServiceIdType;
19
20 import java.sql.Connection;
21 import java.sql.ResultSet;
22 import java.sql.SQLException;
23 import java.util.ArrayList;
24 import java.util.Collection;
25 import java.util.HashMap;
26 import java.util.List;
27 import java.util.Map;
28 import java.util.Objects;
29 import java.util.Set;
30 import java.util.stream.Collectors;
31
32 public class SessionStore implements SignalServiceSessionStore {
33
34 private static final String TABLE_SESSION = "session";
35 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
36
37 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
38
39 private final Database database;
40 private final int accountIdType;
41 private final RecipientResolver resolver;
42 private final RecipientIdCreator recipientIdCreator;
43
44 public static void createSql(Connection connection) throws SQLException {
45 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
46 try (final var statement = connection.createStatement()) {
47 statement.executeUpdate("""
48 CREATE TABLE session (
49 _id INTEGER PRIMARY KEY,
50 account_id_type INTEGER NOT NULL,
51 recipient_id INTEGER NOT NULL REFERENCES recipient (_id) ON DELETE CASCADE,
52 device_id INTEGER NOT NULL,
53 record BLOB NOT NULL,
54 UNIQUE(account_id_type, recipient_id, device_id)
55 );
56 """);
57 }
58 }
59
60 public SessionStore(
61 final Database database,
62 final ServiceIdType serviceIdType,
63 final RecipientResolver resolver,
64 final RecipientIdCreator recipientIdCreator
65 ) {
66 this.database = database;
67 this.accountIdType = Utils.getAccountIdType(serviceIdType);
68 this.resolver = resolver;
69 this.recipientIdCreator = recipientIdCreator;
70 }
71
72 @Override
73 public SessionRecord loadSession(SignalProtocolAddress address) {
74 final var key = getKey(address);
75 try (final var connection = database.getConnection()) {
76 final var session = loadSession(connection, key);
77 return Objects.requireNonNullElseGet(session, SessionRecord::new);
78 } catch (SQLException e) {
79 throw new RuntimeException("Failed read from session store", e);
80 }
81 }
82
83 @Override
84 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
85 final var keys = addresses.stream().map(this::getKey).toList();
86
87 try (final var connection = database.getConnection()) {
88 final var sessions = new ArrayList<SessionRecord>();
89 for (final var key : keys) {
90 final var sessionRecord = loadSession(connection, key);
91 if (sessionRecord != null) {
92 sessions.add(sessionRecord);
93 }
94 }
95
96 if (sessions.size() != addresses.size()) {
97 String message = "Mismatch! Asked for "
98 + addresses.size()
99 + " sessions, but only found "
100 + sessions.size()
101 + "!";
102 logger.warn(message);
103 throw new NoSessionException(message);
104 }
105
106 return sessions;
107 } catch (SQLException e) {
108 throw new RuntimeException("Failed read from session store", e);
109 }
110 }
111
112 @Override
113 public List<Integer> getSubDeviceSessions(String name) {
114 final var recipientId = resolver.resolveRecipient(name);
115 // get all sessions for recipient except primary device session
116 final var sql = (
117 """
118 SELECT s.device_id
119 FROM %s AS s
120 WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id != 1
121 """
122 ).formatted(TABLE_SESSION);
123 try (final var connection = database.getConnection()) {
124 try (final var statement = connection.prepareStatement(sql)) {
125 statement.setInt(1, accountIdType);
126 statement.setLong(2, recipientId.id());
127 return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
128 }
129 } catch (SQLException e) {
130 throw new RuntimeException("Failed read from session store", e);
131 }
132 }
133
134 public boolean isCurrentRatchetKey(RecipientId recipientId, int deviceId, ECPublicKey ratchetKey) {
135 final var key = new Key(recipientId, deviceId);
136
137 try (final var connection = database.getConnection()) {
138 final var session = loadSession(connection, key);
139 if (session == null) {
140 return false;
141 }
142 return session.currentRatchetKeyMatches(ratchetKey);
143 } catch (SQLException e) {
144 throw new RuntimeException("Failed read from session store", e);
145 }
146 }
147
148 @Override
149 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
150 final var key = getKey(address);
151
152 try (final var connection = database.getConnection()) {
153 storeSession(connection, key, session);
154 } catch (SQLException e) {
155 throw new RuntimeException("Failed read from session store", e);
156 }
157 }
158
159 @Override
160 public boolean containsSession(SignalProtocolAddress address) {
161 final var key = getKey(address);
162
163 try (final var connection = database.getConnection()) {
164 final var session = loadSession(connection, key);
165 return isActive(session);
166 } catch (SQLException e) {
167 throw new RuntimeException("Failed read from session store", e);
168 }
169 }
170
171 @Override
172 public void deleteSession(SignalProtocolAddress address) {
173 final var key = getKey(address);
174
175 try (final var connection = database.getConnection()) {
176 deleteSession(connection, key);
177 } catch (SQLException e) {
178 throw new RuntimeException("Failed update session store", e);
179 }
180 }
181
182 @Override
183 public void deleteAllSessions(String name) {
184 final var recipientId = resolver.resolveRecipient(name);
185 deleteAllSessions(recipientId);
186 }
187
188 public void deleteAllSessions(RecipientId recipientId) {
189 try (final var connection = database.getConnection()) {
190 deleteAllSessions(connection, recipientId);
191 } catch (SQLException e) {
192 throw new RuntimeException("Failed update session store", e);
193 }
194 }
195
196 @Override
197 public void archiveSession(final SignalProtocolAddress address) {
198 final var key = getKey(address);
199
200 try (final var connection = database.getConnection()) {
201 connection.setAutoCommit(false);
202 final var session = loadSession(connection, key);
203 if (session != null) {
204 session.archiveCurrentState();
205 storeSession(connection, key, session);
206 connection.commit();
207 }
208 } catch (SQLException e) {
209 throw new RuntimeException("Failed update session store", e);
210 }
211 }
212
213 @Override
214 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
215 final var recipientIdToNameMap = addressNames.stream()
216 .collect(Collectors.toMap(resolver::resolveRecipient, name -> name));
217 final var recipientIdsCommaSeparated = recipientIdToNameMap.keySet()
218 .stream()
219 .map(recipientId -> String.valueOf(recipientId.id()))
220 .collect(Collectors.joining(","));
221 final var sql = (
222 """
223 SELECT s.recipient_id, s.device_id, s.record
224 FROM %s AS s
225 WHERE s.account_id_type = ? AND s.recipient_id IN (%s)
226 """
227 ).formatted(TABLE_SESSION, recipientIdsCommaSeparated);
228 try (final var connection = database.getConnection()) {
229 try (final var statement = connection.prepareStatement(sql)) {
230 statement.setInt(1, accountIdType);
231 return Utils.executeQueryForStream(statement,
232 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
233 .filter(pair -> isActive(pair.second()))
234 .map(Pair::first)
235 .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId),
236 key.deviceId()))
237 .collect(Collectors.toSet());
238 }
239 } catch (SQLException e) {
240 throw new RuntimeException("Failed read from session store", e);
241 }
242 }
243
244 public void archiveAllSessions() {
245 final var sql = (
246 """
247 SELECT s.recipient_id, s.device_id, s.record
248 FROM %s AS s
249 WHERE s.account_id_type = ?
250 """
251 ).formatted(TABLE_SESSION);
252 try (final var connection = database.getConnection()) {
253 connection.setAutoCommit(false);
254 final List<Pair<Key, SessionRecord>> records;
255 try (final var statement = connection.prepareStatement(sql)) {
256 statement.setInt(1, accountIdType);
257 records = Utils.executeQueryForStream(statement,
258 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
259 }
260 for (final var record : records) {
261 record.second().archiveCurrentState();
262 storeSession(connection, record.first(), record.second());
263 }
264 connection.commit();
265 } catch (SQLException e) {
266 throw new RuntimeException("Failed update session store", e);
267 }
268 }
269
270 public void archiveSessions(final RecipientId recipientId) {
271 final var sql = (
272 """
273 SELECT s.recipient_id, s.device_id, s.record
274 FROM %s AS s
275 WHERE s.account_id_type = ? AND s.recipient_id = ?
276 """
277 ).formatted(TABLE_SESSION);
278 try (final var connection = database.getConnection()) {
279 connection.setAutoCommit(false);
280 final List<Pair<Key, SessionRecord>> records;
281 try (final var statement = connection.prepareStatement(sql)) {
282 statement.setInt(1, accountIdType);
283 statement.setLong(2, recipientId.id());
284 records = Utils.executeQueryForStream(statement,
285 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
286 }
287 for (final var record : records) {
288 record.second().archiveCurrentState();
289 storeSession(connection, record.first(), record.second());
290 }
291 connection.commit();
292 } catch (SQLException e) {
293 throw new RuntimeException("Failed update session store", e);
294 }
295 }
296
297 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
298 try (final var connection = database.getConnection()) {
299 connection.setAutoCommit(false);
300 synchronized (cachedSessions) {
301 cachedSessions.clear();
302 }
303
304 final var sql = """
305 UPDATE OR IGNORE %s
306 SET recipient_id = ?
307 WHERE account_id_type = ? AND recipient_id = ?
308 """.formatted(TABLE_SESSION);
309 try (final var statement = connection.prepareStatement(sql)) {
310 statement.setLong(1, recipientId.id());
311 statement.setInt(2, accountIdType);
312 statement.setLong(3, toBeMergedRecipientId.id());
313 final var rows = statement.executeUpdate();
314 if (rows > 0) {
315 logger.debug("Reassigned {} sessions of to be merged recipient.", rows);
316 }
317 }
318 // Delete all conflicting sessions now
319 deleteAllSessions(connection, toBeMergedRecipientId);
320 connection.commit();
321 } catch (SQLException e) {
322 throw new RuntimeException("Failed update session store", e);
323 }
324 }
325
326 void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
327 logger.debug("Migrating legacy sessions to database");
328 long start = System.nanoTime();
329 try (final var connection = database.getConnection()) {
330 connection.setAutoCommit(false);
331 for (final var pair : sessions) {
332 storeSession(connection, pair.first(), pair.second());
333 }
334 connection.commit();
335 } catch (SQLException e) {
336 throw new RuntimeException("Failed update session store", e);
337 }
338 logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
339 }
340
341 private Key getKey(final SignalProtocolAddress address) {
342 final var recipientId = resolver.resolveRecipient(address.getName());
343 return new Key(recipientId, address.getDeviceId());
344 }
345
346 private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
347 synchronized (cachedSessions) {
348 final var session = cachedSessions.get(key);
349 if (session != null) {
350 return session;
351 }
352 }
353 final var sql = (
354 """
355 SELECT s.record
356 FROM %s AS s
357 WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ?
358 """
359 ).formatted(TABLE_SESSION);
360 try (final var statement = connection.prepareStatement(sql)) {
361 statement.setInt(1, accountIdType);
362 statement.setLong(2, key.recipientId().id());
363 statement.setInt(3, key.deviceId());
364 return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
365 }
366 }
367
368 private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
369 final var recipientId = resultSet.getLong("recipient_id");
370 final var deviceId = resultSet.getInt("device_id");
371 return new Key(recipientIdCreator.create(recipientId), deviceId);
372 }
373
374 private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException {
375 try {
376 final var record = resultSet.getBytes("record");
377 return new SessionRecord(record);
378 } catch (InvalidMessageException e) {
379 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
380 return null;
381 }
382 }
383
384 private void storeSession(
385 final Connection connection, final Key key, final SessionRecord session
386 ) throws SQLException {
387 synchronized (cachedSessions) {
388 cachedSessions.put(key, session);
389 }
390
391 final var sql = """
392 INSERT OR REPLACE INTO %s (account_id_type, recipient_id, device_id, record)
393 VALUES (?, ?, ?, ?)
394 """.formatted(TABLE_SESSION);
395 try (final var statement = connection.prepareStatement(sql)) {
396 statement.setInt(1, accountIdType);
397 statement.setLong(2, key.recipientId().id());
398 statement.setInt(3, key.deviceId());
399 statement.setBytes(4, session.serialize());
400 statement.executeUpdate();
401 }
402 }
403
404 private void deleteAllSessions(final Connection connection, final RecipientId recipientId) throws SQLException {
405 synchronized (cachedSessions) {
406 cachedSessions.clear();
407 }
408
409 final var sql = (
410 """
411 DELETE FROM %s AS s
412 WHERE s.account_id_type = ? AND s.recipient_id = ?
413 """
414 ).formatted(TABLE_SESSION);
415 try (final var statement = connection.prepareStatement(sql)) {
416 statement.setInt(1, accountIdType);
417 statement.setLong(2, recipientId.id());
418 statement.executeUpdate();
419 }
420 }
421
422 private void deleteSession(Connection connection, final Key key) throws SQLException {
423 synchronized (cachedSessions) {
424 cachedSessions.remove(key);
425 }
426
427 final var sql = (
428 """
429 DELETE FROM %s AS s
430 WHERE s.account_id_type = ? AND s.recipient_id = ? AND s.device_id = ?
431 """
432 ).formatted(TABLE_SESSION);
433 try (final var statement = connection.prepareStatement(sql)) {
434 statement.setInt(1, accountIdType);
435 statement.setLong(2, key.recipientId().id());
436 statement.setInt(3, key.deviceId());
437 statement.executeUpdate();
438 }
439 }
440
441 private static boolean isActive(SessionRecord record) {
442 return record != null
443 && record.hasSenderChain()
444 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
445 }
446
447 record Key(RecipientId recipientId, int deviceId) {}
448 }