]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
3e5555f49afd2bb381c269b6b3fbe5e9bc569077
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.storage.recipients.RecipientId;
4 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
5 import org.asamk.signal.manager.util.IOUtils;
6 import org.asamk.signal.manager.util.Utils;
7 import org.slf4j.Logger;
8 import org.slf4j.LoggerFactory;
9 import org.whispersystems.libsignal.SignalProtocolAddress;
10 import org.whispersystems.libsignal.protocol.CiphertextMessage;
11 import org.whispersystems.libsignal.state.SessionRecord;
12 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
13
14 import java.io.File;
15 import java.io.FileInputStream;
16 import java.io.FileOutputStream;
17 import java.io.IOException;
18 import java.nio.file.Files;
19 import java.util.Arrays;
20 import java.util.Collection;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.regex.Matcher;
25 import java.util.regex.Pattern;
26 import java.util.stream.Collectors;
27
28 public class SessionStore implements SignalServiceSessionStore {
29
30 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
31
32 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
33
34 private final File sessionsPath;
35
36 private final RecipientResolver resolver;
37
38 public SessionStore(
39 final File sessionsPath, final RecipientResolver resolver
40 ) {
41 this.sessionsPath = sessionsPath;
42 this.resolver = resolver;
43 }
44
45 @Override
46 public SessionRecord loadSession(SignalProtocolAddress address) {
47 final var key = getKey(address);
48
49 synchronized (cachedSessions) {
50 final var session = loadSessionLocked(key);
51 if (session == null) {
52 return new SessionRecord();
53 }
54 return session;
55 }
56 }
57
58 @Override
59 public List<Integer> getSubDeviceSessions(String name) {
60 final var recipientId = resolveRecipient(name);
61
62 synchronized (cachedSessions) {
63 return getKeysLocked(recipientId).stream()
64 // get all sessions for recipient except main device session
65 .filter(key -> key.getDeviceId() != 1 && key.getRecipientId().equals(recipientId))
66 .map(Key::getDeviceId)
67 .collect(Collectors.toList());
68 }
69 }
70
71 @Override
72 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
73 final var key = getKey(address);
74
75 synchronized (cachedSessions) {
76 storeSessionLocked(key, session);
77 }
78 }
79
80 @Override
81 public boolean containsSession(SignalProtocolAddress address) {
82 final var key = getKey(address);
83
84 synchronized (cachedSessions) {
85 final var session = loadSessionLocked(key);
86 if (session == null) {
87 return false;
88 }
89
90 return session.hasSenderChain() && session.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
91 }
92 }
93
94 @Override
95 public void deleteSession(SignalProtocolAddress address) {
96 final var key = getKey(address);
97
98 synchronized (cachedSessions) {
99 deleteSessionLocked(key);
100 }
101 }
102
103 @Override
104 public void deleteAllSessions(String name) {
105 final var recipientId = resolveRecipient(name);
106 deleteAllSessions(recipientId);
107 }
108
109 public void deleteAllSessions(RecipientId recipientId) {
110 synchronized (cachedSessions) {
111 final var keys = getKeysLocked(recipientId);
112 for (var key : keys) {
113 deleteSessionLocked(key);
114 }
115 }
116 }
117
118 @Override
119 public void archiveSession(final SignalProtocolAddress address) {
120 final var key = getKey(address);
121
122 synchronized (cachedSessions) {
123 archiveSessionLocked(key);
124 }
125 }
126
127 public void archiveAllSessions() {
128 synchronized (cachedSessions) {
129 final var keys = getKeysLocked();
130 for (var key : keys) {
131 archiveSessionLocked(key);
132 }
133 }
134 }
135
136 public void archiveSessions(final RecipientId recipientId) {
137 synchronized (cachedSessions) {
138 getKeysLocked().stream()
139 .filter(key -> key.recipientId.equals(recipientId))
140 .forEach(this::archiveSessionLocked);
141 }
142 }
143
144 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
145 synchronized (cachedSessions) {
146 final var otherHasSession = getKeysLocked(toBeMergedRecipientId).size() > 0;
147 if (!otherHasSession) {
148 return;
149 }
150
151 final var hasSession = getKeysLocked(recipientId).size() > 0;
152 if (hasSession) {
153 logger.debug("To be merged recipient had sessions, deleting.");
154 deleteAllSessions(toBeMergedRecipientId);
155 } else {
156 logger.debug("To be merged recipient had sessions, re-assigning to the new recipient.");
157 final var keys = getKeysLocked(toBeMergedRecipientId);
158 for (var key : keys) {
159 final var session = loadSessionLocked(key);
160 deleteSessionLocked(key);
161 if (session == null) {
162 continue;
163 }
164 final var newKey = new Key(recipientId, key.getDeviceId());
165 storeSessionLocked(newKey, session);
166 }
167 }
168 }
169 }
170
171 /**
172 * @param identifier can be either a serialized uuid or a e164 phone number
173 */
174 private RecipientId resolveRecipient(String identifier) {
175 return resolver.resolveRecipient(Utils.getSignalServiceAddressFromIdentifier(identifier));
176 }
177
178 private Key getKey(final SignalProtocolAddress address) {
179 final var recipientId = resolveRecipient(address.getName());
180 return new Key(recipientId, address.getDeviceId());
181 }
182
183 private List<Key> getKeysLocked(RecipientId recipientId) {
184 final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
185 if (files == null) {
186 return List.of();
187 }
188 return parseFileNames(files);
189 }
190
191 private Collection<Key> getKeysLocked() {
192 final var files = sessionsPath.listFiles();
193 if (files == null) {
194 return List.of();
195 }
196 return parseFileNames(files);
197 }
198
199 final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
200
201 private List<Key> parseFileNames(final File[] files) {
202 return Arrays.stream(files)
203 .map(f -> sessionFileNamePattern.matcher(f.getName()))
204 .filter(Matcher::matches)
205 .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
206 Integer.parseInt(matcher.group(2))))
207 .collect(Collectors.toList());
208 }
209
210 private File getSessionFile(Key key) {
211 try {
212 IOUtils.createPrivateDirectories(sessionsPath);
213 } catch (IOException e) {
214 throw new AssertionError("Failed to create sessions path", e);
215 }
216 return new File(sessionsPath, key.getRecipientId().getId() + "_" + key.getDeviceId());
217 }
218
219 private SessionRecord loadSessionLocked(final Key key) {
220 {
221 final var session = cachedSessions.get(key);
222 if (session != null) {
223 return session;
224 }
225 }
226
227 final var file = getSessionFile(key);
228 if (!file.exists()) {
229 return null;
230 }
231 try (var inputStream = new FileInputStream(file)) {
232 final var session = new SessionRecord(inputStream.readAllBytes());
233 cachedSessions.put(key, session);
234 return session;
235 } catch (IOException e) {
236 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
237 return null;
238 }
239 }
240
241 private void storeSessionLocked(final Key key, final SessionRecord session) {
242 cachedSessions.put(key, session);
243
244 final var file = getSessionFile(key);
245 try {
246 try (var outputStream = new FileOutputStream(file)) {
247 outputStream.write(session.serialize());
248 }
249 } catch (IOException e) {
250 logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
251 try {
252 Files.delete(file.toPath());
253 try (var outputStream = new FileOutputStream(file)) {
254 outputStream.write(session.serialize());
255 }
256 } catch (IOException e2) {
257 logger.error("Failed to store session file {}: {}", file, e2.getMessage());
258 }
259 }
260 }
261
262 private void archiveSessionLocked(final Key key) {
263 final var session = loadSessionLocked(key);
264 if (session == null) {
265 return;
266 }
267 session.archiveCurrentState();
268 storeSessionLocked(key, session);
269 }
270
271 private void deleteSessionLocked(final Key key) {
272 cachedSessions.remove(key);
273
274 final var file = getSessionFile(key);
275 if (!file.exists()) {
276 return;
277 }
278 try {
279 Files.delete(file.toPath());
280 } catch (IOException e) {
281 logger.error("Failed to delete session file {}: {}", file, e.getMessage());
282 }
283 }
284
285 private static final class Key {
286
287 private final RecipientId recipientId;
288 private final int deviceId;
289
290 public Key(final RecipientId recipientId, final int deviceId) {
291 this.recipientId = recipientId;
292 this.deviceId = deviceId;
293 }
294
295 public RecipientId getRecipientId() {
296 return recipientId;
297 }
298
299 public int getDeviceId() {
300 return deviceId;
301 }
302
303 @Override
304 public boolean equals(final Object o) {
305 if (this == o) return true;
306 if (o == null || getClass() != o.getClass()) return false;
307
308 final var key = (Key) o;
309
310 if (deviceId != key.deviceId) return false;
311 return recipientId.equals(key.recipientId);
312 }
313
314 @Override
315 public int hashCode() {
316 int result = recipientId.hashCode();
317 result = 31 * result + deviceId;
318 return result;
319 }
320 }
321 }