]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
ef0a055bfe4f0a46c4af79777696fd2614c10d06
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.storage.recipients.RecipientId;
4 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
5 import org.asamk.signal.manager.util.IOUtils;
6 import org.asamk.signal.manager.util.Utils;
7 import org.slf4j.Logger;
8 import org.slf4j.LoggerFactory;
9 import org.whispersystems.libsignal.NoSessionException;
10 import org.whispersystems.libsignal.SignalProtocolAddress;
11 import org.whispersystems.libsignal.protocol.CiphertextMessage;
12 import org.whispersystems.libsignal.state.SessionRecord;
13 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
14
15 import java.io.File;
16 import java.io.FileInputStream;
17 import java.io.FileOutputStream;
18 import java.io.IOException;
19 import java.nio.file.Files;
20 import java.util.Arrays;
21 import java.util.Collection;
22 import java.util.HashMap;
23 import java.util.List;
24 import java.util.Map;
25 import java.util.Objects;
26 import java.util.regex.Matcher;
27 import java.util.regex.Pattern;
28 import java.util.stream.Collectors;
29
30 public class SessionStore implements SignalServiceSessionStore {
31
32 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
33
34 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
35
36 private final File sessionsPath;
37
38 private final RecipientResolver resolver;
39
40 public SessionStore(
41 final File sessionsPath, final RecipientResolver resolver
42 ) {
43 this.sessionsPath = sessionsPath;
44 this.resolver = resolver;
45 }
46
47 @Override
48 public SessionRecord loadSession(SignalProtocolAddress address) {
49 final var key = getKey(address);
50
51 synchronized (cachedSessions) {
52 final var session = loadSessionLocked(key);
53 if (session == null) {
54 return new SessionRecord();
55 }
56 return session;
57 }
58 }
59
60 @Override
61 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
62 final var keys = addresses.stream().map(this::getKey).collect(Collectors.toList());
63
64 synchronized (cachedSessions) {
65 final var sessions = keys.stream()
66 .map(this::loadSessionLocked)
67 .filter(Objects::nonNull)
68 .collect(Collectors.toList());
69
70 if (sessions.size() != addresses.size()) {
71 String message = "Mismatch! Asked for "
72 + addresses.size()
73 + " sessions, but only found "
74 + sessions.size()
75 + "!";
76 logger.warn(message);
77 throw new NoSessionException(message);
78 }
79
80 return sessions;
81 }
82 }
83
84 @Override
85 public List<Integer> getSubDeviceSessions(String name) {
86 final var recipientId = resolveRecipient(name);
87
88 synchronized (cachedSessions) {
89 return getKeysLocked(recipientId).stream()
90 // get all sessions for recipient except main device session
91 .filter(key -> key.getDeviceId() != 1 && key.getRecipientId().equals(recipientId))
92 .map(Key::getDeviceId)
93 .collect(Collectors.toList());
94 }
95 }
96
97 @Override
98 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
99 final var key = getKey(address);
100
101 synchronized (cachedSessions) {
102 storeSessionLocked(key, session);
103 }
104 }
105
106 @Override
107 public boolean containsSession(SignalProtocolAddress address) {
108 final var key = getKey(address);
109
110 synchronized (cachedSessions) {
111 final var session = loadSessionLocked(key);
112 if (session == null) {
113 return false;
114 }
115
116 return session.hasSenderChain() && session.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
117 }
118 }
119
120 @Override
121 public void deleteSession(SignalProtocolAddress address) {
122 final var key = getKey(address);
123
124 synchronized (cachedSessions) {
125 deleteSessionLocked(key);
126 }
127 }
128
129 @Override
130 public void deleteAllSessions(String name) {
131 final var recipientId = resolveRecipient(name);
132 deleteAllSessions(recipientId);
133 }
134
135 public void deleteAllSessions(RecipientId recipientId) {
136 synchronized (cachedSessions) {
137 final var keys = getKeysLocked(recipientId);
138 for (var key : keys) {
139 deleteSessionLocked(key);
140 }
141 }
142 }
143
144 @Override
145 public void archiveSession(final SignalProtocolAddress address) {
146 final var key = getKey(address);
147
148 synchronized (cachedSessions) {
149 archiveSessionLocked(key);
150 }
151 }
152
153 public void archiveAllSessions() {
154 synchronized (cachedSessions) {
155 final var keys = getKeysLocked();
156 for (var key : keys) {
157 archiveSessionLocked(key);
158 }
159 }
160 }
161
162 public void archiveSessions(final RecipientId recipientId) {
163 synchronized (cachedSessions) {
164 getKeysLocked().stream()
165 .filter(key -> key.recipientId.equals(recipientId))
166 .forEach(this::archiveSessionLocked);
167 }
168 }
169
170 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
171 synchronized (cachedSessions) {
172 final var otherHasSession = getKeysLocked(toBeMergedRecipientId).size() > 0;
173 if (!otherHasSession) {
174 return;
175 }
176
177 final var hasSession = getKeysLocked(recipientId).size() > 0;
178 if (hasSession) {
179 logger.debug("To be merged recipient had sessions, deleting.");
180 deleteAllSessions(toBeMergedRecipientId);
181 } else {
182 logger.debug("To be merged recipient had sessions, re-assigning to the new recipient.");
183 final var keys = getKeysLocked(toBeMergedRecipientId);
184 for (var key : keys) {
185 final var session = loadSessionLocked(key);
186 deleteSessionLocked(key);
187 if (session == null) {
188 continue;
189 }
190 final var newKey = new Key(recipientId, key.getDeviceId());
191 storeSessionLocked(newKey, session);
192 }
193 }
194 }
195 }
196
197 /**
198 * @param identifier can be either a serialized uuid or a e164 phone number
199 */
200 private RecipientId resolveRecipient(String identifier) {
201 return resolver.resolveRecipient(Utils.getSignalServiceAddressFromIdentifier(identifier));
202 }
203
204 private Key getKey(final SignalProtocolAddress address) {
205 final var recipientId = resolveRecipient(address.getName());
206 return new Key(recipientId, address.getDeviceId());
207 }
208
209 private List<Key> getKeysLocked(RecipientId recipientId) {
210 final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
211 if (files == null) {
212 return List.of();
213 }
214 return parseFileNames(files);
215 }
216
217 private Collection<Key> getKeysLocked() {
218 final var files = sessionsPath.listFiles();
219 if (files == null) {
220 return List.of();
221 }
222 return parseFileNames(files);
223 }
224
225 final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
226
227 private List<Key> parseFileNames(final File[] files) {
228 return Arrays.stream(files)
229 .map(f -> sessionFileNamePattern.matcher(f.getName()))
230 .filter(Matcher::matches)
231 .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
232 Integer.parseInt(matcher.group(2))))
233 .collect(Collectors.toList());
234 }
235
236 private File getSessionFile(Key key) {
237 try {
238 IOUtils.createPrivateDirectories(sessionsPath);
239 } catch (IOException e) {
240 throw new AssertionError("Failed to create sessions path", e);
241 }
242 return new File(sessionsPath, key.getRecipientId().getId() + "_" + key.getDeviceId());
243 }
244
245 private SessionRecord loadSessionLocked(final Key key) {
246 {
247 final var session = cachedSessions.get(key);
248 if (session != null) {
249 return session;
250 }
251 }
252
253 final var file = getSessionFile(key);
254 if (!file.exists()) {
255 return null;
256 }
257 try (var inputStream = new FileInputStream(file)) {
258 final var session = new SessionRecord(inputStream.readAllBytes());
259 cachedSessions.put(key, session);
260 return session;
261 } catch (IOException e) {
262 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
263 return null;
264 }
265 }
266
267 private void storeSessionLocked(final Key key, final SessionRecord session) {
268 cachedSessions.put(key, session);
269
270 final var file = getSessionFile(key);
271 try {
272 try (var outputStream = new FileOutputStream(file)) {
273 outputStream.write(session.serialize());
274 }
275 } catch (IOException e) {
276 logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
277 try {
278 Files.delete(file.toPath());
279 try (var outputStream = new FileOutputStream(file)) {
280 outputStream.write(session.serialize());
281 }
282 } catch (IOException e2) {
283 logger.error("Failed to store session file {}: {}", file, e2.getMessage());
284 }
285 }
286 }
287
288 private void archiveSessionLocked(final Key key) {
289 final var session = loadSessionLocked(key);
290 if (session == null) {
291 return;
292 }
293 session.archiveCurrentState();
294 storeSessionLocked(key, session);
295 }
296
297 private void deleteSessionLocked(final Key key) {
298 cachedSessions.remove(key);
299
300 final var file = getSessionFile(key);
301 if (!file.exists()) {
302 return;
303 }
304 try {
305 Files.delete(file.toPath());
306 } catch (IOException e) {
307 logger.error("Failed to delete session file {}: {}", file, e.getMessage());
308 }
309 }
310
311 private static final class Key {
312
313 private final RecipientId recipientId;
314 private final int deviceId;
315
316 public Key(final RecipientId recipientId, final int deviceId) {
317 this.recipientId = recipientId;
318 this.deviceId = deviceId;
319 }
320
321 public RecipientId getRecipientId() {
322 return recipientId;
323 }
324
325 public int getDeviceId() {
326 return deviceId;
327 }
328
329 @Override
330 public boolean equals(final Object o) {
331 if (this == o) return true;
332 if (o == null || getClass() != o.getClass()) return false;
333
334 final var key = (Key) o;
335
336 if (deviceId != key.deviceId) return false;
337 return recipientId.equals(key.recipientId);
338 }
339
340 @Override
341 public int hashCode() {
342 int result = recipientId.hashCode();
343 result = 31 * result + deviceId;
344 return result;
345 }
346 }
347 }