]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
97f428baf99b8785de96f45ecae5a9397d20ec5e
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.api.Pair;
4 import org.asamk.signal.manager.storage.Database;
5 import org.asamk.signal.manager.storage.Utils;
6 import org.signal.libsignal.protocol.NoSessionException;
7 import org.signal.libsignal.protocol.SignalProtocolAddress;
8 import org.signal.libsignal.protocol.ecc.ECPublicKey;
9 import org.signal.libsignal.protocol.state.SessionRecord;
10 import org.slf4j.Logger;
11 import org.slf4j.LoggerFactory;
12 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
13 import org.whispersystems.signalservice.api.push.ServiceId;
14 import org.whispersystems.signalservice.api.push.ServiceIdType;
15
16 import java.sql.Connection;
17 import java.sql.ResultSet;
18 import java.sql.SQLException;
19 import java.util.ArrayList;
20 import java.util.Collection;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Objects;
25 import java.util.Set;
26 import java.util.stream.Collectors;
27
28 public class SessionStore implements SignalServiceSessionStore {
29
30 private static final String TABLE_SESSION = "session";
31 private static final Logger logger = LoggerFactory.getLogger(SessionStore.class);
32
33 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
34
35 private final Database database;
36 private final int accountIdType;
37
38 public static void createSql(Connection connection) throws SQLException {
39 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
40 try (final var statement = connection.createStatement()) {
41 statement.executeUpdate("""
42 CREATE TABLE session (
43 _id INTEGER PRIMARY KEY,
44 account_id_type INTEGER NOT NULL,
45 address TEXT NOT NULL,
46 device_id INTEGER NOT NULL,
47 record BLOB NOT NULL,
48 UNIQUE(account_id_type, address, device_id)
49 ) STRICT;
50 """);
51 }
52 }
53
54 public SessionStore(final Database database, final ServiceIdType serviceIdType) {
55 this.database = database;
56 this.accountIdType = Utils.getAccountIdType(serviceIdType);
57 }
58
59 @Override
60 public SessionRecord loadSession(SignalProtocolAddress address) {
61 final var key = getKey(address);
62 try (final var connection = database.getConnection()) {
63 final var session = loadSession(connection, key);
64 return Objects.requireNonNullElseGet(session, SessionRecord::new);
65 } catch (SQLException e) {
66 throw new RuntimeException("Failed read from session store", e);
67 }
68 }
69
70 @Override
71 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
72 final var keys = addresses.stream().map(this::getKey).toList();
73
74 try (final var connection = database.getConnection()) {
75 final var sessions = new ArrayList<SessionRecord>();
76 for (final var key : keys) {
77 final var sessionRecord = loadSession(connection, key);
78 if (sessionRecord != null) {
79 sessions.add(sessionRecord);
80 }
81 }
82
83 if (sessions.size() != addresses.size()) {
84 String message = "Mismatch! Asked for "
85 + addresses.size()
86 + " sessions, but only found "
87 + sessions.size()
88 + "!";
89 logger.warn(message);
90 throw new NoSessionException(message);
91 }
92
93 return sessions;
94 } catch (SQLException e) {
95 throw new RuntimeException("Failed read from session store", e);
96 }
97 }
98
99 @Override
100 public List<Integer> getSubDeviceSessions(String name) {
101 final var serviceId = ServiceId.parseOrThrow(name);
102 // get all sessions for recipient except primary device session
103 final var sql = (
104 """
105 SELECT s.device_id
106 FROM %s AS s
107 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id != 1
108 """
109 ).formatted(TABLE_SESSION);
110 try (final var connection = database.getConnection()) {
111 try (final var statement = connection.prepareStatement(sql)) {
112 statement.setInt(1, accountIdType);
113 statement.setString(2, serviceId.toString());
114 return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
115 }
116 } catch (SQLException e) {
117 throw new RuntimeException("Failed read from session store", e);
118 }
119 }
120
121 public boolean isCurrentRatchetKey(ServiceId serviceId, int deviceId, ECPublicKey ratchetKey) {
122 final var key = new Key(serviceId.toString(), deviceId);
123
124 try (final var connection = database.getConnection()) {
125 final var session = loadSession(connection, key);
126 if (session == null) {
127 return false;
128 }
129 return session.currentRatchetKeyMatches(ratchetKey);
130 } catch (SQLException e) {
131 throw new RuntimeException("Failed read from session store", e);
132 }
133 }
134
135 @Override
136 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
137 final var key = getKey(address);
138
139 try (final var connection = database.getConnection()) {
140 storeSession(connection, key, session);
141 } catch (SQLException e) {
142 throw new RuntimeException("Failed read from session store", e);
143 }
144 }
145
146 @Override
147 public boolean containsSession(SignalProtocolAddress address) {
148 final var key = getKey(address);
149
150 try (final var connection = database.getConnection()) {
151 final var session = loadSession(connection, key);
152 return isActive(session);
153 } catch (SQLException e) {
154 throw new RuntimeException("Failed read from session store", e);
155 }
156 }
157
158 @Override
159 public void deleteSession(SignalProtocolAddress address) {
160 final var key = getKey(address);
161
162 try (final var connection = database.getConnection()) {
163 deleteSession(connection, key);
164 } catch (SQLException e) {
165 throw new RuntimeException("Failed update session store", e);
166 }
167 }
168
169 @Override
170 public void deleteAllSessions(String name) {
171 final var serviceId = ServiceId.parseOrThrow(name);
172 deleteAllSessions(serviceId);
173 }
174
175 public void deleteAllSessions(ServiceId serviceId) {
176 try (final var connection = database.getConnection()) {
177 deleteAllSessions(connection, serviceId.toString());
178 } catch (SQLException e) {
179 throw new RuntimeException("Failed update session store", e);
180 }
181 }
182
183 @Override
184 public void archiveSession(final SignalProtocolAddress address) {
185 final var key = getKey(address);
186
187 try (final var connection = database.getConnection()) {
188 connection.setAutoCommit(false);
189 final var session = loadSession(connection, key);
190 if (session != null) {
191 session.archiveCurrentState();
192 storeSession(connection, key, session);
193 connection.commit();
194 }
195 } catch (SQLException e) {
196 throw new RuntimeException("Failed update session store", e);
197 }
198 }
199
200 @Override
201 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
202 final var serviceIdsCommaSeparated = addressNames.stream()
203 .map(address -> "'" + address.replaceAll("'", "''") + "'")
204 .collect(Collectors.joining(","));
205 final var sql = (
206 """
207 SELECT s.address, s.device_id, s.record
208 FROM %s AS s
209 WHERE s.account_id_type = ? AND s.address IN (%s)
210 """
211 ).formatted(TABLE_SESSION, serviceIdsCommaSeparated);
212 try (final var connection = database.getConnection()) {
213 try (final var statement = connection.prepareStatement(sql)) {
214 statement.setInt(1, accountIdType);
215 return Utils.executeQueryForStream(statement,
216 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
217 .filter(pair -> isActive(pair.second()))
218 .map(Pair::first)
219 .map(key -> new SignalProtocolAddress(key.address(), key.deviceId()))
220 .collect(Collectors.toSet());
221 }
222 } catch (SQLException e) {
223 throw new RuntimeException("Failed read from session store", e);
224 }
225 }
226
227 public void archiveAllSessions() {
228 final var sql = (
229 """
230 SELECT s.address, s.device_id, s.record
231 FROM %s AS s
232 WHERE s.account_id_type = ?
233 """
234 ).formatted(TABLE_SESSION);
235 try (final var connection = database.getConnection()) {
236 connection.setAutoCommit(false);
237 final List<Pair<Key, SessionRecord>> records;
238 try (final var statement = connection.prepareStatement(sql)) {
239 statement.setInt(1, accountIdType);
240 records = Utils.executeQueryForStream(statement,
241 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
242 .filter(Objects::nonNull)
243 .toList();
244 }
245 for (final var record : records) {
246 record.second().archiveCurrentState();
247 storeSession(connection, record.first(), record.second());
248 }
249 connection.commit();
250 } catch (SQLException e) {
251 throw new RuntimeException("Failed update session store", e);
252 }
253 }
254
255 public void archiveSessions(final ServiceId serviceId) {
256 final var sql = (
257 """
258 SELECT s.address, s.device_id, s.record
259 FROM %s AS s
260 WHERE s.account_id_type = ? AND s.address = ?
261 """
262 ).formatted(TABLE_SESSION);
263 try (final var connection = database.getConnection()) {
264 connection.setAutoCommit(false);
265 final List<Pair<Key, SessionRecord>> records;
266 try (final var statement = connection.prepareStatement(sql)) {
267 statement.setInt(1, accountIdType);
268 statement.setString(2, serviceId.toString());
269 records = Utils.executeQueryForStream(statement,
270 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
271 .filter(Objects::nonNull)
272 .toList();
273 }
274 for (final var record : records) {
275 record.second().archiveCurrentState();
276 storeSession(connection, record.first(), record.second());
277 }
278 connection.commit();
279 } catch (SQLException e) {
280 throw new RuntimeException("Failed update session store", e);
281 }
282 }
283
284 void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
285 logger.debug("Migrating legacy sessions to database");
286 long start = System.nanoTime();
287 try (final var connection = database.getConnection()) {
288 connection.setAutoCommit(false);
289 for (final var pair : sessions) {
290 storeSession(connection, pair.first(), pair.second());
291 }
292 connection.commit();
293 } catch (SQLException e) {
294 throw new RuntimeException("Failed update session store", e);
295 }
296 logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
297 }
298
299 private Key getKey(final SignalProtocolAddress address) {
300 return new Key(address.getName(), address.getDeviceId());
301 }
302
303 private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
304 synchronized (cachedSessions) {
305 final var session = cachedSessions.get(key);
306 if (session != null) {
307 return session;
308 }
309 }
310 final var sql = (
311 """
312 SELECT s.record
313 FROM %s AS s
314 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
315 """
316 ).formatted(TABLE_SESSION);
317 try (final var statement = connection.prepareStatement(sql)) {
318 statement.setInt(1, accountIdType);
319 statement.setString(2, key.address());
320 statement.setInt(3, key.deviceId());
321 return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
322 }
323 }
324
325 private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
326 final var address = resultSet.getString("address");
327 final var deviceId = resultSet.getInt("device_id");
328 return new Key(address, deviceId);
329 }
330
331 private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) {
332 try {
333 final var record = resultSet.getBytes("record");
334 return new SessionRecord(record);
335 } catch (Exception e) {
336 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
337 return null;
338 }
339 }
340
341 private void storeSession(
342 final Connection connection, final Key key, final SessionRecord session
343 ) throws SQLException {
344 synchronized (cachedSessions) {
345 cachedSessions.put(key, session);
346 }
347
348 final var sql = """
349 INSERT OR REPLACE INTO %s (account_id_type, address, device_id, record)
350 VALUES (?, ?, ?, ?)
351 """.formatted(TABLE_SESSION);
352 try (final var statement = connection.prepareStatement(sql)) {
353 statement.setInt(1, accountIdType);
354 statement.setString(2, key.address());
355 statement.setInt(3, key.deviceId());
356 statement.setBytes(4, session.serialize());
357 statement.executeUpdate();
358 }
359 }
360
361 private void deleteAllSessions(final Connection connection, final String address) throws SQLException {
362 synchronized (cachedSessions) {
363 cachedSessions.clear();
364 }
365
366 final var sql = (
367 """
368 DELETE FROM %s AS s
369 WHERE s.account_id_type = ? AND s.address = ?
370 """
371 ).formatted(TABLE_SESSION);
372 try (final var statement = connection.prepareStatement(sql)) {
373 statement.setInt(1, accountIdType);
374 statement.setString(2, address);
375 statement.executeUpdate();
376 }
377 }
378
379 private void deleteSession(Connection connection, final Key key) throws SQLException {
380 synchronized (cachedSessions) {
381 cachedSessions.remove(key);
382 }
383
384 final var sql = (
385 """
386 DELETE FROM %s AS s
387 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
388 """
389 ).formatted(TABLE_SESSION);
390 try (final var statement = connection.prepareStatement(sql)) {
391 statement.setInt(1, accountIdType);
392 statement.setString(2, key.address());
393 statement.setInt(3, key.deviceId());
394 statement.executeUpdate();
395 }
396 }
397
398 private static boolean isActive(SessionRecord record) {
399 return record != null && record.hasSenderChain();
400 }
401
402 record Key(String address, int deviceId) {}
403 }