]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
bae4fdf31271557778981c4667d0e63a3443bfbb
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.storage.recipients.RecipientId;
4 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
5 import org.asamk.signal.manager.util.IOUtils;
6 import org.slf4j.Logger;
7 import org.slf4j.LoggerFactory;
8 import org.whispersystems.libsignal.NoSessionException;
9 import org.whispersystems.libsignal.SignalProtocolAddress;
10 import org.whispersystems.libsignal.protocol.CiphertextMessage;
11 import org.whispersystems.libsignal.state.SessionRecord;
12 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
13
14 import java.io.File;
15 import java.io.FileInputStream;
16 import java.io.FileOutputStream;
17 import java.io.IOException;
18 import java.nio.file.Files;
19 import java.util.Arrays;
20 import java.util.Collection;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Objects;
25 import java.util.Set;
26 import java.util.regex.Matcher;
27 import java.util.regex.Pattern;
28 import java.util.stream.Collectors;
29
30 public class SessionStore implements SignalServiceSessionStore {
31
32 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
33
34 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
35
36 private final File sessionsPath;
37
38 private final RecipientResolver resolver;
39
40 public SessionStore(
41 final File sessionsPath, final RecipientResolver resolver
42 ) {
43 this.sessionsPath = sessionsPath;
44 this.resolver = resolver;
45 }
46
47 @Override
48 public SessionRecord loadSession(SignalProtocolAddress address) {
49 final var key = getKey(address);
50
51 synchronized (cachedSessions) {
52 final var session = loadSessionLocked(key);
53 if (session == null) {
54 return new SessionRecord();
55 }
56 return session;
57 }
58 }
59
60 @Override
61 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
62 final var keys = addresses.stream().map(this::getKey).collect(Collectors.toList());
63
64 synchronized (cachedSessions) {
65 final var sessions = keys.stream()
66 .map(this::loadSessionLocked)
67 .filter(Objects::nonNull)
68 .collect(Collectors.toList());
69
70 if (sessions.size() != addresses.size()) {
71 String message = "Mismatch! Asked for "
72 + addresses.size()
73 + " sessions, but only found "
74 + sessions.size()
75 + "!";
76 logger.warn(message);
77 throw new NoSessionException(message);
78 }
79
80 return sessions;
81 }
82 }
83
84 @Override
85 public List<Integer> getSubDeviceSessions(String name) {
86 final var recipientId = resolveRecipient(name);
87
88 synchronized (cachedSessions) {
89 return getKeysLocked(recipientId).stream()
90 // get all sessions for recipient except main device session
91 .filter(key -> key.getDeviceId() != 1 && key.getRecipientId().equals(recipientId))
92 .map(Key::getDeviceId)
93 .collect(Collectors.toList());
94 }
95 }
96
97 @Override
98 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
99 final var key = getKey(address);
100
101 synchronized (cachedSessions) {
102 storeSessionLocked(key, session);
103 }
104 }
105
106 @Override
107 public boolean containsSession(SignalProtocolAddress address) {
108 final var key = getKey(address);
109
110 synchronized (cachedSessions) {
111 final var session = loadSessionLocked(key);
112 return isActive(session);
113 }
114 }
115
116 @Override
117 public void deleteSession(SignalProtocolAddress address) {
118 final var key = getKey(address);
119
120 synchronized (cachedSessions) {
121 deleteSessionLocked(key);
122 }
123 }
124
125 @Override
126 public void deleteAllSessions(String name) {
127 final var recipientId = resolveRecipient(name);
128 deleteAllSessions(recipientId);
129 }
130
131 public void deleteAllSessions(RecipientId recipientId) {
132 synchronized (cachedSessions) {
133 final var keys = getKeysLocked(recipientId);
134 for (var key : keys) {
135 deleteSessionLocked(key);
136 }
137 }
138 }
139
140 @Override
141 public void archiveSession(final SignalProtocolAddress address) {
142 final var key = getKey(address);
143
144 synchronized (cachedSessions) {
145 archiveSessionLocked(key);
146 }
147 }
148
149 @Override
150 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
151 final var recipientIdToNameMap = addressNames.stream()
152 .collect(Collectors.toMap(this::resolveRecipient, name -> name));
153 synchronized (cachedSessions) {
154 return recipientIdToNameMap.keySet()
155 .stream()
156 .flatMap(recipientId -> getKeysLocked(recipientId).stream())
157 .filter(key -> isActive(this.loadSessionLocked(key)))
158 .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), key.getDeviceId()))
159 .collect(Collectors.toSet());
160 }
161 }
162
163 public void archiveAllSessions() {
164 synchronized (cachedSessions) {
165 final var keys = getKeysLocked();
166 for (var key : keys) {
167 archiveSessionLocked(key);
168 }
169 }
170 }
171
172 public void archiveSessions(final RecipientId recipientId) {
173 synchronized (cachedSessions) {
174 getKeysLocked().stream()
175 .filter(key -> key.recipientId.equals(recipientId))
176 .forEach(this::archiveSessionLocked);
177 }
178 }
179
180 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
181 synchronized (cachedSessions) {
182 final var keys = getKeysLocked(toBeMergedRecipientId);
183 final var otherHasSession = keys.size() > 0;
184 if (!otherHasSession) {
185 return;
186 }
187
188 final var hasSession = getKeysLocked(recipientId).size() > 0;
189 if (hasSession) {
190 logger.debug("To be merged recipient had sessions, deleting.");
191 deleteAllSessions(toBeMergedRecipientId);
192 } else {
193 logger.debug("Only to be merged recipient had sessions, re-assigning to the new recipient.");
194 for (var key : keys) {
195 final var session = loadSessionLocked(key);
196 deleteSessionLocked(key);
197 if (session == null) {
198 continue;
199 }
200 final var newKey = new Key(recipientId, key.getDeviceId());
201 storeSessionLocked(newKey, session);
202 }
203 }
204 }
205 }
206
207 /**
208 * @param identifier can be either a serialized uuid or a e164 phone number
209 */
210 private RecipientId resolveRecipient(String identifier) {
211 return resolver.resolveRecipient(identifier);
212 }
213
214 private Key getKey(final SignalProtocolAddress address) {
215 final var recipientId = resolveRecipient(address.getName());
216 return new Key(recipientId, address.getDeviceId());
217 }
218
219 private List<Key> getKeysLocked(RecipientId recipientId) {
220 final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
221 if (files == null) {
222 return List.of();
223 }
224 return parseFileNames(files);
225 }
226
227 private Collection<Key> getKeysLocked() {
228 final var files = sessionsPath.listFiles();
229 if (files == null) {
230 return List.of();
231 }
232 return parseFileNames(files);
233 }
234
235 final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
236
237 private List<Key> parseFileNames(final File[] files) {
238 return Arrays.stream(files)
239 .map(f -> sessionFileNamePattern.matcher(f.getName()))
240 .filter(Matcher::matches)
241 .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
242 Integer.parseInt(matcher.group(2))))
243 .collect(Collectors.toList());
244 }
245
246 private File getSessionFile(Key key) {
247 try {
248 IOUtils.createPrivateDirectories(sessionsPath);
249 } catch (IOException e) {
250 throw new AssertionError("Failed to create sessions path", e);
251 }
252 return new File(sessionsPath, key.getRecipientId().getId() + "_" + key.getDeviceId());
253 }
254
255 private SessionRecord loadSessionLocked(final Key key) {
256 {
257 final var session = cachedSessions.get(key);
258 if (session != null) {
259 return session;
260 }
261 }
262
263 final var file = getSessionFile(key);
264 if (!file.exists()) {
265 return null;
266 }
267 try (var inputStream = new FileInputStream(file)) {
268 final var session = new SessionRecord(inputStream.readAllBytes());
269 cachedSessions.put(key, session);
270 return session;
271 } catch (IOException e) {
272 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
273 return null;
274 }
275 }
276
277 private void storeSessionLocked(final Key key, final SessionRecord session) {
278 cachedSessions.put(key, session);
279
280 final var file = getSessionFile(key);
281 try {
282 try (var outputStream = new FileOutputStream(file)) {
283 outputStream.write(session.serialize());
284 }
285 } catch (IOException e) {
286 logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
287 try {
288 Files.delete(file.toPath());
289 try (var outputStream = new FileOutputStream(file)) {
290 outputStream.write(session.serialize());
291 }
292 } catch (IOException e2) {
293 logger.error("Failed to store session file {}: {}", file, e2.getMessage());
294 }
295 }
296 }
297
298 private void archiveSessionLocked(final Key key) {
299 final var session = loadSessionLocked(key);
300 if (session == null) {
301 return;
302 }
303 session.archiveCurrentState();
304 storeSessionLocked(key, session);
305 }
306
307 private void deleteSessionLocked(final Key key) {
308 cachedSessions.remove(key);
309
310 final var file = getSessionFile(key);
311 if (!file.exists()) {
312 return;
313 }
314 try {
315 Files.delete(file.toPath());
316 } catch (IOException e) {
317 logger.error("Failed to delete session file {}: {}", file, e.getMessage());
318 }
319 }
320
321 private static boolean isActive(SessionRecord record) {
322 return record != null
323 && record.hasSenderChain()
324 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
325 }
326
327 private static final class Key {
328
329 private final RecipientId recipientId;
330 private final int deviceId;
331
332 public Key(final RecipientId recipientId, final int deviceId) {
333 this.recipientId = recipientId;
334 this.deviceId = deviceId;
335 }
336
337 public RecipientId getRecipientId() {
338 return recipientId;
339 }
340
341 public int getDeviceId() {
342 return deviceId;
343 }
344
345 @Override
346 public boolean equals(final Object o) {
347 if (this == o) return true;
348 if (o == null || getClass() != o.getClass()) return false;
349
350 final var key = (Key) o;
351
352 if (deviceId != key.deviceId) return false;
353 return recipientId.equals(key.recipientId);
354 }
355
356 @Override
357 public int hashCode() {
358 int result = recipientId.hashCode();
359 result = 31 * result + deviceId;
360 return result;
361 }
362 }
363 }