]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
9b3e21dbcd6ce7ff8f506283e1ef034f596b224c
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.storage.recipients.RecipientId;
4 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
5 import org.asamk.signal.manager.util.IOUtils;
6 import org.slf4j.Logger;
7 import org.slf4j.LoggerFactory;
8 import org.whispersystems.libsignal.NoSessionException;
9 import org.whispersystems.libsignal.SignalProtocolAddress;
10 import org.whispersystems.libsignal.protocol.CiphertextMessage;
11 import org.whispersystems.libsignal.state.SessionRecord;
12 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
13
14 import java.io.File;
15 import java.io.FileInputStream;
16 import java.io.FileOutputStream;
17 import java.io.IOException;
18 import java.nio.file.Files;
19 import java.util.Arrays;
20 import java.util.Collection;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.Objects;
25 import java.util.Set;
26 import java.util.regex.Matcher;
27 import java.util.regex.Pattern;
28 import java.util.stream.Collectors;
29
30 public class SessionStore implements SignalServiceSessionStore {
31
32 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
33
34 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
35
36 private final File sessionsPath;
37
38 private final RecipientResolver resolver;
39
40 public SessionStore(
41 final File sessionsPath, final RecipientResolver resolver
42 ) {
43 this.sessionsPath = sessionsPath;
44 this.resolver = resolver;
45 }
46
47 @Override
48 public SessionRecord loadSession(SignalProtocolAddress address) {
49 final var key = getKey(address);
50
51 synchronized (cachedSessions) {
52 final var session = loadSessionLocked(key);
53 if (session == null) {
54 return new SessionRecord();
55 }
56 return session;
57 }
58 }
59
60 @Override
61 public List<SessionRecord> loadExistingSessions(final List<SignalProtocolAddress> addresses) throws NoSessionException {
62 final var keys = addresses.stream().map(this::getKey).toList();
63
64 synchronized (cachedSessions) {
65 final var sessions = keys.stream().map(this::loadSessionLocked).filter(Objects::nonNull).toList();
66
67 if (sessions.size() != addresses.size()) {
68 String message = "Mismatch! Asked for "
69 + addresses.size()
70 + " sessions, but only found "
71 + sessions.size()
72 + "!";
73 logger.warn(message);
74 throw new NoSessionException(message);
75 }
76
77 return sessions;
78 }
79 }
80
81 @Override
82 public List<Integer> getSubDeviceSessions(String name) {
83 final var recipientId = resolveRecipient(name);
84
85 synchronized (cachedSessions) {
86 return getKeysLocked(recipientId).stream()
87 // get all sessions for recipient except main device session
88 .filter(key -> key.deviceId() != 1 && key.recipientId().equals(recipientId))
89 .map(Key::deviceId)
90 .toList();
91 }
92 }
93
94 @Override
95 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
96 final var key = getKey(address);
97
98 synchronized (cachedSessions) {
99 storeSessionLocked(key, session);
100 }
101 }
102
103 @Override
104 public boolean containsSession(SignalProtocolAddress address) {
105 final var key = getKey(address);
106
107 synchronized (cachedSessions) {
108 final var session = loadSessionLocked(key);
109 return isActive(session);
110 }
111 }
112
113 @Override
114 public void deleteSession(SignalProtocolAddress address) {
115 final var key = getKey(address);
116
117 synchronized (cachedSessions) {
118 deleteSessionLocked(key);
119 }
120 }
121
122 @Override
123 public void deleteAllSessions(String name) {
124 final var recipientId = resolveRecipient(name);
125 deleteAllSessions(recipientId);
126 }
127
128 public void deleteAllSessions(RecipientId recipientId) {
129 synchronized (cachedSessions) {
130 final var keys = getKeysLocked(recipientId);
131 for (var key : keys) {
132 deleteSessionLocked(key);
133 }
134 }
135 }
136
137 @Override
138 public void archiveSession(final SignalProtocolAddress address) {
139 final var key = getKey(address);
140
141 synchronized (cachedSessions) {
142 archiveSessionLocked(key);
143 }
144 }
145
146 @Override
147 public Set<SignalProtocolAddress> getAllAddressesWithActiveSessions(final List<String> addressNames) {
148 final var recipientIdToNameMap = addressNames.stream()
149 .collect(Collectors.toMap(this::resolveRecipient, name -> name));
150 synchronized (cachedSessions) {
151 return recipientIdToNameMap.keySet()
152 .stream()
153 .flatMap(recipientId -> getKeysLocked(recipientId).stream())
154 .filter(key -> isActive(this.loadSessionLocked(key)))
155 .map(key -> new SignalProtocolAddress(recipientIdToNameMap.get(key.recipientId), key.deviceId()))
156 .collect(Collectors.toSet());
157 }
158 }
159
160 public void archiveAllSessions() {
161 synchronized (cachedSessions) {
162 final var keys = getKeysLocked();
163 for (var key : keys) {
164 archiveSessionLocked(key);
165 }
166 }
167 }
168
169 public void archiveSessions(final RecipientId recipientId) {
170 synchronized (cachedSessions) {
171 getKeysLocked().stream()
172 .filter(key -> key.recipientId.equals(recipientId))
173 .forEach(this::archiveSessionLocked);
174 }
175 }
176
177 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
178 synchronized (cachedSessions) {
179 final var keys = getKeysLocked(toBeMergedRecipientId);
180 final var otherHasSession = keys.size() > 0;
181 if (!otherHasSession) {
182 return;
183 }
184
185 final var hasSession = getKeysLocked(recipientId).size() > 0;
186 if (hasSession) {
187 logger.debug("To be merged recipient had sessions, deleting.");
188 deleteAllSessions(toBeMergedRecipientId);
189 } else {
190 logger.debug("Only to be merged recipient had sessions, re-assigning to the new recipient.");
191 for (var key : keys) {
192 final var session = loadSessionLocked(key);
193 deleteSessionLocked(key);
194 if (session == null) {
195 continue;
196 }
197 final var newKey = new Key(recipientId, key.deviceId());
198 storeSessionLocked(newKey, session);
199 }
200 }
201 }
202 }
203
204 /**
205 * @param identifier can be either a serialized uuid or a e164 phone number
206 */
207 private RecipientId resolveRecipient(String identifier) {
208 return resolver.resolveRecipient(identifier);
209 }
210
211 private Key getKey(final SignalProtocolAddress address) {
212 final var recipientId = resolveRecipient(address.getName());
213 return new Key(recipientId, address.getDeviceId());
214 }
215
216 private List<Key> getKeysLocked(RecipientId recipientId) {
217 final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.id() + "_"));
218 if (files == null) {
219 return List.of();
220 }
221 return parseFileNames(files);
222 }
223
224 private Collection<Key> getKeysLocked() {
225 final var files = sessionsPath.listFiles();
226 if (files == null) {
227 return List.of();
228 }
229 return parseFileNames(files);
230 }
231
232 final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
233
234 private List<Key> parseFileNames(final File[] files) {
235 return Arrays.stream(files)
236 .map(f -> sessionFileNamePattern.matcher(f.getName()))
237 .filter(Matcher::matches)
238 .map(matcher -> {
239 final var recipientId = resolver.resolveRecipient(Long.parseLong(matcher.group(1)));
240 if (recipientId == null) {
241 return null;
242 }
243 return new Key(recipientId, Integer.parseInt(matcher.group(2)));
244 })
245 .filter(Objects::nonNull)
246 .toList();
247 }
248
249 private File getSessionFile(Key key) {
250 try {
251 IOUtils.createPrivateDirectories(sessionsPath);
252 } catch (IOException e) {
253 throw new AssertionError("Failed to create sessions path", e);
254 }
255 return new File(sessionsPath, key.recipientId().id() + "_" + key.deviceId());
256 }
257
258 private SessionRecord loadSessionLocked(final Key key) {
259 {
260 final var session = cachedSessions.get(key);
261 if (session != null) {
262 return session;
263 }
264 }
265
266 final var file = getSessionFile(key);
267 if (!file.exists()) {
268 return null;
269 }
270 try (var inputStream = new FileInputStream(file)) {
271 final var session = new SessionRecord(inputStream.readAllBytes());
272 cachedSessions.put(key, session);
273 return session;
274 } catch (IOException e) {
275 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
276 return null;
277 }
278 }
279
280 private void storeSessionLocked(final Key key, final SessionRecord session) {
281 cachedSessions.put(key, session);
282
283 final var file = getSessionFile(key);
284 try {
285 try (var outputStream = new FileOutputStream(file)) {
286 outputStream.write(session.serialize());
287 }
288 } catch (IOException e) {
289 logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
290 try {
291 Files.delete(file.toPath());
292 try (var outputStream = new FileOutputStream(file)) {
293 outputStream.write(session.serialize());
294 }
295 } catch (IOException e2) {
296 logger.error("Failed to store session file {}: {}", file, e2.getMessage());
297 }
298 }
299 }
300
301 private void archiveSessionLocked(final Key key) {
302 final var session = loadSessionLocked(key);
303 if (session == null) {
304 return;
305 }
306 session.archiveCurrentState();
307 storeSessionLocked(key, session);
308 }
309
310 private void deleteSessionLocked(final Key key) {
311 cachedSessions.remove(key);
312
313 final var file = getSessionFile(key);
314 if (!file.exists()) {
315 return;
316 }
317 try {
318 Files.delete(file.toPath());
319 } catch (IOException e) {
320 logger.error("Failed to delete session file {}: {}", file, e.getMessage());
321 }
322 }
323
324 private static boolean isActive(SessionRecord record) {
325 return record != null
326 && record.hasSenderChain()
327 && record.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
328 }
329
330 private record Key(RecipientId recipientId, int deviceId) {}
331 }