1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.storage
.recipients
.RecipientId
;
4 import org
.asamk
.signal
.manager
.storage
.recipients
.RecipientResolver
;
5 import org
.asamk
.signal
.manager
.util
.IOUtils
;
6 import org
.signal
.libsignal
.protocol
.NoSessionException
;
7 import org
.signal
.libsignal
.protocol
.SignalProtocolAddress
;
8 import org
.signal
.libsignal
.protocol
.ecc
.ECPublicKey
;
9 import org
.signal
.libsignal
.protocol
.message
.CiphertextMessage
;
10 import org
.signal
.libsignal
.protocol
.state
.SessionRecord
;
11 import org
.slf4j
.Logger
;
12 import org
.slf4j
.LoggerFactory
;
13 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
16 import java
.io
.FileInputStream
;
17 import java
.io
.FileOutputStream
;
18 import java
.io
.IOException
;
19 import java
.nio
.file
.Files
;
20 import java
.util
.Arrays
;
21 import java
.util
.Collection
;
22 import java
.util
.HashMap
;
23 import java
.util
.List
;
25 import java
.util
.Objects
;
27 import java
.util
.regex
.Matcher
;
28 import java
.util
.regex
.Pattern
;
29 import java
.util
.stream
.Collectors
;
31 public class SessionStore
implements SignalServiceSessionStore
{
33 private final static Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
35 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
37 private final File sessionsPath
;
39 private final RecipientResolver resolver
;
42 final File sessionsPath
, final RecipientResolver resolver
44 this.sessionsPath
= sessionsPath
;
45 this.resolver
= resolver
;
49 public SessionRecord
loadSession(SignalProtocolAddress address
) {
50 final var key
= getKey(address
);
52 synchronized (cachedSessions
) {
53 final var session
= loadSessionLocked(key
);
54 if (session
== null) {
55 return new SessionRecord();
62 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
63 final var keys
= addresses
.stream().map(this::getKey
).toList();
65 synchronized (cachedSessions
) {
66 final var sessions
= keys
.stream().map(this::loadSessionLocked
).filter(Objects
::nonNull
).toList();
68 if (sessions
.size() != addresses
.size()) {
69 String message
= "Mismatch! Asked for "
71 + " sessions, but only found "
75 throw new NoSessionException(message
);
83 public List
<Integer
> getSubDeviceSessions(String name
) {
84 final var recipientId
= resolveRecipient(name
);
86 synchronized (cachedSessions
) {
87 return getKeysLocked(recipientId
).stream()
88 // get all sessions for recipient except main device session
89 .filter(key
-> key
.deviceId() != 1 && key
.recipientId().equals(recipientId
))
95 public boolean isCurrentRatchetKey(RecipientId recipientId
, int deviceId
, ECPublicKey ratchetKey
) {
96 final var key
= new Key(recipientId
, deviceId
);
98 synchronized (cachedSessions
) {
99 final var session
= loadSessionLocked(key
);
100 if (session
== null) {
103 return session
.currentRatchetKeyMatches(ratchetKey
);
108 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
109 final var key
= getKey(address
);
111 synchronized (cachedSessions
) {
112 storeSessionLocked(key
, session
);
117 public boolean containsSession(SignalProtocolAddress address
) {
118 final var key
= getKey(address
);
120 synchronized (cachedSessions
) {
121 final var session
= loadSessionLocked(key
);
122 return isActive(session
);
127 public void deleteSession(SignalProtocolAddress address
) {
128 final var key
= getKey(address
);
130 synchronized (cachedSessions
) {
131 deleteSessionLocked(key
);
136 public void deleteAllSessions(String name
) {
137 final var recipientId
= resolveRecipient(name
);
138 deleteAllSessions(recipientId
);
141 public void deleteAllSessions(RecipientId recipientId
) {
142 synchronized (cachedSessions
) {
143 final var keys
= getKeysLocked(recipientId
);
144 for (var key
: keys
) {
145 deleteSessionLocked(key
);
151 public void archiveSession(final SignalProtocolAddress address
) {
152 final var key
= getKey(address
);
154 synchronized (cachedSessions
) {
155 archiveSessionLocked(key
);
160 public Set
<SignalProtocolAddress
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
161 final var recipientIdToNameMap
= addressNames
.stream()
162 .collect(Collectors
.toMap(this::resolveRecipient
, name
-> name
));
163 synchronized (cachedSessions
) {
164 return recipientIdToNameMap
.keySet()
166 .flatMap(recipientId
-> getKeysLocked(recipientId
).stream())
167 .filter(key
-> isActive(this.loadSessionLocked(key
)))
168 .map(key
-> new SignalProtocolAddress(recipientIdToNameMap
.get(key
.recipientId
), key
.deviceId()))
169 .collect(Collectors
.toSet());
173 public void archiveAllSessions() {
174 synchronized (cachedSessions
) {
175 final var keys
= getKeysLocked();
176 for (var key
: keys
) {
177 archiveSessionLocked(key
);
182 public void archiveSessions(final RecipientId recipientId
) {
183 synchronized (cachedSessions
) {
184 getKeysLocked().stream()
185 .filter(key
-> key
.recipientId
.equals(recipientId
))
186 .forEach(this::archiveSessionLocked
);
190 public void mergeRecipients(RecipientId recipientId
, RecipientId toBeMergedRecipientId
) {
191 synchronized (cachedSessions
) {
192 final var keys
= getKeysLocked(toBeMergedRecipientId
);
193 final var otherHasSession
= keys
.size() > 0;
194 if (!otherHasSession
) {
198 final var hasSession
= getKeysLocked(recipientId
).size() > 0;
200 logger
.debug("To be merged recipient had sessions, deleting.");
201 deleteAllSessions(toBeMergedRecipientId
);
203 logger
.debug("Only to be merged recipient had sessions, re-assigning to the new recipient.");
204 for (var key
: keys
) {
205 final var session
= loadSessionLocked(key
);
206 deleteSessionLocked(key
);
207 if (session
== null) {
210 final var newKey
= new Key(recipientId
, key
.deviceId());
211 storeSessionLocked(newKey
, session
);
218 * @param identifier can be either a serialized uuid or a e164 phone number
220 private RecipientId
resolveRecipient(String identifier
) {
221 return resolver
.resolveRecipient(identifier
);
224 private Key
getKey(final SignalProtocolAddress address
) {
225 final var recipientId
= resolveRecipient(address
.getName());
226 return new Key(recipientId
, address
.getDeviceId());
229 private List
<Key
> getKeysLocked(RecipientId recipientId
) {
230 final var files
= sessionsPath
.listFiles((_file
, s
) -> s
.startsWith(recipientId
.id() + "_"));
234 return parseFileNames(files
);
237 private Collection
<Key
> getKeysLocked() {
238 final var files
= sessionsPath
.listFiles();
242 return parseFileNames(files
);
245 final Pattern sessionFileNamePattern
= Pattern
.compile("([0-9]+)_([0-9]+)");
247 private List
<Key
> parseFileNames(final File
[] files
) {
248 return Arrays
.stream(files
)
249 .map(f
-> sessionFileNamePattern
.matcher(f
.getName()))
250 .filter(Matcher
::matches
)
252 final var recipientId
= resolver
.resolveRecipient(Long
.parseLong(matcher
.group(1)));
253 if (recipientId
== null) {
256 return new Key(recipientId
, Integer
.parseInt(matcher
.group(2)));
258 .filter(Objects
::nonNull
)
262 private File
getSessionFile(Key key
) {
264 IOUtils
.createPrivateDirectories(sessionsPath
);
265 } catch (IOException e
) {
266 throw new AssertionError("Failed to create sessions path", e
);
268 return new File(sessionsPath
, key
.recipientId().id() + "_" + key
.deviceId());
271 private SessionRecord
loadSessionLocked(final Key key
) {
273 final var session
= cachedSessions
.get(key
);
274 if (session
!= null) {
279 final var file
= getSessionFile(key
);
280 if (!file
.exists()) {
283 try (var inputStream
= new FileInputStream(file
)) {
284 final var session
= new SessionRecord(inputStream
.readAllBytes());
285 cachedSessions
.put(key
, session
);
287 } catch (Exception e
) {
288 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
293 private void storeSessionLocked(final Key key
, final SessionRecord session
) {
294 cachedSessions
.put(key
, session
);
296 final var file
= getSessionFile(key
);
298 try (var outputStream
= new FileOutputStream(file
)) {
299 outputStream
.write(session
.serialize());
301 } catch (IOException e
) {
302 logger
.warn("Failed to store session, trying to delete file and retry: {}", e
.getMessage());
304 Files
.delete(file
.toPath());
305 try (var outputStream
= new FileOutputStream(file
)) {
306 outputStream
.write(session
.serialize());
308 } catch (IOException e2
) {
309 logger
.error("Failed to store session file {}: {}", file
, e2
.getMessage());
314 private void archiveSessionLocked(final Key key
) {
315 final var session
= loadSessionLocked(key
);
316 if (session
== null) {
319 session
.archiveCurrentState();
320 storeSessionLocked(key
, session
);
323 private void deleteSessionLocked(final Key key
) {
324 cachedSessions
.remove(key
);
326 final var file
= getSessionFile(key
);
327 if (!file
.exists()) {
331 Files
.delete(file
.toPath());
332 } catch (IOException e
) {
333 logger
.error("Failed to delete session file {}: {}", file
, e
.getMessage());
337 private static boolean isActive(SessionRecord
record) {
338 return record != null
339 && record.hasSenderChain()
340 && record.getSessionVersion() == CiphertextMessage
.CURRENT_VERSION
;
343 private record Key(RecipientId recipientId
, int deviceId
) {}