]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
067dc5d3c26d8dafee854251b993469209416065
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.api.Pair;
4 import org.asamk.signal.manager.storage.Database;
5 import org.asamk.signal.manager.storage.Utils;
6 import org.signal.libsignal.protocol.NoSessionException;
7 import org.signal.libsignal.protocol.SignalProtocolAddress;
8 import org.signal.libsignal.protocol.ecc.ECPublicKey;
9 import org.signal.libsignal.protocol.state.SessionRecord;
10 import org.slf4j.Logger;
11 import org.slf4j.LoggerFactory;
12 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
13 import org.whispersystems.signalservice.api.push.ServiceId;
14 import org.whispersystems.signalservice.api.push.ServiceIdType;
15
16 import java.sql.Connection;
17 import java.sql.ResultSet;
18 import java.sql.SQLException;
19 import java.util.ArrayList;
20 import java.util.Collection;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Objects;
25 import java.util.stream.Collectors;
26
27 public class SessionStore implements SignalServiceSessionStore {
28
29 private static final String TABLE_SESSION = "session";
30 private static final Logger logger = LoggerFactory.getLogger(SessionStore.class);
31
32 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
33
34 private final Database database;
35 private final int accountIdType;
36
37 public static void createSql(Connection connection) throws SQLException {
38 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
39 try (final var statement = connection.createStatement()) {
40 statement.executeUpdate("""
41 CREATE TABLE session (
42 _id INTEGER PRIMARY KEY,
43 account_id_type INTEGER NOT NULL,
44 address TEXT NOT NULL,
45 device_id INTEGER NOT NULL,
46 record BLOB NOT NULL,
47 UNIQUE(account_id_type, address, device_id)
48 ) STRICT;
49 """);
50 }
51 }
52
53 public SessionStore(final Database database, final ServiceIdType serviceIdType) {
54 this.database = database;
55 this.accountIdType = Utils.getAccountIdType(serviceIdType);
56 }
57
58 @Override
59 public SessionRecord loadSession(SignalProtocolAddress address) {
60 final var key = getKey(address);
61 try (final var connection = database.getConnection()) {
62 final var sessionRecord = Objects.requireNonNullElseGet(loadSession(connection, key), SessionRecord::new);
63 synchronized (cachedSessions) {
64 cachedSessions.put(key, sessionRecord);
65 }
66 return sessionRecord;
67 } catch (SQLException e) {
68 throw new RuntimeException("Failed read from session store", e);
69 }
70 }
71
72 @Override
73 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
74 final var keys = addresses.stream().map(this::getKey).toList();
75
76 try (final var connection = database.getConnection()) {
77 final var sessions = new ArrayList<SessionRecord>();
78 for (final var key : keys) {
79 final var sessionRecord = loadSession(connection, key);
80 if (sessionRecord != null) {
81 sessions.add(sessionRecord);
82 }
83 }
84
85 if (sessions.size() != addresses.size()) {
86 String message = "Mismatch! Asked for "
87 + addresses.size()
88 + " sessions, but only found "
89 + sessions.size()
90 + "!";
91 logger.warn(message);
92 throw new NoSessionException(message);
93 }
94
95 return sessions;
96 } catch (SQLException e) {
97 throw new RuntimeException("Failed read from session store", e);
98 }
99 }
100
101 @Override
102 public List<Integer> getSubDeviceSessions(String name) {
103 final var serviceId = ServiceId.parseOrThrow(name);
104 // get all sessions for recipient except primary device session
105 final var sql = (
106 """
107 SELECT s.device_id
108 FROM %s AS s
109 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id != 1
110 """
111 ).formatted(TABLE_SESSION);
112 try (final var connection = database.getConnection()) {
113 try (final var statement = connection.prepareStatement(sql)) {
114 statement.setInt(1, accountIdType);
115 statement.setString(2, serviceId.toString());
116 return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
117 }
118 } catch (SQLException e) {
119 throw new RuntimeException("Failed read from session store", e);
120 }
121 }
122
123 public boolean isCurrentRatchetKey(ServiceId serviceId, int deviceId, ECPublicKey ratchetKey) {
124 final var key = new Key(serviceId.toString(), deviceId);
125
126 try (final var connection = database.getConnection()) {
127 final var session = loadSession(connection, key);
128 if (session == null) {
129 return false;
130 }
131 return session.currentRatchetKeyMatches(ratchetKey);
132 } catch (SQLException e) {
133 throw new RuntimeException("Failed read from session store", e);
134 }
135 }
136
137 @Override
138 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
139 final var key = getKey(address);
140
141 try (final var connection = database.getConnection()) {
142 storeSession(connection, key, session);
143 } catch (SQLException e) {
144 throw new RuntimeException("Failed read from session store", e);
145 }
146 }
147
148 @Override
149 public boolean containsSession(SignalProtocolAddress address) {
150 final var key = getKey(address);
151
152 try (final var connection = database.getConnection()) {
153 final var session = loadSession(connection, key);
154 final var active = isActive(session);
155 logger.trace("Contains session {}: {} (active: {})", address, session != null, active);
156 return active;
157 } catch (SQLException e) {
158 throw new RuntimeException("Failed read from session store", e);
159 }
160 }
161
162 @Override
163 public void deleteSession(SignalProtocolAddress address) {
164 final var key = getKey(address);
165
166 try (final var connection = database.getConnection()) {
167 deleteSession(connection, key);
168 } catch (SQLException e) {
169 throw new RuntimeException("Failed update session store", e);
170 }
171 }
172
173 @Override
174 public void deleteAllSessions(String name) {
175 final var serviceId = ServiceId.parseOrThrow(name);
176 deleteAllSessions(serviceId);
177 }
178
179 public void deleteAllSessions(ServiceId serviceId) {
180 try (final var connection = database.getConnection()) {
181 deleteAllSessions(connection, serviceId.toString());
182 } catch (SQLException e) {
183 throw new RuntimeException("Failed update session store", e);
184 }
185 }
186
187 @Override
188 public void archiveSession(final SignalProtocolAddress address) {
189 final var key = getKey(address);
190
191 try (final var connection = database.getConnection()) {
192 connection.setAutoCommit(false);
193 final var session = loadSession(connection, key);
194 if (session != null) {
195 session.archiveCurrentState();
196 storeSession(connection, key, session);
197 connection.commit();
198 }
199 } catch (SQLException e) {
200 throw new RuntimeException("Failed update session store", e);
201 }
202 }
203
204 @Override
205 public Map<SignalProtocolAddress, SessionRecord> getAllAddressesWithActiveSessions(final List<String> addressNames) {
206 final var serviceIdsCommaSeparated = addressNames.stream()
207 .map(address -> "'" + address.replaceAll("'", "''") + "'")
208 .collect(Collectors.joining(","));
209 final var sql = (
210 """
211 SELECT s.address, s.device_id, s.record
212 FROM %s AS s
213 WHERE s.account_id_type = ? AND s.address IN (%s)
214 """
215 ).formatted(TABLE_SESSION, serviceIdsCommaSeparated);
216 try (final var connection = database.getConnection()) {
217 try (final var statement = connection.prepareStatement(sql)) {
218 statement.setInt(1, accountIdType);
219 return Utils.executeQueryForStream(statement,
220 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
221 .filter(pair -> isActive(pair.second()))
222 .collect(Collectors.toMap(pair -> new SignalProtocolAddress(pair.first().address(),
223 pair.first().deviceId()), Pair::second));
224 }
225 } catch (SQLException e) {
226 throw new RuntimeException("Failed read from session store", e);
227 }
228 }
229
230 public void archiveAllSessions() {
231 final var sql = (
232 """
233 SELECT s.address, s.device_id, s.record
234 FROM %s AS s
235 WHERE s.account_id_type = ?
236 """
237 ).formatted(TABLE_SESSION);
238 try (final var connection = database.getConnection()) {
239 connection.setAutoCommit(false);
240 final List<Pair<Key, SessionRecord>> records;
241 try (final var statement = connection.prepareStatement(sql)) {
242 statement.setInt(1, accountIdType);
243 records = Utils.executeQueryForStream(statement,
244 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
245 .filter(Objects::nonNull)
246 .toList();
247 }
248 for (final var record : records) {
249 record.second().archiveCurrentState();
250 storeSession(connection, record.first(), record.second());
251 }
252 connection.commit();
253 } catch (SQLException e) {
254 throw new RuntimeException("Failed update session store", e);
255 }
256 }
257
258 public void archiveSessions(final ServiceId serviceId) {
259 final var sql = (
260 """
261 SELECT s.address, s.device_id, s.record
262 FROM %s AS s
263 WHERE s.account_id_type = ? AND s.address = ?
264 """
265 ).formatted(TABLE_SESSION);
266 try (final var connection = database.getConnection()) {
267 connection.setAutoCommit(false);
268 final List<Pair<Key, SessionRecord>> records;
269 try (final var statement = connection.prepareStatement(sql)) {
270 statement.setInt(1, accountIdType);
271 statement.setString(2, serviceId.toString());
272 records = Utils.executeQueryForStream(statement,
273 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
274 .filter(Objects::nonNull)
275 .toList();
276 }
277 for (final var record : records) {
278 record.second().archiveCurrentState();
279 storeSession(connection, record.first(), record.second());
280 }
281 connection.commit();
282 } catch (SQLException e) {
283 throw new RuntimeException("Failed update session store", e);
284 }
285 }
286
287 void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
288 logger.debug("Migrating legacy sessions to database");
289 long start = System.nanoTime();
290 try (final var connection = database.getConnection()) {
291 connection.setAutoCommit(false);
292 for (final var pair : sessions) {
293 storeSession(connection, pair.first(), pair.second());
294 }
295 connection.commit();
296 } catch (SQLException e) {
297 throw new RuntimeException("Failed update session store", e);
298 }
299 logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
300 }
301
302 private Key getKey(final SignalProtocolAddress address) {
303 return new Key(address.getName(), address.getDeviceId());
304 }
305
306 private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
307 synchronized (cachedSessions) {
308 final var session = cachedSessions.get(key);
309 if (session != null) {
310 return session;
311 }
312 }
313 final var sql = (
314 """
315 SELECT s.record
316 FROM %s AS s
317 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
318 """
319 ).formatted(TABLE_SESSION);
320 try (final var statement = connection.prepareStatement(sql)) {
321 statement.setInt(1, accountIdType);
322 statement.setString(2, key.address());
323 statement.setInt(3, key.deviceId());
324 return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
325 }
326 }
327
328 private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
329 final var address = resultSet.getString("address");
330 final var deviceId = resultSet.getInt("device_id");
331 return new Key(address, deviceId);
332 }
333
334 private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) {
335 try {
336 final var record = resultSet.getBytes("record");
337 return new SessionRecord(record);
338 } catch (Exception e) {
339 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
340 return null;
341 }
342 }
343
344 private void storeSession(
345 final Connection connection, final Key key, final SessionRecord session
346 ) throws SQLException {
347 synchronized (cachedSessions) {
348 cachedSessions.put(key, session);
349 }
350
351 final var sql = """
352 INSERT OR REPLACE INTO %s (account_id_type, address, device_id, record)
353 VALUES (?, ?, ?, ?)
354 """.formatted(TABLE_SESSION);
355 try (final var statement = connection.prepareStatement(sql)) {
356 statement.setInt(1, accountIdType);
357 statement.setString(2, key.address());
358 statement.setInt(3, key.deviceId());
359 statement.setBytes(4, session.serialize());
360 statement.executeUpdate();
361 }
362 }
363
364 private void deleteAllSessions(final Connection connection, final String address) throws SQLException {
365 synchronized (cachedSessions) {
366 cachedSessions.clear();
367 }
368
369 final var sql = (
370 """
371 DELETE FROM %s AS s
372 WHERE s.account_id_type = ? AND s.address = ?
373 """
374 ).formatted(TABLE_SESSION);
375 try (final var statement = connection.prepareStatement(sql)) {
376 statement.setInt(1, accountIdType);
377 statement.setString(2, address);
378 statement.executeUpdate();
379 }
380 }
381
382 private void deleteSession(Connection connection, final Key key) throws SQLException {
383 synchronized (cachedSessions) {
384 cachedSessions.remove(key);
385 }
386
387 final var sql = (
388 """
389 DELETE FROM %s AS s
390 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
391 """
392 ).formatted(TABLE_SESSION);
393 try (final var statement = connection.prepareStatement(sql)) {
394 statement.setInt(1, accountIdType);
395 statement.setString(2, key.address());
396 statement.setInt(3, key.deviceId());
397 statement.executeUpdate();
398 }
399 }
400
401 private static boolean isActive(SessionRecord record) {
402 return record != null && record.hasSenderChain();
403 }
404
405 record Key(String address, int deviceId) {}
406 }