]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
564df1d0b27883864c67b4f9432cd9217ffc77e2
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.api.Pair;
4 import org.asamk.signal.manager.storage.Database;
5 import org.asamk.signal.manager.storage.Utils;
6 import org.signal.libsignal.protocol.NoSessionException;
7 import org.signal.libsignal.protocol.SignalProtocolAddress;
8 import org.signal.libsignal.protocol.ecc.ECPublicKey;
9 import org.signal.libsignal.protocol.message.CiphertextMessage;
10 import org.signal.libsignal.protocol.state.SessionRecord;
11 import org.signal.libsignal.protocol.util.Hex;
12 import org.slf4j.Logger;
13 import org.slf4j.LoggerFactory;
14 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
15 import org.whispersystems.signalservice.api.push.ServiceId;
16 import org.whispersystems.signalservice.api.push.ServiceIdType;
17 import org.whispersystems.signalservice.api.util.UuidUtil;
18
19 import java.sql.Connection;
20 import java.sql.ResultSet;
21 import java.sql.SQLException;
22 import java.util.ArrayList;
23 import java.util.Collection;
24 import java.util.HashMap;
25 import java.util.List;
26 import java.util.Map;
27 import java.util.Objects;
28 import java.util.Set;
29 import java.util.stream.Collectors;
30
31 public class SessionStore implements SignalServiceSessionStore {
32
33 private static final String TABLE_SESSION = "session";
34 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
35
36 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
37
38 private final Database database;
39 private final int accountIdType;
40
41 public static void createSql(Connection connection) throws SQLException {
42 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
43 try (final var statement = connection.createStatement()) {
44 statement.executeUpdate("""
45 CREATE TABLE session (
46 _id INTEGER PRIMARY KEY,
47 account_id_type INTEGER NOT NULL,
48 uuid BLOB NOT NULL,
49 device_id INTEGER NOT NULL,
50 record BLOB NOT NULL,
51 UNIQUE(account_id_type, uuid, device_id)
52 ) STRICT;
53 """);
54 }
55 }
56
57 public SessionStore(final Database database, final ServiceIdType serviceIdType) {
58 this.database = database;
59 this.accountIdType = Utils.getAccountIdType(serviceIdType);
60 }
61
62 @Override
63 public SessionRecord loadSession(SignalProtocolAddress address) {
64 final var key = getKey(address);
65 try (final var connection = database.getConnection()) {
66 final var session = loadSession(connection, key);
67 return Objects.requireNonNullElseGet(session, SessionRecord::new);
68 } catch (SQLException e) {
69 throw new RuntimeException("Failed read from session store", e);
70 }
71 }
72
73 @Override
74 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
75 final var keys = addresses.stream().map(this::getKey).toList();
76
77 try (final var connection = database.getConnection()) {
78 final var sessions = new ArrayList<SessionRecord>();
79 for (final var key : keys) {
80 final var sessionRecord = loadSession(connection, key);
81 if (sessionRecord != null) {
82 sessions.add(sessionRecord);
83 }
84 }
85
86 if (sessions.size() != addresses.size()) {
87 String message = "Mismatch! Asked for "
88 + addresses.size()
89 + " sessions, but only found "
90 + sessions.size()
91 + "!";
92 logger.warn(message);
93 throw new NoSessionException(message);
94 }
95
96 return sessions;
97 } catch (SQLException e) {
98 throw new RuntimeException("Failed read from session store", e);
99 }
100 }
101
102 @Override
103 public List<Integer> getSubDeviceSessions(String name) {
104 final var serviceId = ServiceId.parseOrThrow(name);
105 // get all sessions for recipient except primary device session
106 final var sql = (
107 """
108 SELECT s.device_id
109 FROM %s AS s
110 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id != 1
111 """
112 ).formatted(TABLE_SESSION);
113 try (final var connection = database.getConnection()) {
114 try (final var statement = connection.prepareStatement(sql)) {
115 statement.setInt(1, accountIdType);
116 statement.setBytes(2, serviceId.toByteArray());
117 return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
118 }
119 } catch (SQLException e) {
120 throw new RuntimeException("Failed read from session store", e);
121 }
122 }
123
124 public boolean isCurrentRatchetKey(ServiceId serviceId, int deviceId, ECPublicKey ratchetKey) {
125 final var key = new Key(serviceId, deviceId);
126
127 try (final var connection = database.getConnection()) {
128 final var session = loadSession(connection, key);
129 if (session == null) {
130 return false;
131 }
132 return session.currentRatchetKeyMatches(ratchetKey);
133 } catch (SQLException e) {
134 throw new RuntimeException("Failed read from session store", e);
135 }
136 }
137
138 @Override
139 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
140 final var key = getKey(address);
141
142 try (final var connection = database.getConnection()) {
143 storeSession(connection, key, session);
144 } catch (SQLException e) {
145 throw new RuntimeException("Failed read from session store", e);
146 }
147 }
148
149 @Override
150 public boolean containsSession(SignalProtocolAddress address) {
151 final var key = getKey(address);
152
153 try (final var connection = database.getConnection()) {
154 final var session = loadSession(connection, key);
155 return isActive(session);
156 } catch (SQLException e) {
157 throw new RuntimeException("Failed read from session store", e);
158 }
159 }
160
161 @Override
162 public void deleteSession(SignalProtocolAddress address) {
163 final var key = getKey(address);
164
165 try (final var connection = database.getConnection()) {
166 deleteSession(connection, key);
167 } catch (SQLException e) {
168 throw new RuntimeException("Failed update session store", e);
169 }
170 }
171
172 @Override
173 public void deleteAllSessions(String name) {
174 final var serviceId = ServiceId.parseOrThrow(name);
175 deleteAllSessions(serviceId);
176 }
177
178 public void deleteAllSessions(ServiceId serviceId) {
179 try (final var connection = database.getConnection()) {
180 deleteAllSessions(connection, serviceId);
181 } catch (SQLException e) {
182 throw new RuntimeException("Failed update session store", e);
183 }
184 }
185
186 @Override
187 public void archiveSession(final SignalProtocolAddress address) {
188 if (!UuidUtil.isUuid(address.getName())) {
189 return;
190 }
191
192 final var key = getKey(address);
193
194 try (final var connection = database.getConnection()) {
195 connection.setAutoCommit(false);
196 final var session = loadSession(connection, key);
197 if (session != null) {
198 session.archiveCurrentState();
199 storeSession(connection, key, session);
200 connection.commit();
201 }
202 } catch (SQLException e) {
203 throw new RuntimeException("Failed update session store", e);
204 }
205 }
206
207 @Override
208 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
209 final var serviceIdsCommaSeparated = addressNames.stream()
210 .map(ServiceId::parseOrThrow)
211 .map(ServiceId::toByteArray)
212 .map(uuid -> "x'" + Hex.toStringCondensed(uuid) + "'")
213 .collect(Collectors.joining(","));
214 final var sql = (
215 """
216 SELECT s.uuid, s.device_id, s.record
217 FROM %s AS s
218 WHERE s.account_id_type = ? AND s.uuid IN (%s)
219 """
220 ).formatted(TABLE_SESSION, serviceIdsCommaSeparated);
221 try (final var connection = database.getConnection()) {
222 try (final var statement = connection.prepareStatement(sql)) {
223 statement.setInt(1, accountIdType);
224 return Utils.executeQueryForStream(statement,
225 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
226 .filter(pair -> isActive(pair.second()))
227 .map(Pair::first)
228 .map(key -> key.serviceId().toProtocolAddress(key.deviceId()))
229 .collect(Collectors.toSet());
230 }
231 } catch (SQLException e) {
232 throw new RuntimeException("Failed read from session store", e);
233 }
234 }
235
236 public void archiveAllSessions() {
237 final var sql = (
238 """
239 SELECT s.uuid, s.device_id, s.record
240 FROM %s AS s
241 WHERE s.account_id_type = ?
242 """
243 ).formatted(TABLE_SESSION);
244 try (final var connection = database.getConnection()) {
245 connection.setAutoCommit(false);
246 final List<Pair<Key, SessionRecord>> records;
247 try (final var statement = connection.prepareStatement(sql)) {
248 statement.setInt(1, accountIdType);
249 records = Utils.executeQueryForStream(statement,
250 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
251 }
252 for (final var record : records) {
253 record.second().archiveCurrentState();
254 storeSession(connection, record.first(), record.second());
255 }
256 connection.commit();
257 } catch (SQLException e) {
258 throw new RuntimeException("Failed update session store", e);
259 }
260 }
261
262 public void archiveSessions(final ServiceId serviceId) {
263 final var sql = (
264 """
265 SELECT s.uuid, s.device_id, s.record
266 FROM %s AS s
267 WHERE s.account_id_type = ? AND s.uuid = ?
268 """
269 ).formatted(TABLE_SESSION);
270 try (final var connection = database.getConnection()) {
271 connection.setAutoCommit(false);
272 final List<Pair<Key, SessionRecord>> records;
273 try (final var statement = connection.prepareStatement(sql)) {
274 statement.setInt(1, accountIdType);
275 statement.setBytes(2, serviceId.toByteArray());
276 records = Utils.executeQueryForStream(statement,
277 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
278 }
279 for (final var record : records) {
280 record.second().archiveCurrentState();
281 storeSession(connection, record.first(), record.second());
282 }
283 connection.commit();
284 } catch (SQLException e) {
285 throw new RuntimeException("Failed update session store", e);
286 }
287 }
288
289 void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
290 logger.debug("Migrating legacy sessions to database");
291 long start = System.nanoTime();
292 try (final var connection = database.getConnection()) {
293 connection.setAutoCommit(false);
294 for (final var pair : sessions) {
295 storeSession(connection, pair.first(), pair.second());
296 }
297 connection.commit();
298 } catch (SQLException e) {
299 throw new RuntimeException("Failed update session store", e);
300 }
301 logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
302 }
303
304 private Key getKey(final SignalProtocolAddress address) {
305 final var serviceId = ServiceId.parseOrThrow(address.getName());
306 return new Key(serviceId, address.getDeviceId());
307 }
308
309 private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
310 synchronized (cachedSessions) {
311 final var session = cachedSessions.get(key);
312 if (session != null) {
313 return session;
314 }
315 }
316 final var sql = (
317 """
318 SELECT s.record
319 FROM %s AS s
320 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
321 """
322 ).formatted(TABLE_SESSION);
323 try (final var statement = connection.prepareStatement(sql)) {
324 statement.setInt(1, accountIdType);
325 statement.setBytes(2, key.serviceId().toByteArray());
326 statement.setInt(3, key.deviceId());
327 return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
328 }
329 }
330
331 private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
332 final var serviceId = ServiceId.parseOrThrow(resultSet.getBytes("uuid"));
333 final var deviceId = resultSet.getInt("device_id");
334 return new Key(serviceId, deviceId);
335 }
336
337 private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException {
338 try {
339 final var record = resultSet.getBytes("record");
340 return new SessionRecord(record);
341 } catch (Exception e) {
342 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
343 return null;
344 }
345 }
346
347 private void storeSession(
348 final Connection connection, final Key key, final SessionRecord session
349 ) throws SQLException {
350 synchronized (cachedSessions) {
351 cachedSessions.put(key, session);
352 }
353
354 final var sql = """
355 INSERT OR REPLACE INTO %s (account_id_type, uuid, device_id, record)
356 VALUES (?, ?, ?, ?)
357 """.formatted(TABLE_SESSION);
358 try (final var statement = connection.prepareStatement(sql)) {
359 statement.setInt(1, accountIdType);
360 statement.setBytes(2, key.serviceId().toByteArray());
361 statement.setInt(3, key.deviceId());
362 statement.setBytes(4, session.serialize());
363 statement.executeUpdate();
364 }
365 }
366
367 private void deleteAllSessions(final Connection connection, final ServiceId serviceId) throws SQLException {
368 synchronized (cachedSessions) {
369 cachedSessions.clear();
370 }
371
372 final var sql = (
373 """
374 DELETE FROM %s AS s
375 WHERE s.account_id_type = ? AND s.uuid = ?
376 """
377 ).formatted(TABLE_SESSION);
378 try (final var statement = connection.prepareStatement(sql)) {
379 statement.setInt(1, accountIdType);
380 statement.setBytes(2, serviceId.toByteArray());
381 statement.executeUpdate();
382 }
383 }
384
385 private void deleteSession(Connection connection, final Key key) throws SQLException {
386 synchronized (cachedSessions) {
387 cachedSessions.remove(key);
388 }
389
390 final var sql = (
391 """
392 DELETE FROM %s AS s
393 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
394 """
395 ).formatted(TABLE_SESSION);
396 try (final var statement = connection.prepareStatement(sql)) {
397 statement.setInt(1, accountIdType);
398 statement.setBytes(2, key.serviceId().toByteArray());
399 statement.setInt(3, key.deviceId());
400 statement.executeUpdate();
401 }
402 }
403
404 private static boolean isActive(SessionRecord record) {
405 return record != null
406 && record.hasSenderChain()
407 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
408 }
409
410 record Key(ServiceId serviceId, int deviceId) {}
411 }