]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
5738408d6826f9861e14dbb7df9b7d2d13128ed4
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.storage.recipients.RecipientId;
4 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
5 import org.asamk.signal.manager.util.IOUtils;
6 import org.slf4j.Logger;
7 import org.slf4j.LoggerFactory;
8 import org.whispersystems.libsignal.NoSessionException;
9 import org.whispersystems.libsignal.SignalProtocolAddress;
10 import org.whispersystems.libsignal.protocol.CiphertextMessage;
11 import org.whispersystems.libsignal.state.SessionRecord;
12 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
13
14 import java.io.File;
15 import java.io.FileInputStream;
16 import java.io.FileOutputStream;
17 import java.io.IOException;
18 import java.nio.file.Files;
19 import java.util.Arrays;
20 import java.util.Collection;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Objects;
25 import java.util.Set;
26 import java.util.regex.Matcher;
27 import java.util.regex.Pattern;
28 import java.util.stream.Collectors;
29
30 public class SessionStore implements SignalServiceSessionStore {
31
32 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
33
34 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
35
36 private final File sessionsPath;
37
38 private final RecipientResolver resolver;
39
40 public SessionStore(
41 final File sessionsPath, final RecipientResolver resolver
42 ) {
43 this.sessionsPath = sessionsPath;
44 this.resolver = resolver;
45 }
46
47 @Override
48 public SessionRecord loadSession(SignalProtocolAddress address) {
49 final var key = getKey(address);
50
51 synchronized (cachedSessions) {
52 final var session = loadSessionLocked(key);
53 if (session == null) {
54 return new SessionRecord();
55 }
56 return session;
57 }
58 }
59
60 @Override
61 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
62 final var keys = addresses.stream().map(this::getKey).collect(Collectors.toList());
63
64 synchronized (cachedSessions) {
65 final var sessions = keys.stream()
66 .map(this::loadSessionLocked)
67 .filter(Objects::nonNull)
68 .collect(Collectors.toList());
69
70 if (sessions.size() != addresses.size()) {
71 String message = "Mismatch! Asked for "
72 + addresses.size()
73 + " sessions, but only found "
74 + sessions.size()
75 + "!";
76 logger.warn(message);
77 throw new NoSessionException(message);
78 }
79
80 return sessions;
81 }
82 }
83
84 @Override
85 public List<Integer> getSubDeviceSessions(String name) {
86 final var recipientId = resolveRecipient(name);
87
88 synchronized (cachedSessions) {
89 return getKeysLocked(recipientId).stream()
90 // get all sessions for recipient except main device session
91 .filter(key -> key.getDeviceId() != 1 && key.getRecipientId().equals(recipientId))
92 .map(Key::getDeviceId)
93 .collect(Collectors.toList());
94 }
95 }
96
97 @Override
98 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
99 final var key = getKey(address);
100
101 synchronized (cachedSessions) {
102 storeSessionLocked(key, session);
103 }
104 }
105
106 @Override
107 public boolean containsSession(SignalProtocolAddress address) {
108 final var key = getKey(address);
109
110 synchronized (cachedSessions) {
111 final var session = loadSessionLocked(key);
112 if (session == null) {
113 return false;
114 }
115
116 return session.hasSenderChain() && session.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
117 }
118 }
119
120 @Override
121 public void deleteSession(SignalProtocolAddress address) {
122 final var key = getKey(address);
123
124 synchronized (cachedSessions) {
125 deleteSessionLocked(key);
126 }
127 }
128
129 @Override
130 public void deleteAllSessions(String name) {
131 final var recipientId = resolveRecipient(name);
132 deleteAllSessions(recipientId);
133 }
134
135 public void deleteAllSessions(RecipientId recipientId) {
136 synchronized (cachedSessions) {
137 final var keys = getKeysLocked(recipientId);
138 for (var key : keys) {
139 deleteSessionLocked(key);
140 }
141 }
142 }
143
144 @Override
145 public void archiveSession(final SignalProtocolAddress address) {
146 final var key = getKey(address);
147
148 synchronized (cachedSessions) {
149 archiveSessionLocked(key);
150 }
151 }
152
153 @Override
154 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
155 final var recipientIdToNameMap = addressNames.stream()
156 .collect(Collectors.toMap(this::resolveRecipient, name -> name));
157 synchronized (cachedSessions) {
158 return recipientIdToNameMap.keySet()
159 .stream()
160 .flatMap(recipientId -> getKeysLocked(recipientId).stream())
161 .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), key.getDeviceId()))
162 .collect(Collectors.toSet());
163 }
164 }
165
166 public void archiveAllSessions() {
167 synchronized (cachedSessions) {
168 final var keys = getKeysLocked();
169 for (var key : keys) {
170 archiveSessionLocked(key);
171 }
172 }
173 }
174
175 public void archiveSessions(final RecipientId recipientId) {
176 synchronized (cachedSessions) {
177 getKeysLocked().stream()
178 .filter(key -> key.recipientId.equals(recipientId))
179 .forEach(this::archiveSessionLocked);
180 }
181 }
182
183 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
184 synchronized (cachedSessions) {
185 final var otherHasSession = getKeysLocked(toBeMergedRecipientId).size() > 0;
186 if (!otherHasSession) {
187 return;
188 }
189
190 final var hasSession = getKeysLocked(recipientId).size() > 0;
191 if (hasSession) {
192 logger.debug("To be merged recipient had sessions, deleting.");
193 deleteAllSessions(toBeMergedRecipientId);
194 } else {
195 logger.debug("To be merged recipient had sessions, re-assigning to the new recipient.");
196 final var keys = getKeysLocked(toBeMergedRecipientId);
197 for (var key : keys) {
198 final var session = loadSessionLocked(key);
199 deleteSessionLocked(key);
200 if (session == null) {
201 continue;
202 }
203 final var newKey = new Key(recipientId, key.getDeviceId());
204 storeSessionLocked(newKey, session);
205 }
206 }
207 }
208 }
209
210 /**
211 * @param identifier can be either a serialized uuid or a e164 phone number
212 */
213 private RecipientId resolveRecipient(String identifier) {
214 return resolver.resolveRecipient(identifier);
215 }
216
217 private Key getKey(final SignalProtocolAddress address) {
218 final var recipientId = resolveRecipient(address.getName());
219 return new Key(recipientId, address.getDeviceId());
220 }
221
222 private List<Key> getKeysLocked(RecipientId recipientId) {
223 final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
224 if (files == null) {
225 return List.of();
226 }
227 return parseFileNames(files);
228 }
229
230 private Collection<Key> getKeysLocked() {
231 final var files = sessionsPath.listFiles();
232 if (files == null) {
233 return List.of();
234 }
235 return parseFileNames(files);
236 }
237
238 final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
239
240 private List<Key> parseFileNames(final File[] files) {
241 return Arrays.stream(files)
242 .map(f -> sessionFileNamePattern.matcher(f.getName()))
243 .filter(Matcher::matches)
244 .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
245 Integer.parseInt(matcher.group(2))))
246 .collect(Collectors.toList());
247 }
248
249 private File getSessionFile(Key key) {
250 try {
251 IOUtils.createPrivateDirectories(sessionsPath);
252 } catch (IOException e) {
253 throw new AssertionError("Failed to create sessions path", e);
254 }
255 return new File(sessionsPath, key.getRecipientId().getId() + "_" + key.getDeviceId());
256 }
257
258 private SessionRecord loadSessionLocked(final Key key) {
259 {
260 final var session = cachedSessions.get(key);
261 if (session != null) {
262 return session;
263 }
264 }
265
266 final var file = getSessionFile(key);
267 if (!file.exists()) {
268 return null;
269 }
270 try (var inputStream = new FileInputStream(file)) {
271 final var session = new SessionRecord(inputStream.readAllBytes());
272 cachedSessions.put(key, session);
273 return session;
274 } catch (IOException e) {
275 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
276 return null;
277 }
278 }
279
280 private void storeSessionLocked(final Key key, final SessionRecord session) {
281 cachedSessions.put(key, session);
282
283 final var file = getSessionFile(key);
284 try {
285 try (var outputStream = new FileOutputStream(file)) {
286 outputStream.write(session.serialize());
287 }
288 } catch (IOException e) {
289 logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
290 try {
291 Files.delete(file.toPath());
292 try (var outputStream = new FileOutputStream(file)) {
293 outputStream.write(session.serialize());
294 }
295 } catch (IOException e2) {
296 logger.error("Failed to store session file {}: {}", file, e2.getMessage());
297 }
298 }
299 }
300
301 private void archiveSessionLocked(final Key key) {
302 final var session = loadSessionLocked(key);
303 if (session == null) {
304 return;
305 }
306 session.archiveCurrentState();
307 storeSessionLocked(key, session);
308 }
309
310 private void deleteSessionLocked(final Key key) {
311 cachedSessions.remove(key);
312
313 final var file = getSessionFile(key);
314 if (!file.exists()) {
315 return;
316 }
317 try {
318 Files.delete(file.toPath());
319 } catch (IOException e) {
320 logger.error("Failed to delete session file {}: {}", file, e.getMessage());
321 }
322 }
323
324 private static final class Key {
325
326 private final RecipientId recipientId;
327 private final int deviceId;
328
329 public Key(final RecipientId recipientId, final int deviceId) {
330 this.recipientId = recipientId;
331 this.deviceId = deviceId;
332 }
333
334 public RecipientId getRecipientId() {
335 return recipientId;
336 }
337
338 public int getDeviceId() {
339 return deviceId;
340 }
341
342 @Override
343 public boolean equals(final Object o) {
344 if (this == o) return true;
345 if (o == null || getClass() != o.getClass()) return false;
346
347 final var key = (Key) o;
348
349 if (deviceId != key.deviceId) return false;
350 return recipientId.equals(key.recipientId);
351 }
352
353 @Override
354 public int hashCode() {
355 int result = recipientId.hashCode();
356 result = 31 * result + deviceId;
357 return result;
358 }
359 }
360 }