]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
ec933742ea702abd77206f618154e5fee0906a6d
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.api.Pair;
4 import org.asamk.signal.manager.storage.Database;
5 import org.asamk.signal.manager.storage.Utils;
6 import org.signal.libsignal.protocol.InvalidMessageException;
7 import org.signal.libsignal.protocol.NoSessionException;
8 import org.signal.libsignal.protocol.SignalProtocolAddress;
9 import org.signal.libsignal.protocol.ecc.ECPublicKey;
10 import org.signal.libsignal.protocol.message.CiphertextMessage;
11 import org.signal.libsignal.protocol.state.SessionRecord;
12 import org.signal.libsignal.protocol.util.Hex;
13 import org.slf4j.Logger;
14 import org.slf4j.LoggerFactory;
15 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
16 import org.whispersystems.signalservice.api.push.ServiceId;
17 import org.whispersystems.signalservice.api.push.ServiceIdType;
18 import org.whispersystems.signalservice.api.util.UuidUtil;
19
20 import java.sql.Connection;
21 import java.sql.ResultSet;
22 import java.sql.SQLException;
23 import java.util.ArrayList;
24 import java.util.Collection;
25 import java.util.HashMap;
26 import java.util.List;
27 import java.util.Map;
28 import java.util.Objects;
29 import java.util.Set;
30 import java.util.stream.Collectors;
31
32 public class SessionStore implements SignalServiceSessionStore {
33
34 private static final String TABLE_SESSION = "session";
35 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
36
37 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
38
39 private final Database database;
40 private final int accountIdType;
41
42 public static void createSql(Connection connection) throws SQLException {
43 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
44 try (final var statement = connection.createStatement()) {
45 statement.executeUpdate("""
46 CREATE TABLE session (
47 _id INTEGER PRIMARY KEY,
48 account_id_type INTEGER NOT NULL,
49 uuid BLOB NOT NULL,
50 device_id INTEGER NOT NULL,
51 record BLOB NOT NULL,
52 UNIQUE(account_id_type, uuid, device_id)
53 ) STRICT;
54 """);
55 }
56 }
57
58 public SessionStore(final Database database, final ServiceIdType serviceIdType) {
59 this.database = database;
60 this.accountIdType = Utils.getAccountIdType(serviceIdType);
61 }
62
63 @Override
64 public SessionRecord loadSession(SignalProtocolAddress address) {
65 final var key = getKey(address);
66 try (final var connection = database.getConnection()) {
67 final var session = loadSession(connection, key);
68 return Objects.requireNonNullElseGet(session, SessionRecord::new);
69 } catch (SQLException e) {
70 throw new RuntimeException("Failed read from session store", e);
71 }
72 }
73
74 @Override
75 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
76 final var keys = addresses.stream().map(this::getKey).toList();
77
78 try (final var connection = database.getConnection()) {
79 final var sessions = new ArrayList<SessionRecord>();
80 for (final var key : keys) {
81 final var sessionRecord = loadSession(connection, key);
82 if (sessionRecord != null) {
83 sessions.add(sessionRecord);
84 }
85 }
86
87 if (sessions.size() != addresses.size()) {
88 String message = "Mismatch! Asked for "
89 + addresses.size()
90 + " sessions, but only found "
91 + sessions.size()
92 + "!";
93 logger.warn(message);
94 throw new NoSessionException(message);
95 }
96
97 return sessions;
98 } catch (SQLException e) {
99 throw new RuntimeException("Failed read from session store", e);
100 }
101 }
102
103 @Override
104 public List<Integer> getSubDeviceSessions(String name) {
105 final var serviceId = ServiceId.parseOrThrow(name);
106 // get all sessions for recipient except primary device session
107 final var sql = (
108 """
109 SELECT s.device_id
110 FROM %s AS s
111 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id != 1
112 """
113 ).formatted(TABLE_SESSION);
114 try (final var connection = database.getConnection()) {
115 try (final var statement = connection.prepareStatement(sql)) {
116 statement.setInt(1, accountIdType);
117 statement.setBytes(2, serviceId.toByteArray());
118 return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
119 }
120 } catch (SQLException e) {
121 throw new RuntimeException("Failed read from session store", e);
122 }
123 }
124
125 public boolean isCurrentRatchetKey(ServiceId serviceId, int deviceId, ECPublicKey ratchetKey) {
126 final var key = new Key(serviceId, deviceId);
127
128 try (final var connection = database.getConnection()) {
129 final var session = loadSession(connection, key);
130 if (session == null) {
131 return false;
132 }
133 return session.currentRatchetKeyMatches(ratchetKey);
134 } catch (SQLException e) {
135 throw new RuntimeException("Failed read from session store", e);
136 }
137 }
138
139 @Override
140 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
141 final var key = getKey(address);
142
143 try (final var connection = database.getConnection()) {
144 storeSession(connection, key, session);
145 } catch (SQLException e) {
146 throw new RuntimeException("Failed read from session store", e);
147 }
148 }
149
150 @Override
151 public boolean containsSession(SignalProtocolAddress address) {
152 final var key = getKey(address);
153
154 try (final var connection = database.getConnection()) {
155 final var session = loadSession(connection, key);
156 return isActive(session);
157 } catch (SQLException e) {
158 throw new RuntimeException("Failed read from session store", e);
159 }
160 }
161
162 @Override
163 public void deleteSession(SignalProtocolAddress address) {
164 final var key = getKey(address);
165
166 try (final var connection = database.getConnection()) {
167 deleteSession(connection, key);
168 } catch (SQLException e) {
169 throw new RuntimeException("Failed update session store", e);
170 }
171 }
172
173 @Override
174 public void deleteAllSessions(String name) {
175 final var serviceId = ServiceId.parseOrThrow(name);
176 deleteAllSessions(serviceId);
177 }
178
179 public void deleteAllSessions(ServiceId serviceId) {
180 try (final var connection = database.getConnection()) {
181 deleteAllSessions(connection, serviceId);
182 } catch (SQLException e) {
183 throw new RuntimeException("Failed update session store", e);
184 }
185 }
186
187 @Override
188 public void archiveSession(final SignalProtocolAddress address) {
189 if (!UuidUtil.isUuid(address.getName())) {
190 return;
191 }
192
193 final var key = getKey(address);
194
195 try (final var connection = database.getConnection()) {
196 connection.setAutoCommit(false);
197 final var session = loadSession(connection, key);
198 if (session != null) {
199 session.archiveCurrentState();
200 storeSession(connection, key, session);
201 connection.commit();
202 }
203 } catch (SQLException e) {
204 throw new RuntimeException("Failed update session store", e);
205 }
206 }
207
208 @Override
209 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
210 final var serviceIdsCommaSeparated = addressNames.stream()
211 .map(ServiceId::parseOrThrow)
212 .map(ServiceId::toByteArray)
213 .map(uuid -> "x'" + Hex.toStringCondensed(uuid) + "'")
214 .collect(Collectors.joining(","));
215 final var sql = (
216 """
217 SELECT s.uuid, s.device_id, s.record
218 FROM %s AS s
219 WHERE s.account_id_type = ? AND s.uuid IN (%s)
220 """
221 ).formatted(TABLE_SESSION, serviceIdsCommaSeparated);
222 try (final var connection = database.getConnection()) {
223 try (final var statement = connection.prepareStatement(sql)) {
224 statement.setInt(1, accountIdType);
225 return Utils.executeQueryForStream(statement,
226 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
227 .filter(pair -> isActive(pair.second()))
228 .map(Pair::first)
229 .map(key -> key.serviceId().toProtocolAddress(key.deviceId()))
230 .collect(Collectors.toSet());
231 }
232 } catch (SQLException e) {
233 throw new RuntimeException("Failed read from session store", e);
234 }
235 }
236
237 public void archiveAllSessions() {
238 final var sql = (
239 """
240 SELECT s.uuid, s.device_id, s.record
241 FROM %s AS s
242 WHERE s.account_id_type = ?
243 """
244 ).formatted(TABLE_SESSION);
245 try (final var connection = database.getConnection()) {
246 connection.setAutoCommit(false);
247 final List<Pair<Key, SessionRecord>> records;
248 try (final var statement = connection.prepareStatement(sql)) {
249 statement.setInt(1, accountIdType);
250 records = Utils.executeQueryForStream(statement,
251 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
252 }
253 for (final var record : records) {
254 record.second().archiveCurrentState();
255 storeSession(connection, record.first(), record.second());
256 }
257 connection.commit();
258 } catch (SQLException e) {
259 throw new RuntimeException("Failed update session store", e);
260 }
261 }
262
263 public void archiveSessions(final ServiceId serviceId) {
264 final var sql = (
265 """
266 SELECT s.uuid, s.device_id, s.record
267 FROM %s AS s
268 WHERE s.account_id_type = ? AND s.uuid = ?
269 """
270 ).formatted(TABLE_SESSION);
271 try (final var connection = database.getConnection()) {
272 connection.setAutoCommit(false);
273 final List<Pair<Key, SessionRecord>> records;
274 try (final var statement = connection.prepareStatement(sql)) {
275 statement.setInt(1, accountIdType);
276 statement.setBytes(2, serviceId.toByteArray());
277 records = Utils.executeQueryForStream(statement,
278 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
279 }
280 for (final var record : records) {
281 record.second().archiveCurrentState();
282 storeSession(connection, record.first(), record.second());
283 }
284 connection.commit();
285 } catch (SQLException e) {
286 throw new RuntimeException("Failed update session store", e);
287 }
288 }
289
290 void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
291 logger.debug("Migrating legacy sessions to database");
292 long start = System.nanoTime();
293 try (final var connection = database.getConnection()) {
294 connection.setAutoCommit(false);
295 for (final var pair : sessions) {
296 storeSession(connection, pair.first(), pair.second());
297 }
298 connection.commit();
299 } catch (SQLException e) {
300 throw new RuntimeException("Failed update session store", e);
301 }
302 logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
303 }
304
305 private Key getKey(final SignalProtocolAddress address) {
306 final var serviceId = ServiceId.parseOrThrow(address.getName());
307 return new Key(serviceId, address.getDeviceId());
308 }
309
310 private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
311 synchronized (cachedSessions) {
312 final var session = cachedSessions.get(key);
313 if (session != null) {
314 return session;
315 }
316 }
317 final var sql = (
318 """
319 SELECT s.record
320 FROM %s AS s
321 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
322 """
323 ).formatted(TABLE_SESSION);
324 try (final var statement = connection.prepareStatement(sql)) {
325 statement.setInt(1, accountIdType);
326 statement.setBytes(2, key.serviceId().toByteArray());
327 statement.setInt(3, key.deviceId());
328 return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
329 }
330 }
331
332 private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
333 final var serviceId = ServiceId.parseOrThrow(resultSet.getBytes("uuid"));
334 final var deviceId = resultSet.getInt("device_id");
335 return new Key(serviceId, deviceId);
336 }
337
338 private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException {
339 try {
340 final var record = resultSet.getBytes("record");
341 return new SessionRecord(record);
342 } catch (InvalidMessageException e) {
343 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
344 return null;
345 }
346 }
347
348 private void storeSession(
349 final Connection connection, final Key key, final SessionRecord session
350 ) throws SQLException {
351 synchronized (cachedSessions) {
352 cachedSessions.put(key, session);
353 }
354
355 final var sql = """
356 INSERT OR REPLACE INTO %s (account_id_type, uuid, device_id, record)
357 VALUES (?, ?, ?, ?)
358 """.formatted(TABLE_SESSION);
359 try (final var statement = connection.prepareStatement(sql)) {
360 statement.setInt(1, accountIdType);
361 statement.setBytes(2, key.serviceId().toByteArray());
362 statement.setInt(3, key.deviceId());
363 statement.setBytes(4, session.serialize());
364 statement.executeUpdate();
365 }
366 }
367
368 private void deleteAllSessions(final Connection connection, final ServiceId serviceId) throws SQLException {
369 synchronized (cachedSessions) {
370 cachedSessions.clear();
371 }
372
373 final var sql = (
374 """
375 DELETE FROM %s AS s
376 WHERE s.account_id_type = ? AND s.uuid = ?
377 """
378 ).formatted(TABLE_SESSION);
379 try (final var statement = connection.prepareStatement(sql)) {
380 statement.setInt(1, accountIdType);
381 statement.setBytes(2, serviceId.toByteArray());
382 statement.executeUpdate();
383 }
384 }
385
386 private void deleteSession(Connection connection, final Key key) throws SQLException {
387 synchronized (cachedSessions) {
388 cachedSessions.remove(key);
389 }
390
391 final var sql = (
392 """
393 DELETE FROM %s AS s
394 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
395 """
396 ).formatted(TABLE_SESSION);
397 try (final var statement = connection.prepareStatement(sql)) {
398 statement.setInt(1, accountIdType);
399 statement.setBytes(2, key.serviceId().toByteArray());
400 statement.setInt(3, key.deviceId());
401 statement.executeUpdate();
402 }
403 }
404
405 private static boolean isActive(SessionRecord record) {
406 return record != null
407 && record.hasSenderChain()
408 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
409 }
410
411 record Key(ServiceId serviceId, int deviceId) {}
412 }