]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
24327eb07420d1d3c3ce2b81fca94d52721947d9
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.storage.recipients.RecipientId;
4 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
5 import org.asamk.signal.manager.util.IOUtils;
6 import org.signal.libsignal.protocol.NoSessionException;
7 import org.signal.libsignal.protocol.SignalProtocolAddress;
8 import org.signal.libsignal.protocol.ecc.ECPublicKey;
9 import org.signal.libsignal.protocol.message.CiphertextMessage;
10 import org.signal.libsignal.protocol.state.SessionRecord;
11 import org.slf4j.Logger;
12 import org.slf4j.LoggerFactory;
13 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
14
15 import java.io.File;
16 import java.io.FileInputStream;
17 import java.io.FileOutputStream;
18 import java.io.IOException;
19 import java.nio.file.Files;
20 import java.util.Arrays;
21 import java.util.Collection;
22 import java.util.HashMap;
23 import java.util.List;
24 import java.util.Map;
25 import java.util.Objects;
26 import java.util.Set;
27 import java.util.regex.Matcher;
28 import java.util.regex.Pattern;
29 import java.util.stream.Collectors;
30
31 public class SessionStore implements SignalServiceSessionStore {
32
33 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
34
35 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
36
37 private final File sessionsPath;
38
39 private final RecipientResolver resolver;
40
41 public SessionStore(
42 final File sessionsPath, final RecipientResolver resolver
43 ) {
44 this.sessionsPath = sessionsPath;
45 this.resolver = resolver;
46 }
47
48 @Override
49 public SessionRecord loadSession(SignalProtocolAddress address) {
50 final var key = getKey(address);
51
52 synchronized (cachedSessions) {
53 final var session = loadSessionLocked(key);
54 if (session == null) {
55 return new SessionRecord();
56 }
57 return session;
58 }
59 }
60
61 @Override
62 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
63 final var keys = addresses.stream().map(this::getKey).toList();
64
65 synchronized (cachedSessions) {
66 final var sessions = keys.stream().map(this::loadSessionLocked).filter(Objects::nonNull).toList();
67
68 if (sessions.size() != addresses.size()) {
69 String message = "Mismatch! Asked for "
70 + addresses.size()
71 + " sessions, but only found "
72 + sessions.size()
73 + "!";
74 logger.warn(message);
75 throw new NoSessionException(message);
76 }
77
78 return sessions;
79 }
80 }
81
82 @Override
83 public List<Integer> getSubDeviceSessions(String name) {
84 final var recipientId = resolveRecipient(name);
85
86 synchronized (cachedSessions) {
87 return getKeysLocked(recipientId).stream()
88 // get all sessions for recipient except main device session
89 .filter(key -> key.deviceId() != 1 && key.recipientId().equals(recipientId))
90 .map(Key::deviceId)
91 .toList();
92 }
93 }
94
95 public boolean isCurrentRatchetKey(RecipientId recipientId, int deviceId, ECPublicKey ratchetKey) {
96 final var key = new Key(recipientId, deviceId);
97
98 synchronized (cachedSessions) {
99 final var session = loadSessionLocked(key);
100 if (session == null) {
101 return false;
102 }
103 return session.currentRatchetKeyMatches(ratchetKey);
104 }
105 }
106
107 @Override
108 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
109 final var key = getKey(address);
110
111 synchronized (cachedSessions) {
112 storeSessionLocked(key, session);
113 }
114 }
115
116 @Override
117 public boolean containsSession(SignalProtocolAddress address) {
118 final var key = getKey(address);
119
120 synchronized (cachedSessions) {
121 final var session = loadSessionLocked(key);
122 return isActive(session);
123 }
124 }
125
126 @Override
127 public void deleteSession(SignalProtocolAddress address) {
128 final var key = getKey(address);
129
130 synchronized (cachedSessions) {
131 deleteSessionLocked(key);
132 }
133 }
134
135 @Override
136 public void deleteAllSessions(String name) {
137 final var recipientId = resolveRecipient(name);
138 deleteAllSessions(recipientId);
139 }
140
141 public void deleteAllSessions(RecipientId recipientId) {
142 synchronized (cachedSessions) {
143 final var keys = getKeysLocked(recipientId);
144 for (var key : keys) {
145 deleteSessionLocked(key);
146 }
147 }
148 }
149
150 @Override
151 public void archiveSession(final SignalProtocolAddress address) {
152 final var key = getKey(address);
153
154 synchronized (cachedSessions) {
155 archiveSessionLocked(key);
156 }
157 }
158
159 @Override
160 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
161 final var recipientIdToNameMap = addressNames.stream()
162 .collect(Collectors.toMap(this::resolveRecipient, name -> name));
163 synchronized (cachedSessions) {
164 return recipientIdToNameMap.keySet()
165 .stream()
166 .flatMap(recipientId -> getKeysLocked(recipientId).stream())
167 .filter(key -> isActive(this.loadSessionLocked(key)))
168 .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), key.deviceId()))
169 .collect(Collectors.toSet());
170 }
171 }
172
173 public void archiveAllSessions() {
174 synchronized (cachedSessions) {
175 final var keys = getKeysLocked();
176 for (var key : keys) {
177 archiveSessionLocked(key);
178 }
179 }
180 }
181
182 public void archiveSessions(final RecipientId recipientId) {
183 synchronized (cachedSessions) {
184 getKeysLocked().stream()
185 .filter(key -> key.recipientId.equals(recipientId))
186 .forEach(this::archiveSessionLocked);
187 }
188 }
189
190 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
191 synchronized (cachedSessions) {
192 final var keys = getKeysLocked(toBeMergedRecipientId);
193 final var otherHasSession = keys.size() > 0;
194 if (!otherHasSession) {
195 return;
196 }
197
198 final var hasSession = getKeysLocked(recipientId).size() > 0;
199 if (hasSession) {
200 logger.debug("To be merged recipient had sessions, deleting.");
201 deleteAllSessions(toBeMergedRecipientId);
202 } else {
203 logger.debug("Only to be merged recipient had sessions, re-assigning to the new recipient.");
204 for (var key : keys) {
205 final var session = loadSessionLocked(key);
206 deleteSessionLocked(key);
207 if (session == null) {
208 continue;
209 }
210 final var newKey = new Key(recipientId, key.deviceId());
211 storeSessionLocked(newKey, session);
212 }
213 }
214 }
215 }
216
217 /**
218 * @param identifier can be either a serialized uuid or a e164 phone number
219 */
220 private RecipientId resolveRecipient(String identifier) {
221 return resolver.resolveRecipient(identifier);
222 }
223
224 private Key getKey(final SignalProtocolAddress address) {
225 final var recipientId = resolveRecipient(address.getName());
226 return new Key(recipientId, address.getDeviceId());
227 }
228
229 private List<Key> getKeysLocked(RecipientId recipientId) {
230 final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.id() + "_"));
231 if (files == null) {
232 return List.of();
233 }
234 return parseFileNames(files);
235 }
236
237 private Collection<Key> getKeysLocked() {
238 final var files = sessionsPath.listFiles();
239 if (files == null) {
240 return List.of();
241 }
242 return parseFileNames(files);
243 }
244
245 final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
246
247 private List<Key> parseFileNames(final File[] files) {
248 return Arrays.stream(files)
249 .map(f -> sessionFileNamePattern.matcher(f.getName()))
250 .filter(Matcher::matches)
251 .map(matcher -> {
252 final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1)));
253 if (recipientId == null) {
254 return null;
255 }
256 return new Key(recipientId, Integer.parseInt(matcher.group(2)));
257 })
258 .filter(Objects::nonNull)
259 .toList();
260 }
261
262 private File getSessionFile(Key key) {
263 try {
264 IOUtils.createPrivateDirectories(sessionsPath);
265 } catch (IOException e) {
266 throw new AssertionError("Failed to create sessions path", e);
267 }
268 return new File(sessionsPath, key.recipientId().id() + "_" + key.deviceId());
269 }
270
271 private SessionRecord loadSessionLocked(final Key key) {
272 {
273 final var session = cachedSessions.get(key);
274 if (session != null) {
275 return session;
276 }
277 }
278
279 final var file = getSessionFile(key);
280 if (!file.exists()) {
281 return null;
282 }
283 try (var inputStream = new FileInputStream(file)) {
284 final var session = new SessionRecord(inputStream.readAllBytes());
285 cachedSessions.put(key, session);
286 return session;
287 } catch (Exception e) {
288 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
289 return null;
290 }
291 }
292
293 private void storeSessionLocked(final Key key, final SessionRecord session) {
294 cachedSessions.put(key, session);
295
296 final var file = getSessionFile(key);
297 try {
298 try (var outputStream = new FileOutputStream(file)) {
299 outputStream.write(session.serialize());
300 }
301 } catch (IOException e) {
302 logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
303 try {
304 Files.delete(file.toPath());
305 try (var outputStream = new FileOutputStream(file)) {
306 outputStream.write(session.serialize());
307 }
308 } catch (IOException e2) {
309 logger.error("Failed to store session file {}: {}", file, e2.getMessage());
310 }
311 }
312 }
313
314 private void archiveSessionLocked(final Key key) {
315 final var session = loadSessionLocked(key);
316 if (session == null) {
317 return;
318 }
319 session.archiveCurrentState();
320 storeSessionLocked(key, session);
321 }
322
323 private void deleteSessionLocked(final Key key) {
324 cachedSessions.remove(key);
325
326 final var file = getSessionFile(key);
327 if (!file.exists()) {
328 return;
329 }
330 try {
331 Files.delete(file.toPath());
332 } catch (IOException e) {
333 logger.error("Failed to delete session file {}: {}", file, e.getMessage());
334 }
335 }
336
337 private static boolean isActive(SessionRecord record) {
338 return record != null
339 && record.hasSenderChain()
340 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
341 }
342
343 private record Key(RecipientId recipientId, int deviceId) {}
344 }