]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
8011a69d5ce4e1e9c4d3aa3361cdbc3c5ace9a1f
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.storage.recipients.RecipientId;
4 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
5 import org.asamk.signal.manager.util.IOUtils;
6 import org.slf4j.Logger;
7 import org.slf4j.LoggerFactory;
8 import org.whispersystems.libsignal.NoSessionException;
9 import org.whispersystems.libsignal.SignalProtocolAddress;
10 import org.whispersystems.libsignal.protocol.CiphertextMessage;
11 import org.whispersystems.libsignal.state.SessionRecord;
12 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
13
14 import java.io.File;
15 import java.io.FileInputStream;
16 import java.io.FileOutputStream;
17 import java.io.IOException;
18 import java.nio.file.Files;
19 import java.util.Arrays;
20 import java.util.Collection;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Objects;
25 import java.util.Set;
26 import java.util.regex.Matcher;
27 import java.util.regex.Pattern;
28 import java.util.stream.Collectors;
29
30 public class SessionStore implements SignalServiceSessionStore {
31
32 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
33
34 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
35
36 private final File sessionsPath;
37
38 private final RecipientResolver resolver;
39
40 public SessionStore(
41 final File sessionsPath, final RecipientResolver resolver
42 ) {
43 this.sessionsPath = sessionsPath;
44 this.resolver = resolver;
45 }
46
47 @Override
48 public SessionRecord loadSession(SignalProtocolAddress address) {
49 final var key = getKey(address);
50
51 synchronized (cachedSessions) {
52 final var session = loadSessionLocked(key);
53 if (session == null) {
54 return new SessionRecord();
55 }
56 return session;
57 }
58 }
59
60 @Override
61 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
62 final var keys = addresses.stream().map(this::getKey).collect(Collectors.toList());
63
64 synchronized (cachedSessions) {
65 final var sessions = keys.stream()
66 .map(this::loadSessionLocked)
67 .filter(Objects::nonNull)
68 .collect(Collectors.toList());
69
70 if (sessions.size() != addresses.size()) {
71 String message = "Mismatch! Asked for "
72 + addresses.size()
73 + " sessions, but only found "
74 + sessions.size()
75 + "!";
76 logger.warn(message);
77 throw new NoSessionException(message);
78 }
79
80 return sessions;
81 }
82 }
83
84 @Override
85 public List<Integer> getSubDeviceSessions(String name) {
86 final var recipientId = resolveRecipient(name);
87
88 synchronized (cachedSessions) {
89 return getKeysLocked(recipientId).stream()
90 // get all sessions for recipient except main device session
91 .filter(key -> key.deviceId() != 1 && key.recipientId().equals(recipientId))
92 .map(Key::deviceId)
93 .collect(Collectors.toList());
94 }
95 }
96
97 @Override
98 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
99 final var key = getKey(address);
100
101 synchronized (cachedSessions) {
102 storeSessionLocked(key, session);
103 }
104 }
105
106 @Override
107 public boolean containsSession(SignalProtocolAddress address) {
108 final var key = getKey(address);
109
110 synchronized (cachedSessions) {
111 final var session = loadSessionLocked(key);
112 return isActive(session);
113 }
114 }
115
116 @Override
117 public void deleteSession(SignalProtocolAddress address) {
118 final var key = getKey(address);
119
120 synchronized (cachedSessions) {
121 deleteSessionLocked(key);
122 }
123 }
124
125 @Override
126 public void deleteAllSessions(String name) {
127 final var recipientId = resolveRecipient(name);
128 deleteAllSessions(recipientId);
129 }
130
131 public void deleteAllSessions(RecipientId recipientId) {
132 synchronized (cachedSessions) {
133 final var keys = getKeysLocked(recipientId);
134 for (var key : keys) {
135 deleteSessionLocked(key);
136 }
137 }
138 }
139
140 @Override
141 public void archiveSession(final SignalProtocolAddress address) {
142 final var key = getKey(address);
143
144 synchronized (cachedSessions) {
145 archiveSessionLocked(key);
146 }
147 }
148
149 @Override
150 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
151 final var recipientIdToNameMap = addressNames.stream()
152 .collect(Collectors.toMap(this::resolveRecipient, name -> name));
153 synchronized (cachedSessions) {
154 return recipientIdToNameMap.keySet()
155 .stream()
156 .flatMap(recipientId -> getKeysLocked(recipientId).stream())
157 .filter(key -> isActive(this.loadSessionLocked(key)))
158 .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), key.deviceId()))
159 .collect(Collectors.toSet());
160 }
161 }
162
163 public void archiveAllSessions() {
164 synchronized (cachedSessions) {
165 final var keys = getKeysLocked();
166 for (var key : keys) {
167 archiveSessionLocked(key);
168 }
169 }
170 }
171
172 public void archiveSessions(final RecipientId recipientId) {
173 synchronized (cachedSessions) {
174 getKeysLocked().stream()
175 .filter(key -> key.recipientId.equals(recipientId))
176 .forEach(this::archiveSessionLocked);
177 }
178 }
179
180 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
181 synchronized (cachedSessions) {
182 final var keys = getKeysLocked(toBeMergedRecipientId);
183 final var otherHasSession = keys.size() > 0;
184 if (!otherHasSession) {
185 return;
186 }
187
188 final var hasSession = getKeysLocked(recipientId).size() > 0;
189 if (hasSession) {
190 logger.debug("To be merged recipient had sessions, deleting.");
191 deleteAllSessions(toBeMergedRecipientId);
192 } else {
193 logger.debug("Only to be merged recipient had sessions, re-assigning to the new recipient.");
194 for (var key : keys) {
195 final var session = loadSessionLocked(key);
196 deleteSessionLocked(key);
197 if (session == null) {
198 continue;
199 }
200 final var newKey = new Key(recipientId, key.deviceId());
201 storeSessionLocked(newKey, session);
202 }
203 }
204 }
205 }
206
207 /**
208 * @param identifier can be either a serialized uuid or a e164 phone number
209 */
210 private RecipientId resolveRecipient(String identifier) {
211 return resolver.resolveRecipient(identifier);
212 }
213
214 private Key getKey(final SignalProtocolAddress address) {
215 final var recipientId = resolveRecipient(address.getName());
216 return new Key(recipientId, address.getDeviceId());
217 }
218
219 private List<Key> getKeysLocked(RecipientId recipientId) {
220 final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.id() + "_"));
221 if (files == null) {
222 return List.of();
223 }
224 return parseFileNames(files);
225 }
226
227 private Collection<Key> getKeysLocked() {
228 final var files = sessionsPath.listFiles();
229 if (files == null) {
230 return List.of();
231 }
232 return parseFileNames(files);
233 }
234
235 final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
236
237 private List<Key> parseFileNames(final File[] files) {
238 return Arrays.stream(files)
239 .map(f -> sessionFileNamePattern.matcher(f.getName()))
240 .filter(Matcher::matches)
241 .map(matcher -> {
242 final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1)));
243 if (recipientId == null) {
244 return null;
245 }
246 return new Key(recipientId, Integer.parseInt(matcher.group(2)));
247 })
248 .filter(Objects::nonNull)
249 .collect(Collectors.toList());
250 }
251
252 private File getSessionFile(Key key) {
253 try {
254 IOUtils.createPrivateDirectories(sessionsPath);
255 } catch (IOException e) {
256 throw new AssertionError("Failed to create sessions path", e);
257 }
258 return new File(sessionsPath, key.recipientId().id() + "_" + key.deviceId());
259 }
260
261 private SessionRecord loadSessionLocked(final Key key) {
262 {
263 final var session = cachedSessions.get(key);
264 if (session != null) {
265 return session;
266 }
267 }
268
269 final var file = getSessionFile(key);
270 if (!file.exists()) {
271 return null;
272 }
273 try (var inputStream = new FileInputStream(file)) {
274 final var session = new SessionRecord(inputStream.readAllBytes());
275 cachedSessions.put(key, session);
276 return session;
277 } catch (IOException e) {
278 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
279 return null;
280 }
281 }
282
283 private void storeSessionLocked(final Key key, final SessionRecord session) {
284 cachedSessions.put(key, session);
285
286 final var file = getSessionFile(key);
287 try {
288 try (var outputStream = new FileOutputStream(file)) {
289 outputStream.write(session.serialize());
290 }
291 } catch (IOException e) {
292 logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
293 try {
294 Files.delete(file.toPath());
295 try (var outputStream = new FileOutputStream(file)) {
296 outputStream.write(session.serialize());
297 }
298 } catch (IOException e2) {
299 logger.error("Failed to store session file {}: {}", file, e2.getMessage());
300 }
301 }
302 }
303
304 private void archiveSessionLocked(final Key key) {
305 final var session = loadSessionLocked(key);
306 if (session == null) {
307 return;
308 }
309 session.archiveCurrentState();
310 storeSessionLocked(key, session);
311 }
312
313 private void deleteSessionLocked(final Key key) {
314 cachedSessions.remove(key);
315
316 final var file = getSessionFile(key);
317 if (!file.exists()) {
318 return;
319 }
320 try {
321 Files.delete(file.toPath());
322 } catch (IOException e) {
323 logger.error("Failed to delete session file {}: {}", file, e.getMessage());
324 }
325 }
326
327 private static boolean isActive(SessionRecord record) {
328 return record != null
329 && record.hasSenderChain()
330 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
331 }
332
333 private record Key(RecipientId recipientId, int deviceId) {}
334 }