]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
4bac3bd57ffdf22a396b81140e1fd3c6b78a5769
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.api.Pair;
4 import org.asamk.signal.manager.storage.Database;
5 import org.asamk.signal.manager.storage.Utils;
6 import org.signal.libsignal.protocol.InvalidMessageException;
7 import org.signal.libsignal.protocol.NoSessionException;
8 import org.signal.libsignal.protocol.SignalProtocolAddress;
9 import org.signal.libsignal.protocol.ecc.ECPublicKey;
10 import org.signal.libsignal.protocol.message.CiphertextMessage;
11 import org.signal.libsignal.protocol.state.SessionRecord;
12 import org.slf4j.Logger;
13 import org.slf4j.LoggerFactory;
14 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
15 import org.whispersystems.signalservice.api.push.ServiceId;
16 import org.whispersystems.signalservice.api.push.ServiceIdType;
17 import org.whispersystems.signalservice.api.util.UuidUtil;
18
19 import java.sql.Connection;
20 import java.sql.ResultSet;
21 import java.sql.SQLException;
22 import java.util.ArrayList;
23 import java.util.Collection;
24 import java.util.HashMap;
25 import java.util.List;
26 import java.util.Map;
27 import java.util.Objects;
28 import java.util.Set;
29 import java.util.stream.Collectors;
30
31 public class SessionStore implements SignalServiceSessionStore {
32
33 private static final String TABLE_SESSION = "session";
34 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
35
36 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
37
38 private final Database database;
39 private final int accountIdType;
40
41 public static void createSql(Connection connection) throws SQLException {
42 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
43 try (final var statement = connection.createStatement()) {
44 statement.executeUpdate("""
45 CREATE TABLE session (
46 _id INTEGER PRIMARY KEY,
47 account_id_type INTEGER NOT NULL,
48 uuid BLOB NOT NULL,
49 device_id INTEGER NOT NULL,
50 record BLOB NOT NULL,
51 UNIQUE(account_id_type, uuid, device_id)
52 ) STRICT;
53 """);
54 }
55 }
56
57 public SessionStore(final Database database, final ServiceIdType serviceIdType) {
58 this.database = database;
59 this.accountIdType = Utils.getAccountIdType(serviceIdType);
60 }
61
62 @Override
63 public SessionRecord loadSession(SignalProtocolAddress address) {
64 final var key = getKey(address);
65 try (final var connection = database.getConnection()) {
66 final var session = loadSession(connection, key);
67 return Objects.requireNonNullElseGet(session, SessionRecord::new);
68 } catch (SQLException e) {
69 throw new RuntimeException("Failed read from session store", e);
70 }
71 }
72
73 @Override
74 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
75 final var keys = addresses.stream().map(this::getKey).toList();
76
77 try (final var connection = database.getConnection()) {
78 final var sessions = new ArrayList<SessionRecord>();
79 for (final var key : keys) {
80 final var sessionRecord = loadSession(connection, key);
81 if (sessionRecord != null) {
82 sessions.add(sessionRecord);
83 }
84 }
85
86 if (sessions.size() != addresses.size()) {
87 String message = "Mismatch! Asked for "
88 + addresses.size()
89 + " sessions, but only found "
90 + sessions.size()
91 + "!";
92 logger.warn(message);
93 throw new NoSessionException(message);
94 }
95
96 return sessions;
97 } catch (SQLException e) {
98 throw new RuntimeException("Failed read from session store", e);
99 }
100 }
101
102 @Override
103 public List<Integer> getSubDeviceSessions(String name) {
104 final var serviceId = ServiceId.parseOrThrow(name);
105 // get all sessions for recipient except primary device session
106 final var sql = (
107 """
108 SELECT s.device_id
109 FROM %s AS s
110 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id != 1
111 """
112 ).formatted(TABLE_SESSION);
113 try (final var connection = database.getConnection()) {
114 try (final var statement = connection.prepareStatement(sql)) {
115 statement.setInt(1, accountIdType);
116 statement.setBytes(2, serviceId.toByteArray());
117 return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
118 }
119 } catch (SQLException e) {
120 throw new RuntimeException("Failed read from session store", e);
121 }
122 }
123
124 public boolean isCurrentRatchetKey(ServiceId serviceId, int deviceId, ECPublicKey ratchetKey) {
125 final var key = new Key(serviceId, deviceId);
126
127 try (final var connection = database.getConnection()) {
128 final var session = loadSession(connection, key);
129 if (session == null) {
130 return false;
131 }
132 return session.currentRatchetKeyMatches(ratchetKey);
133 } catch (SQLException e) {
134 throw new RuntimeException("Failed read from session store", e);
135 }
136 }
137
138 @Override
139 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
140 final var key = getKey(address);
141
142 try (final var connection = database.getConnection()) {
143 storeSession(connection, key, session);
144 } catch (SQLException e) {
145 throw new RuntimeException("Failed read from session store", e);
146 }
147 }
148
149 @Override
150 public boolean containsSession(SignalProtocolAddress address) {
151 final var key = getKey(address);
152
153 try (final var connection = database.getConnection()) {
154 final var session = loadSession(connection, key);
155 return isActive(session);
156 } catch (SQLException e) {
157 throw new RuntimeException("Failed read from session store", e);
158 }
159 }
160
161 @Override
162 public void deleteSession(SignalProtocolAddress address) {
163 final var key = getKey(address);
164
165 try (final var connection = database.getConnection()) {
166 deleteSession(connection, key);
167 } catch (SQLException e) {
168 throw new RuntimeException("Failed update session store", e);
169 }
170 }
171
172 @Override
173 public void deleteAllSessions(String name) {
174 final var serviceId = ServiceId.parseOrThrow(name);
175 deleteAllSessions(serviceId);
176 }
177
178 public void deleteAllSessions(ServiceId serviceId) {
179 try (final var connection = database.getConnection()) {
180 deleteAllSessions(connection, serviceId);
181 } catch (SQLException e) {
182 throw new RuntimeException("Failed update session store", e);
183 }
184 }
185
186 @Override
187 public void archiveSession(final SignalProtocolAddress address) {
188 if (!UuidUtil.isUuid(address.getName())) {
189 return;
190 }
191
192 final var key = getKey(address);
193
194 try (final var connection = database.getConnection()) {
195 connection.setAutoCommit(false);
196 final var session = loadSession(connection, key);
197 if (session != null) {
198 session.archiveCurrentState();
199 storeSession(connection, key, session);
200 connection.commit();
201 }
202 } catch (SQLException e) {
203 throw new RuntimeException("Failed update session store", e);
204 }
205 }
206
207 @Override
208 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
209 final var serviceIdsCommaSeparated = addressNames.stream()
210 .map(ServiceId::parseOrThrow)
211 .map(ServiceId::toString)
212 .collect(Collectors.joining(","));
213 final var sql = (
214 """
215 SELECT s.uuid, s.device_id, s.record
216 FROM %s AS s
217 WHERE s.account_id_type = ? AND s.uuid IN (%s)
218 """
219 ).formatted(TABLE_SESSION, serviceIdsCommaSeparated);
220 try (final var connection = database.getConnection()) {
221 try (final var statement = connection.prepareStatement(sql)) {
222 statement.setInt(1, accountIdType);
223 return Utils.executeQueryForStream(statement,
224 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
225 .filter(pair -> isActive(pair.second()))
226 .map(Pair::first)
227 .map(key -> key.serviceId().toProtocolAddress(key.deviceId()))
228 .collect(Collectors.toSet());
229 }
230 } catch (SQLException e) {
231 throw new RuntimeException("Failed read from session store", e);
232 }
233 }
234
235 public void archiveAllSessions() {
236 final var sql = (
237 """
238 SELECT s.uuid, s.device_id, s.record
239 FROM %s AS s
240 WHERE s.account_id_type = ?
241 """
242 ).formatted(TABLE_SESSION);
243 try (final var connection = database.getConnection()) {
244 connection.setAutoCommit(false);
245 final List<Pair<Key, SessionRecord>> records;
246 try (final var statement = connection.prepareStatement(sql)) {
247 statement.setInt(1, accountIdType);
248 records = Utils.executeQueryForStream(statement,
249 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
250 }
251 for (final var record : records) {
252 record.second().archiveCurrentState();
253 storeSession(connection, record.first(), record.second());
254 }
255 connection.commit();
256 } catch (SQLException e) {
257 throw new RuntimeException("Failed update session store", e);
258 }
259 }
260
261 public void archiveSessions(final ServiceId serviceId) {
262 final var sql = (
263 """
264 SELECT s.uuid, s.device_id, s.record
265 FROM %s AS s
266 WHERE s.account_id_type = ? AND s.uuid = ?
267 """
268 ).formatted(TABLE_SESSION);
269 try (final var connection = database.getConnection()) {
270 connection.setAutoCommit(false);
271 final List<Pair<Key, SessionRecord>> records;
272 try (final var statement = connection.prepareStatement(sql)) {
273 statement.setInt(1, accountIdType);
274 statement.setBytes(2, serviceId.toByteArray());
275 records = Utils.executeQueryForStream(statement,
276 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res))).toList();
277 }
278 for (final var record : records) {
279 record.second().archiveCurrentState();
280 storeSession(connection, record.first(), record.second());
281 }
282 connection.commit();
283 } catch (SQLException e) {
284 throw new RuntimeException("Failed update session store", e);
285 }
286 }
287
288 void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
289 logger.debug("Migrating legacy sessions to database");
290 long start = System.nanoTime();
291 try (final var connection = database.getConnection()) {
292 connection.setAutoCommit(false);
293 for (final var pair : sessions) {
294 storeSession(connection, pair.first(), pair.second());
295 }
296 connection.commit();
297 } catch (SQLException e) {
298 throw new RuntimeException("Failed update session store", e);
299 }
300 logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
301 }
302
303 private Key getKey(final SignalProtocolAddress address) {
304 final var serviceId = ServiceId.parseOrThrow(address.getName());
305 return new Key(serviceId, address.getDeviceId());
306 }
307
308 private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
309 synchronized (cachedSessions) {
310 final var session = cachedSessions.get(key);
311 if (session != null) {
312 return session;
313 }
314 }
315 final var sql = (
316 """
317 SELECT s.record
318 FROM %s AS s
319 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
320 """
321 ).formatted(TABLE_SESSION);
322 try (final var statement = connection.prepareStatement(sql)) {
323 statement.setInt(1, accountIdType);
324 statement.setBytes(2, key.serviceId().toByteArray());
325 statement.setInt(3, key.deviceId());
326 return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
327 }
328 }
329
330 private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
331 final var serviceId = ServiceId.parseOrThrow(resultSet.getBytes("uuid"));
332 final var deviceId = resultSet.getInt("device_id");
333 return new Key(serviceId, deviceId);
334 }
335
336 private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) throws SQLException {
337 try {
338 final var record = resultSet.getBytes("record");
339 return new SessionRecord(record);
340 } catch (InvalidMessageException e) {
341 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
342 return null;
343 }
344 }
345
346 private void storeSession(
347 final Connection connection, final Key key, final SessionRecord session
348 ) throws SQLException {
349 synchronized (cachedSessions) {
350 cachedSessions.put(key, session);
351 }
352
353 final var sql = """
354 INSERT OR REPLACE INTO %s (account_id_type, uuid, device_id, record)
355 VALUES (?, ?, ?, ?)
356 """.formatted(TABLE_SESSION);
357 try (final var statement = connection.prepareStatement(sql)) {
358 statement.setInt(1, accountIdType);
359 statement.setBytes(2, key.serviceId().toByteArray());
360 statement.setInt(3, key.deviceId());
361 statement.setBytes(4, session.serialize());
362 statement.executeUpdate();
363 }
364 }
365
366 private void deleteAllSessions(final Connection connection, final ServiceId serviceId) throws SQLException {
367 synchronized (cachedSessions) {
368 cachedSessions.clear();
369 }
370
371 final var sql = (
372 """
373 DELETE FROM %s AS s
374 WHERE s.account_id_type = ? AND s.uuid = ?
375 """
376 ).formatted(TABLE_SESSION);
377 try (final var statement = connection.prepareStatement(sql)) {
378 statement.setInt(1, accountIdType);
379 statement.setBytes(2, serviceId.toByteArray());
380 statement.executeUpdate();
381 }
382 }
383
384 private void deleteSession(Connection connection, final Key key) throws SQLException {
385 synchronized (cachedSessions) {
386 cachedSessions.remove(key);
387 }
388
389 final var sql = (
390 """
391 DELETE FROM %s AS s
392 WHERE s.account_id_type = ? AND s.uuid = ? AND s.device_id = ?
393 """
394 ).formatted(TABLE_SESSION);
395 try (final var statement = connection.prepareStatement(sql)) {
396 statement.setInt(1, accountIdType);
397 statement.setBytes(2, key.serviceId().toByteArray());
398 statement.setInt(3, key.deviceId());
399 statement.executeUpdate();
400 }
401 }
402
403 private static boolean isActive(SessionRecord record) {
404 return record != null
405 && record.hasSenderChain()
406 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
407 }
408
409 record Key(ServiceId serviceId, int deviceId) {}
410 }