]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
0773af9abc3b6b5027dafeb78e78baa361084798
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / sessions / SessionStore.java
1 package org.asamk.signal.manager.storage.sessions;
2
3 import org.asamk.signal.manager.storage.recipients.RecipientId;
4 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
5 import org.asamk.signal.manager.util.IOUtils;
6 import org.asamk.signal.manager.util.Utils;
7 import org.slf4j.Logger;
8 import org.slf4j.LoggerFactory;
9 import org.whispersystems.libsignal.SignalProtocolAddress;
10 import org.whispersystems.libsignal.protocol.CiphertextMessage;
11 import org.whispersystems.libsignal.state.SessionRecord;
12 import org.whispersystems.signalservice.api.SignalServiceSessionStore;
13
14 import java.io.File;
15 import java.io.FileInputStream;
16 import java.io.FileOutputStream;
17 import java.io.IOException;
18 import java.nio.file.Files;
19 import java.util.Arrays;
20 import java.util.Collection;
21 import java.util.HashMap;
22 import java.util.List;
23 import java.util.Map;
24 import java.util.regex.Matcher;
25 import java.util.regex.Pattern;
26 import java.util.stream.Collectors;
27
28 public class SessionStore implements SignalServiceSessionStore {
29
30 private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
31
32 private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
33
34 private final File sessionsPath;
35
36 private final RecipientResolver resolver;
37
38 public SessionStore(
39 final File sessionsPath, final RecipientResolver resolver
40 ) {
41 this.sessionsPath = sessionsPath;
42 this.resolver = resolver;
43 }
44
45 @Override
46 public SessionRecord loadSession(SignalProtocolAddress address) {
47 final var key = getKey(address);
48
49 synchronized (cachedSessions) {
50 final var session = loadSessionLocked(key);
51 if (session == null) {
52 return new SessionRecord();
53 }
54 return session;
55 }
56 }
57
58 @Override
59 public List<Integer> getSubDeviceSessions(String name) {
60 final var recipientId = resolveRecipient(name);
61
62 synchronized (cachedSessions) {
63 return getKeysLocked(recipientId).stream()
64 // get all sessions for recipient except main device session
65 .filter(key -> key.getDeviceId() != 1 && key.getRecipientId().equals(recipientId))
66 .map(Key::getDeviceId)
67 .collect(Collectors.toList());
68 }
69 }
70
71 @Override
72 public void storeSession(SignalProtocolAddress address, SessionRecord session) {
73 final var key = getKey(address);
74
75 synchronized (cachedSessions) {
76 storeSessionLocked(key, session);
77 }
78 }
79
80 @Override
81 public boolean containsSession(SignalProtocolAddress address) {
82 final var key = getKey(address);
83
84 synchronized (cachedSessions) {
85 final var session = loadSessionLocked(key);
86 if (session == null) {
87 return false;
88 }
89
90 return session.hasSenderChain() && session.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
91 }
92 }
93
94 @Override
95 public void deleteSession(SignalProtocolAddress address) {
96 final var key = getKey(address);
97
98 synchronized (cachedSessions) {
99 deleteSessionLocked(key);
100 }
101 }
102
103 @Override
104 public void deleteAllSessions(String name) {
105 final var recipientId = resolveRecipient(name);
106 deleteAllSessions(recipientId);
107 }
108
109 public void deleteAllSessions(RecipientId recipientId) {
110 synchronized (cachedSessions) {
111 final var keys = getKeysLocked(recipientId);
112 for (var key : keys) {
113 deleteSessionLocked(key);
114 }
115 }
116 }
117
118 @Override
119 public void archiveSession(final SignalProtocolAddress address) {
120 final var key = getKey(address);
121
122 synchronized (cachedSessions) {
123 archiveSessionLocked(key);
124 }
125 }
126
127 public void archiveAllSessions() {
128 synchronized (cachedSessions) {
129 final var keys = getKeysLocked();
130 for (var key : keys) {
131 archiveSessionLocked(key);
132 }
133 }
134 }
135
136 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
137 synchronized (cachedSessions) {
138 final var otherHasSession = getKeysLocked(toBeMergedRecipientId).size() > 0;
139 if (!otherHasSession) {
140 return;
141 }
142
143 final var hasSession = getKeysLocked(recipientId).size() > 0;
144 if (hasSession) {
145 logger.debug("To be merged recipient had sessions, deleting.");
146 deleteAllSessions(toBeMergedRecipientId);
147 } else {
148 logger.debug("To be merged recipient had sessions, re-assigning to the new recipient.");
149 final var keys = getKeysLocked(toBeMergedRecipientId);
150 for (var key : keys) {
151 final var session = loadSessionLocked(key);
152 deleteSessionLocked(key);
153 if (session == null) {
154 continue;
155 }
156 final var newKey = new Key(recipientId, key.getDeviceId());
157 storeSessionLocked(newKey, session);
158 }
159 }
160 }
161 }
162
163 /**
164 * @param identifier can be either a serialized uuid or a e164 phone number
165 */
166 private RecipientId resolveRecipient(String identifier) {
167 return resolver.resolveRecipient(Utils.getSignalServiceAddressFromIdentifier(identifier));
168 }
169
170 private Key getKey(final SignalProtocolAddress address) {
171 final var recipientId = resolveRecipient(address.getName());
172 return new Key(recipientId, address.getDeviceId());
173 }
174
175 private List<Key> getKeysLocked(RecipientId recipientId) {
176 final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
177 if (files == null) {
178 return List.of();
179 }
180 return parseFileNames(files);
181 }
182
183 private Collection<Key> getKeysLocked() {
184 final var files = sessionsPath.listFiles();
185 if (files == null) {
186 return List.of();
187 }
188 return parseFileNames(files);
189 }
190
191 final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
192
193 private List<Key> parseFileNames(final File[] files) {
194 return Arrays.stream(files)
195 .map(f -> sessionFileNamePattern.matcher(f.getName()))
196 .filter(Matcher::matches)
197 .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
198 Integer.parseInt(matcher.group(2))))
199 .collect(Collectors.toList());
200 }
201
202 private File getSessionPath(Key key) {
203 try {
204 IOUtils.createPrivateDirectories(sessionsPath);
205 } catch (IOException e) {
206 throw new AssertionError("Failed to create sessions path", e);
207 }
208 return new File(sessionsPath, key.getRecipientId().getId() + "_" + key.getDeviceId());
209 }
210
211 private SessionRecord loadSessionLocked(final Key key) {
212 {
213 final var session = cachedSessions.get(key);
214 if (session != null) {
215 return session;
216 }
217 }
218
219 final var file = getSessionPath(key);
220 if (!file.exists()) {
221 return null;
222 }
223 try (var inputStream = new FileInputStream(file)) {
224 final var session = new SessionRecord(inputStream.readAllBytes());
225 cachedSessions.put(key, session);
226 return session;
227 } catch (IOException e) {
228 logger.warn("Failed to load session, resetting session: {}", e.getMessage());
229 return null;
230 }
231 }
232
233 private void storeSessionLocked(final Key key, final SessionRecord session) {
234 cachedSessions.put(key, session);
235
236 final var file = getSessionPath(key);
237 try {
238 try (var outputStream = new FileOutputStream(file)) {
239 outputStream.write(session.serialize());
240 }
241 } catch (IOException e) {
242 logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
243 try {
244 Files.delete(file.toPath());
245 try (var outputStream = new FileOutputStream(file)) {
246 outputStream.write(session.serialize());
247 }
248 } catch (IOException e2) {
249 logger.error("Failed to store session file {}: {}", file, e2.getMessage());
250 }
251 }
252 }
253
254 private void archiveSessionLocked(final Key key) {
255 final var session = loadSessionLocked(key);
256 if (session == null) {
257 return;
258 }
259 session.archiveCurrentState();
260 storeSessionLocked(key, session);
261 }
262
263 private void deleteSessionLocked(final Key key) {
264 cachedSessions.remove(key);
265
266 final var file = getSessionPath(key);
267 if (!file.exists()) {
268 return;
269 }
270 try {
271 Files.delete(file.toPath());
272 } catch (IOException e) {
273 logger.error("Failed to delete session file {}: {}", file, e.getMessage());
274 }
275 }
276
277 private static final class Key {
278
279 private final RecipientId recipientId;
280 private final int deviceId;
281
282 public Key(final RecipientId recipientId, final int deviceId) {
283 this.recipientId = recipientId;
284 this.deviceId = deviceId;
285 }
286
287 public RecipientId getRecipientId() {
288 return recipientId;
289 }
290
291 public int getDeviceId() {
292 return deviceId;
293 }
294
295 @Override
296 public boolean equals(final Object o) {
297 if (this == o) return true;
298 if (o == null || getClass() != o.getClass()) return false;
299
300 final var key = (Key) o;
301
302 if (deviceId != key.deviceId) return false;
303 return recipientId.equals(key.recipientId);
304 }
305
306 @Override
307 public int hashCode() {
308 int result = recipientId.hashCode();
309 result = 31 * result + deviceId;
310 return result;
311 }
312 }
313 }