]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
3f3ba4d6709929580fc1dc760158546d820d8248
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.api.Pair;
4 import org.asamk.signal.manager.storage.Database;
5 import org.asamk.signal.manager.storage.Utils;
6 import org.signal.libsignal.protocol.NoSessionException;
7 import org.signal.libsignal.protocol.SignalProtocolAddress;
8 import org.signal.libsignal.protocol.ecc.ECPublicKey;
9 import org.signal.libsignal.protocol.state.SessionRecord;
10 import org.slf4j.Logger;
11 import org.slf4j.LoggerFactory;
12 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
13 import org.whispersystems.signalservice.api.push.ServiceId;
14 import org.whispersystems.signalservice.api.push.ServiceIdType;
15
16 import java.sql.Connection;
17 import java.sql.ResultSet;
18 import java.sql.SQLException;
19 import java.util.ArrayList;
20 import java.util.Collection;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Objects;
25 import java.util.stream.Collectors;
26
27 public class SessionStore implements SignalServiceSessionStore {
28
29 private static final String TABLE_SESSION = "session";
30 private static final Logger logger = LoggerFactory.getLogger(SessionStore.class);
31
32 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
33
34 private final Database database;
35 private final int accountIdType;
36
37 public static void createSql(Connection connection) throws SQLException {
38 // When modifying the CREATE statement here, also add a migration in AccountDatabase.java
39 try (final var statement = connection.createStatement()) {
40 statement.executeUpdate("""
41 CREATE TABLE session (
42 _id INTEGER PRIMARY KEY,
43 account_id_type INTEGER NOT NULL,
44 address TEXT NOT NULL,
45 device_id INTEGER NOT NULL,
46 record BLOB NOT NULL,
47 UNIQUE(account_id_type, address, device_id)
48 ) STRICT;
49 """);
50 }
51 }
52
53 public SessionStore(final Database database, final ServiceIdType serviceIdType) {
54 this.database = database;
55 this.accountIdType = Utils.getAccountIdType(serviceIdType);
56 }
57
58 @Override
59 public SessionRecord loadSession(SignalProtocolAddress address) {
60 final var key = getKey(address);
61 try (final var connection = database.getConnection()) {
62 final var session = loadSession(connection, key);
63 return Objects.requireNonNullElseGet(session, SessionRecord::new);
64 } catch (SQLException e) {
65 throw new RuntimeException("Failed read from session store", e);
66 }
67 }
68
69 @Override
70 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
71 final var keys = addresses.stream().map(this::getKey).toList();
72
73 try (final var connection = database.getConnection()) {
74 final var sessions = new ArrayList<SessionRecord>();
75 for (final var key : keys) {
76 final var sessionRecord = loadSession(connection, key);
77 if (sessionRecord != null) {
78 sessions.add(sessionRecord);
79 }
80 }
81
82 if (sessions.size() != addresses.size()) {
83 String message = "Mismatch! Asked for "
84 + addresses.size()
85 + " sessions, but only found "
86 + sessions.size()
87 + "!";
88 logger.warn(message);
89 throw new NoSessionException(message);
90 }
91
92 return sessions;
93 } catch (SQLException e) {
94 throw new RuntimeException("Failed read from session store", e);
95 }
96 }
97
98 @Override
99 public List<Integer> getSubDeviceSessions(String name) {
100 final var serviceId = ServiceId.parseOrThrow(name);
101 // get all sessions for recipient except primary device session
102 final var sql = (
103 """
104 SELECT s.device_id
105 FROM %s AS s
106 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id != 1
107 """
108 ).formatted(TABLE_SESSION);
109 try (final var connection = database.getConnection()) {
110 try (final var statement = connection.prepareStatement(sql)) {
111 statement.setInt(1, accountIdType);
112 statement.setString(2, serviceId.toString());
113 return Utils.executeQueryForStream(statement, res -> res.getInt("device_id")).toList();
114 }
115 } catch (SQLException e) {
116 throw new RuntimeException("Failed read from session store", e);
117 }
118 }
119
120 public boolean isCurrentRatchetKey(ServiceId serviceId, int deviceId, ECPublicKey ratchetKey) {
121 final var key = new Key(serviceId.toString(), deviceId);
122
123 try (final var connection = database.getConnection()) {
124 final var session = loadSession(connection, key);
125 if (session == null) {
126 return false;
127 }
128 return session.currentRatchetKeyMatches(ratchetKey);
129 } catch (SQLException e) {
130 throw new RuntimeException("Failed read from session store", e);
131 }
132 }
133
134 @Override
135 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
136 final var key = getKey(address);
137
138 try (final var connection = database.getConnection()) {
139 storeSession(connection, key, session);
140 } catch (SQLException e) {
141 throw new RuntimeException("Failed read from session store", e);
142 }
143 }
144
145 @Override
146 public boolean containsSession(SignalProtocolAddress address) {
147 final var key = getKey(address);
148
149 try (final var connection = database.getConnection()) {
150 final var session = loadSession(connection, key);
151 return isActive(session);
152 } catch (SQLException e) {
153 throw new RuntimeException("Failed read from session store", e);
154 }
155 }
156
157 @Override
158 public void deleteSession(SignalProtocolAddress address) {
159 final var key = getKey(address);
160
161 try (final var connection = database.getConnection()) {
162 deleteSession(connection, key);
163 } catch (SQLException e) {
164 throw new RuntimeException("Failed update session store", e);
165 }
166 }
167
168 @Override
169 public void deleteAllSessions(String name) {
170 final var serviceId = ServiceId.parseOrThrow(name);
171 deleteAllSessions(serviceId);
172 }
173
174 public void deleteAllSessions(ServiceId serviceId) {
175 try (final var connection = database.getConnection()) {
176 deleteAllSessions(connection, serviceId.toString());
177 } catch (SQLException e) {
178 throw new RuntimeException("Failed update session store", e);
179 }
180 }
181
182 @Override
183 public void archiveSession(final SignalProtocolAddress address) {
184 final var key = getKey(address);
185
186 try (final var connection = database.getConnection()) {
187 connection.setAutoCommit(false);
188 final var session = loadSession(connection, key);
189 if (session != null) {
190 session.archiveCurrentState();
191 storeSession(connection, key, session);
192 connection.commit();
193 }
194 } catch (SQLException e) {
195 throw new RuntimeException("Failed update session store", e);
196 }
197 }
198
199 @Override
200 public Map<SignalProtocolAddress, SessionRecord> getAllAddressesWithActiveSessions(final List<String> addressNames) {
201 final var serviceIdsCommaSeparated = addressNames.stream()
202 .map(address -> "'" + address.replaceAll("'", "''") + "'")
203 .collect(Collectors.joining(","));
204 final var sql = (
205 """
206 SELECT s.address, s.device_id, s.record
207 FROM %s AS s
208 WHERE s.account_id_type = ? AND s.address IN (%s)
209 """
210 ).formatted(TABLE_SESSION, serviceIdsCommaSeparated);
211 try (final var connection = database.getConnection()) {
212 try (final var statement = connection.prepareStatement(sql)) {
213 statement.setInt(1, accountIdType);
214 return Utils.executeQueryForStream(statement,
215 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
216 .filter(pair -> isActive(pair.second()))
217 .collect(Collectors.toMap(pair -> new SignalProtocolAddress(pair.first().address(),
218 pair.first().deviceId()), Pair::second));
219 }
220 } catch (SQLException e) {
221 throw new RuntimeException("Failed read from session store", e);
222 }
223 }
224
225 public void archiveAllSessions() {
226 final var sql = (
227 """
228 SELECT s.address, s.device_id, s.record
229 FROM %s AS s
230 WHERE s.account_id_type = ?
231 """
232 ).formatted(TABLE_SESSION);
233 try (final var connection = database.getConnection()) {
234 connection.setAutoCommit(false);
235 final List<Pair<Key, SessionRecord>> records;
236 try (final var statement = connection.prepareStatement(sql)) {
237 statement.setInt(1, accountIdType);
238 records = Utils.executeQueryForStream(statement,
239 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
240 .filter(Objects::nonNull)
241 .toList();
242 }
243 for (final var record : records) {
244 record.second().archiveCurrentState();
245 storeSession(connection, record.first(), record.second());
246 }
247 connection.commit();
248 } catch (SQLException e) {
249 throw new RuntimeException("Failed update session store", e);
250 }
251 }
252
253 public void archiveSessions(final ServiceId serviceId) {
254 final var sql = (
255 """
256 SELECT s.address, s.device_id, s.record
257 FROM %s AS s
258 WHERE s.account_id_type = ? AND s.address = ?
259 """
260 ).formatted(TABLE_SESSION);
261 try (final var connection = database.getConnection()) {
262 connection.setAutoCommit(false);
263 final List<Pair<Key, SessionRecord>> records;
264 try (final var statement = connection.prepareStatement(sql)) {
265 statement.setInt(1, accountIdType);
266 statement.setString(2, serviceId.toString());
267 records = Utils.executeQueryForStream(statement,
268 res -> new Pair<>(getKeyFromResultSet(res), getSessionRecordFromResultSet(res)))
269 .filter(Objects::nonNull)
270 .toList();
271 }
272 for (final var record : records) {
273 record.second().archiveCurrentState();
274 storeSession(connection, record.first(), record.second());
275 }
276 connection.commit();
277 } catch (SQLException e) {
278 throw new RuntimeException("Failed update session store", e);
279 }
280 }
281
282 void addLegacySessions(final Collection<Pair<Key, SessionRecord>> sessions) {
283 logger.debug("Migrating legacy sessions to database");
284 long start = System.nanoTime();
285 try (final var connection = database.getConnection()) {
286 connection.setAutoCommit(false);
287 for (final var pair : sessions) {
288 storeSession(connection, pair.first(), pair.second());
289 }
290 connection.commit();
291 } catch (SQLException e) {
292 throw new RuntimeException("Failed update session store", e);
293 }
294 logger.debug("Complete sessions migration took {}ms", (System.nanoTime() - start) / 1000000);
295 }
296
297 private Key getKey(final SignalProtocolAddress address) {
298 return new Key(address.getName(), address.getDeviceId());
299 }
300
301 private SessionRecord loadSession(Connection connection, final Key key) throws SQLException {
302 synchronized (cachedSessions) {
303 final var session = cachedSessions.get(key);
304 if (session != null) {
305 return session;
306 }
307 }
308 final var sql = (
309 """
310 SELECT s.record
311 FROM %s AS s
312 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
313 """
314 ).formatted(TABLE_SESSION);
315 try (final var statement = connection.prepareStatement(sql)) {
316 statement.setInt(1, accountIdType);
317 statement.setString(2, key.address());
318 statement.setInt(3, key.deviceId());
319 return Utils.executeQueryForOptional(statement, this::getSessionRecordFromResultSet).orElse(null);
320 }
321 }
322
323 private Key getKeyFromResultSet(ResultSet resultSet) throws SQLException {
324 final var address = resultSet.getString("address");
325 final var deviceId = resultSet.getInt("device_id");
326 return new Key(address, deviceId);
327 }
328
329 private SessionRecord getSessionRecordFromResultSet(ResultSet resultSet) {
330 try {
331 final var record = resultSet.getBytes("record");
332 return new SessionRecord(record);
333 } catch (Exception e) {
334 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
335 return null;
336 }
337 }
338
339 private void storeSession(
340 final Connection connection, final Key key, final SessionRecord session
341 ) throws SQLException {
342 synchronized (cachedSessions) {
343 cachedSessions.put(key, session);
344 }
345
346 final var sql = """
347 INSERT OR REPLACE INTO %s (account_id_type, address, device_id, record)
348 VALUES (?, ?, ?, ?)
349 """.formatted(TABLE_SESSION);
350 try (final var statement = connection.prepareStatement(sql)) {
351 statement.setInt(1, accountIdType);
352 statement.setString(2, key.address());
353 statement.setInt(3, key.deviceId());
354 statement.setBytes(4, session.serialize());
355 statement.executeUpdate();
356 }
357 }
358
359 private void deleteAllSessions(final Connection connection, final String address) throws SQLException {
360 synchronized (cachedSessions) {
361 cachedSessions.clear();
362 }
363
364 final var sql = (
365 """
366 DELETE FROM %s AS s
367 WHERE s.account_id_type = ? AND s.address = ?
368 """
369 ).formatted(TABLE_SESSION);
370 try (final var statement = connection.prepareStatement(sql)) {
371 statement.setInt(1, accountIdType);
372 statement.setString(2, address);
373 statement.executeUpdate();
374 }
375 }
376
377 private void deleteSession(Connection connection, final Key key) throws SQLException {
378 synchronized (cachedSessions) {
379 cachedSessions.remove(key);
380 }
381
382 final var sql = (
383 """
384 DELETE FROM %s AS s
385 WHERE s.account_id_type = ? AND s.address = ? AND s.device_id = ?
386 """
387 ).formatted(TABLE_SESSION);
388 try (final var statement = connection.prepareStatement(sql)) {
389 statement.setInt(1, accountIdType);
390 statement.setString(2, key.address());
391 statement.setInt(3, key.deviceId());
392 statement.executeUpdate();
393 }
394 }
395
396 private static boolean isActive(SessionRecord record) {
397 return record != null && record.hasSenderChain();
398 }
399
400 record Key(String address, int deviceId) {}
401 }