]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
0f26758315b19a65877a14924dce35f3a8e8e4aa
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.api.Pair;
4 import org.asamk.signal.manager.storage.Database;
5 import org.asamk.signal.manager.storage.Utils;
6 import org.signal.libsignal.protocol.NoSessionException;
7 import org.signal.libsignal.protocol.SignalProtocolAddress;
8 import org.signal.libsignal.protocol.ecc.ECPublicKey;
9 import org.signal.libsignal.protocol.message.CiphertextMessage;
10 import org.signal.libsignal.protocol.state.SessionRecord;
11 import org.signal.libsignal.protocol.util.Hex;
12 import org.slf4j.Logger;
13 import org.slf4j.LoggerFactory;
14 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
15 import org.whispersystems.signalservice.api.push.ServiceId;
16 import org.whispersystems.signalservice.api.push.ServiceIdType;
17 import org.whispersystems.signalservice.api.util.UuidUtil;
18
19 import java.sql.Connection;
20 import java.sql.ResultSet;
21 import java.sql.SQLException;
22 import java.util.ArrayList;
23 import java.util.Collection;
24 import java.util.HashMap;
25 import java.util.List;
26 import java.util.Map;
27 import java.util.Objects;
28 import java.util.Set;
29 import java.util.stream.Collectors;
30
31 public class SessionStore implements SignalServiceSessionStore {
32
33 private static final String TABLE_SESSION = "session";
34 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
35
36 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
37
38 private final Database database;
39 private final int accountIdType;
40
41 public static void createSql(Connection connection) throws SQLException {
42 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
43 try (final var statement = connection.createStatement()) {
44 statement.executeUpdate("""
45 CREATE TABLE session (
46 _id INTEGER PRIMARY KEY,
47 account_id_type INTEGER NOT NULL,
48 uuid BLOB NOT NULL,
49 device_id INTEGER NOT NULL,
50 record BLOB NOT NULL,
51 UNIQUE(account_id_type, uuid, device_id)
52 ) STRICT;
53 """);
54 }
55 }
56
57 public SessionStore(final Database database, final ServiceIdType serviceIdType) {
58 this.database = database;
59 this.accountIdType = Utils.getAccountIdType(serviceIdType);
60 }
61
62 @Override
63 public SessionRecord loadSession(SignalProtocolAddress address) {
64 final var key = getKey(address);
65 try (final var connection = database.getConnection()) {
66 final var session = loadSession(connection, key);
67 return Objects.requireNonNullElseGet(session, SessionRecord::new);
68 } catch (SQLException e) {
69 throw new RuntimeException("Failed read from session store", e);
70 }
71 }
72
73 @Override
74 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
75 final var keys = addresses.stream().map(this::getKey).toList();
76
77 try (final var connection = database.getConnection()) {
78 final var sessions = new ArrayList<SessionRecord>();
79 for (final var key : keys) {
80 final var sessionRecord = loadSession(connection, key);
81 if (sessionRecord != null) {
82 sessions.add(sessionRecord);
83 }
84 }
85
86 if (sessions.size() != addresses.size()) {
87 String message = "Mismatch! Asked for "
88 + addresses.size()
89 + " sessions, but only found "
90 + sessions.size()
91 + "!";
92 logger.warn(message);
93 throw new NoSessionException(message);
94 }
95
96 return sessions;
97 } catch (SQLException e) {
98 throw new RuntimeException("Failed read from session store", e);
99 }
100 }
101
102 @Override
103 public List<Integer> getSubDeviceSessions(String name) {
104 final var serviceId = ServiceId.parseOrThrow(name);
105 // get all sessions for recipient except primary device session
106 final var sql = (
107 """
108 SELECT s.device_id
109 FROM %s AS s
110 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id != 1
111 """
112 ).formatted(TABLE_SESSION);
113 try (final var connection = database.getConnection()) {
114 try (final var statement = connection.prepareStatement(sql)) {
115 statement.setInt(1, accountIdType);
116 statement.setBytes(2, serviceId.toByteArray());
117 return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
118 }
119 } catch (SQLException e) {
120 throw new RuntimeException("Failed read from session store", e);
121 }
122 }
123
124 public boolean isCurrentRatchetKey(ServiceId serviceId, int deviceId, ECPublicKey ratchetKey) {
125 final var key = new Key(serviceId, deviceId);
126
127 try (final var connection = database.getConnection()) {
128 final var session = loadSession(connection, key);
129 if (session == null) {
130 return false;
131 }
132 return session.currentRatchetKeyMatches(ratchetKey);
133 } catch (SQLException e) {
134 throw new RuntimeException("Failed read from session store", e);
135 }
136 }
137
138 @Override
139 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
140 final var key = getKey(address);
141
142 try (final var connection = database.getConnection()) {
143 storeSession(connection, key, session);
144 } catch (SQLException e) {
145 throw new RuntimeException("Failed read from session store", e);
146 }
147 }
148
149 @Override
150 public boolean containsSession(SignalProtocolAddress address) {
151 final var key = getKey(address);
152
153 try (final var connection = database.getConnection()) {
154 final var session = loadSession(connection, key);
155 return isActive(session);
156 } catch (SQLException e) {
157 throw new RuntimeException("Failed read from session store", e);
158 }
159 }
160
161 @Override
162 public void deleteSession(SignalProtocolAddress address) {
163 final var key = getKey(address);
164
165 try (final var connection = database.getConnection()) {
166 deleteSession(connection, key);
167 } catch (SQLException e) {
168 throw new RuntimeException("Failed update session store", e);
169 }
170 }
171
172 @Override
173 public void deleteAllSessions(String name) {
174 final var serviceId = ServiceId.parseOrThrow(name);
175 deleteAllSessions(serviceId);
176 }
177
178 public void deleteAllSessions(ServiceId serviceId) {
179 try (final var connection = database.getConnection()) {
180 deleteAllSessions(connection, serviceId);
181 } catch (SQLException e) {
182 throw new RuntimeException("Failed update session store", e);
183 }
184 }
185
186 @Override
187 public void archiveSession(final SignalProtocolAddress address) {
188 if (!UuidUtil.isUuid(address.getName())) {
189 return;
190 }
191
192 final var key = getKey(address);
193
194 try (final var connection = database.getConnection()) {
195 connection.setAutoCommit(false);
196 final var session = loadSession(connection, key);
197 if (session != null) {
198 session.archiveCurrentState();
199 storeSession(connection, key, session);
200 connection.commit();
201 }
202 } catch (SQLException e) {
203 throw new RuntimeException("Failed update session store", e);
204 }
205 }
206
207 @Override
208 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
209 final var serviceIdsCommaSeparated = addressNames.stream()
210 .map(ServiceId::parseOrThrow)
211 .map(ServiceId::toByteArray)
212 .map(uuid -> "x'" + Hex.toStringCondensed(uuid) + "'")
213 .collect(Collectors.joining(","));
214 final var sql = (
215 """
216 SELECT s.uuid, s.device_id, s.record
217 FROM %s AS s
218 WHERE s.account_id_type = ? AND s.uuid IN (%s)
219 """
220 ).formatted(TABLE_SESSION, serviceIdsCommaSeparated);
221 try (final var connection = database.getConnection()) {
222 try (final var statement = connection.prepareStatement(sql)) {
223 statement.setInt(1, accountIdType);
224 return Utils.executeQueryForStream(statement,
225 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
226 .filter(pair -> isActive(pair.second()))
227 .map(Pair::first)
228 .map(key -> key.serviceId().toProtocolAddress(key.deviceId()))
229 .collect(Collectors.toSet());
230 }
231 } catch (SQLException e) {
232 throw new RuntimeException("Failed read from session store", e);
233 }
234 }
235
236 public void archiveAllSessions() {
237 final var sql = (
238 """
239 SELECT s.uuid, s.device_id, s.record
240 FROM %s AS s
241 WHERE s.account_id_type = ?
242 """
243 ).formatted(TABLE_SESSION);
244 try (final var connection = database.getConnection()) {
245 connection.setAutoCommit(false);
246 final List<Pair<Key, SessionRecord>> records;
247 try (final var statement = connection.prepareStatement(sql)) {
248 statement.setInt(1, accountIdType);
249 records = Utils.executeQueryForStream(statement,
250 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
251 .filter(Objects::nonNull)
252 .toList();
253 }
254 for (final var record : records) {
255 record.second().archiveCurrentState();
256 storeSession(connection, record.first(), record.second());
257 }
258 connection.commit();
259 } catch (SQLException e) {
260 throw new RuntimeException("Failed update session store", e);
261 }
262 }
263
264 public void archiveSessions(final ServiceId serviceId) {
265 final var sql = (
266 """
267 SELECT s.uuid, s.device_id, s.record
268 FROM %s AS s
269 WHERE s.account_id_type = ? AND s.uuid = ?
270 """
271 ).formatted(TABLE_SESSION);
272 try (final var connection = database.getConnection()) {
273 connection.setAutoCommit(false);
274 final List<Pair<Key, SessionRecord>> records;
275 try (final var statement = connection.prepareStatement(sql)) {
276 statement.setInt(1, accountIdType);
277 statement.setBytes(2, serviceId.toByteArray());
278 records = Utils.executeQueryForStream(statement,
279 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
280 .filter(Objects::nonNull)
281 .toList();
282 }
283 for (final var record : records) {
284 record.second().archiveCurrentState();
285 storeSession(connection, record.first(), record.second());
286 }
287 connection.commit();
288 } catch (SQLException e) {
289 throw new RuntimeException("Failed update session store", e);
290 }
291 }
292
293 void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
294 logger.debug("Migrating legacy sessions to database");
295 long start = System.nanoTime();
296 try (final var connection = database.getConnection()) {
297 connection.setAutoCommit(false);
298 for (final var pair : sessions) {
299 storeSession(connection, pair.first(), pair.second());
300 }
301 connection.commit();
302 } catch (SQLException e) {
303 throw new RuntimeException("Failed update session store", e);
304 }
305 logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
306 }
307
308 private Key getKey(final SignalProtocolAddress address) {
309 final var serviceId = ServiceId.parseOrThrow(address.getName());
310 return new Key(serviceId, address.getDeviceId());
311 }
312
313 private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
314 synchronized (cachedSessions) {
315 final var session = cachedSessions.get(key);
316 if (session != null) {
317 return session;
318 }
319 }
320 final var sql = (
321 """
322 SELECT s.record
323 FROM %s AS s
324 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
325 """
326 ).formatted(TABLE_SESSION);
327 try (final var statement = connection.prepareStatement(sql)) {
328 statement.setInt(1, accountIdType);
329 statement.setBytes(2, key.serviceId().toByteArray());
330 statement.setInt(3, key.deviceId());
331 return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
332 }
333 }
334
335 private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
336 final var serviceId = ServiceId.parseOrThrow(resultSet.getBytes("uuid"));
337 final var deviceId = resultSet.getInt("device_id");
338 return new Key(serviceId, deviceId);
339 }
340
341 private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException {
342 try {
343 final var record = resultSet.getBytes("record");
344 return new SessionRecord(record);
345 } catch (Exception e) {
346 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
347 return null;
348 }
349 }
350
351 private void storeSession(
352 final Connection connection, final Key key, final SessionRecord session
353 ) throws SQLException {
354 synchronized (cachedSessions) {
355 cachedSessions.put(key, session);
356 }
357
358 final var sql = """
359 INSERT OR REPLACE INTO %s (account_id_type, uuid, device_id, record)
360 VALUES (?, ?, ?, ?)
361 """.formatted(TABLE_SESSION);
362 try (final var statement = connection.prepareStatement(sql)) {
363 statement.setInt(1, accountIdType);
364 statement.setBytes(2, key.serviceId().toByteArray());
365 statement.setInt(3, key.deviceId());
366 statement.setBytes(4, session.serialize());
367 statement.executeUpdate();
368 }
369 }
370
371 private void deleteAllSessions(final Connection connection, final ServiceId serviceId) throws SQLException {
372 synchronized (cachedSessions) {
373 cachedSessions.clear();
374 }
375
376 final var sql = (
377 """
378 DELETE FROM %s AS s
379 WHERE s.account_id_type = ? AND s.uuid = ?
380 """
381 ).formatted(TABLE_SESSION);
382 try (final var statement = connection.prepareStatement(sql)) {
383 statement.setInt(1, accountIdType);
384 statement.setBytes(2, serviceId.toByteArray());
385 statement.executeUpdate();
386 }
387 }
388
389 private void deleteSession(Connection connection, final Key key) throws SQLException {
390 synchronized (cachedSessions) {
391 cachedSessions.remove(key);
392 }
393
394 final var sql = (
395 """
396 DELETE FROM %s AS s
397 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
398 """
399 ).formatted(TABLE_SESSION);
400 try (final var statement = connection.prepareStatement(sql)) {
401 statement.setInt(1, accountIdType);
402 statement.setBytes(2, key.serviceId().toByteArray());
403 statement.setInt(3, key.deviceId());
404 statement.executeUpdate();
405 }
406 }
407
408 private static boolean isActive(SessionRecord record) {
409 return record != null
410 && record.hasSenderChain()
411 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
412 }
413
414 record Key(ServiceId serviceId, int deviceId) {}
415 }