1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.storage
.recipients
.RecipientId
;
4 import org
.asamk
.signal
.manager
.storage
.recipients
.RecipientResolver
;
5 import org
.asamk
.signal
.manager
.util
.IOUtils
;
6 import org
.slf4j
.Logger
;
7 import org
.slf4j
.LoggerFactory
;
8 import org
.whispersystems
.libsignal
.NoSessionException
;
9 import org
.whispersystems
.libsignal
.SignalProtocolAddress
;
10 import org
.whispersystems
.libsignal
.protocol
.CiphertextMessage
;
11 import org
.whispersystems
.libsignal
.state
.SessionRecord
;
12 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
15 import java
.io
.FileInputStream
;
16 import java
.io
.FileOutputStream
;
17 import java
.io
.IOException
;
18 import java
.nio
.file
.Files
;
19 import java
.util
.Arrays
;
20 import java
.util
.Collection
;
21 import java
.util
.HashMap
;
22 import java
.util
.List
;
24 import java
.util
.Objects
;
26 import java
.util
.regex
.Matcher
;
27 import java
.util
.regex
.Pattern
;
28 import java
.util
.stream
.Collectors
;
30 public class SessionStore
implements SignalServiceSessionStore
{
32 private final static Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
34 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
36 private final File sessionsPath
;
38 private final RecipientResolver resolver
;
41 final File sessionsPath
, final RecipientResolver resolver
43 this.sessionsPath
= sessionsPath
;
44 this.resolver
= resolver
;
48 public SessionRecord
loadSession(SignalProtocolAddress address
) {
49 final var key
= getKey(address
);
51 synchronized (cachedSessions
) {
52 final var session
= loadSessionLocked(key
);
53 if (session
== null) {
54 return new SessionRecord();
61 public List
<SessionRecord
> loadExistingSessions(final List
<SignalProtocolAddress
> addresses
) throws NoSessionException
{
62 final var keys
= addresses
.stream().map(this::getKey
).toList();
64 synchronized (cachedSessions
) {
65 final var sessions
= keys
.stream().map(this::loadSessionLocked
).filter(Objects
::nonNull
).toList();
67 if (sessions
.size() != addresses
.size()) {
68 String message
= "Mismatch! Asked for "
70 + " sessions, but only found "
74 throw new NoSessionException(message
);
82 public List
<Integer
> getSubDeviceSessions(String name
) {
83 final var recipientId
= resolveRecipient(name
);
85 synchronized (cachedSessions
) {
86 return getKeysLocked(recipientId
).stream()
87 // get all sessions for recipient except main device session
88 .filter(key
-> key
.deviceId() != 1 && key
.recipientId().equals(recipientId
))
95 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
96 final var key
= getKey(address
);
98 synchronized (cachedSessions
) {
99 storeSessionLocked(key
, session
);
104 public boolean containsSession(SignalProtocolAddress address
) {
105 final var key
= getKey(address
);
107 synchronized (cachedSessions
) {
108 final var session
= loadSessionLocked(key
);
109 return isActive(session
);
114 public void deleteSession(SignalProtocolAddress address
) {
115 final var key
= getKey(address
);
117 synchronized (cachedSessions
) {
118 deleteSessionLocked(key
);
123 public void deleteAllSessions(String name
) {
124 final var recipientId
= resolveRecipient(name
);
125 deleteAllSessions(recipientId
);
128 public void deleteAllSessions(RecipientId recipientId
) {
129 synchronized (cachedSessions
) {
130 final var keys
= getKeysLocked(recipientId
);
131 for (var key
: keys
) {
132 deleteSessionLocked(key
);
138 public void archiveSession(final SignalProtocolAddress address
) {
139 final var key
= getKey(address
);
141 synchronized (cachedSessions
) {
142 archiveSessionLocked(key
);
147 public Set
<SignalProtocolAddress
> getAllAddressesWithActiveSessions(final List
<String
> addressNames
) {
148 final var recipientIdToNameMap
= addressNames
.stream()
149 .collect(Collectors
.toMap(this::resolveRecipient
, name
-> name
));
150 synchronized (cachedSessions
) {
151 return recipientIdToNameMap
.keySet()
153 .flatMap(recipientId
-> getKeysLocked(recipientId
).stream())
154 .filter(key
-> isActive(this.loadSessionLocked(key
)))
155 .map(key
-> new SignalProtocolAddress(recipientIdToNameMap
.get(key
.recipientId
), key
.deviceId()))
156 .collect(Collectors
.toSet());
160 public void archiveAllSessions() {
161 synchronized (cachedSessions
) {
162 final var keys
= getKeysLocked();
163 for (var key
: keys
) {
164 archiveSessionLocked(key
);
169 public void archiveSessions(final RecipientId recipientId
) {
170 synchronized (cachedSessions
) {
171 getKeysLocked().stream()
172 .filter(key
-> key
.recipientId
.equals(recipientId
))
173 .forEach(this::archiveSessionLocked
);
177 public void mergeRecipients(RecipientId recipientId
, RecipientId toBeMergedRecipientId
) {
178 synchronized (cachedSessions
) {
179 final var keys
= getKeysLocked(toBeMergedRecipientId
);
180 final var otherHasSession
= keys
.size() > 0;
181 if (!otherHasSession
) {
185 final var hasSession
= getKeysLocked(recipientId
).size() > 0;
187 logger
.debug("To be merged recipient had sessions, deleting.");
188 deleteAllSessions(toBeMergedRecipientId
);
190 logger
.debug("Only to be merged recipient had sessions, re-assigning to the new recipient.");
191 for (var key
: keys
) {
192 final var session
= loadSessionLocked(key
);
193 deleteSessionLocked(key
);
194 if (session
== null) {
197 final var newKey
= new Key(recipientId
, key
.deviceId());
198 storeSessionLocked(newKey
, session
);
205 * @param identifier can be either a serialized uuid or a e164 phone number
207 private RecipientId
resolveRecipient(String identifier
) {
208 return resolver
.resolveRecipient(identifier
);
211 private Key
getKey(final SignalProtocolAddress address
) {
212 final var recipientId
= resolveRecipient(address
.getName());
213 return new Key(recipientId
, address
.getDeviceId());
216 private List
<Key
> getKeysLocked(RecipientId recipientId
) {
217 final var files
= sessionsPath
.listFiles((_file
, s
) -> s
.startsWith(recipientId
.id() + "_"));
221 return parseFileNames(files
);
224 private Collection
<Key
> getKeysLocked() {
225 final var files
= sessionsPath
.listFiles();
229 return parseFileNames(files
);
232 final Pattern sessionFileNamePattern
= Pattern
.compile("([0-9]+)_([0-9]+)");
234 private List
<Key
> parseFileNames(final File
[] files
) {
235 return Arrays
.stream(files
)
236 .map(f
-> sessionFileNamePattern
.matcher(f
.getName()))
237 .filter(Matcher
::matches
)
239 final var recipientId
= resolver
.resolveRecipient(Long
.parseLong(matcher
.group(1)));
240 if (recipientId
== null) {
243 return new Key(recipientId
, Integer
.parseInt(matcher
.group(2)));
245 .filter(Objects
::nonNull
)
249 private File
getSessionFile(Key key
) {
251 IOUtils
.createPrivateDirectories(sessionsPath
);
252 } catch (IOException e
) {
253 throw new AssertionError("Failed to create sessions path", e
);
255 return new File(sessionsPath
, key
.recipientId().id() + "_" + key
.deviceId());
258 private SessionRecord
loadSessionLocked(final Key key
) {
260 final var session
= cachedSessions
.get(key
);
261 if (session
!= null) {
266 final var file
= getSessionFile(key
);
267 if (!file
.exists()) {
270 try (var inputStream
= new FileInputStream(file
)) {
271 final var session
= new SessionRecord(inputStream
.readAllBytes());
272 cachedSessions
.put(key
, session
);
274 } catch (IOException e
) {
275 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
280 private void storeSessionLocked(final Key key
, final SessionRecord session
) {
281 cachedSessions
.put(key
, session
);
283 final var file
= getSessionFile(key
);
285 try (var outputStream
= new FileOutputStream(file
)) {
286 outputStream
.write(session
.serialize());
288 } catch (IOException e
) {
289 logger
.warn("Failed to store session, trying to delete file and retry: {}", e
.getMessage());
291 Files
.delete(file
.toPath());
292 try (var outputStream
= new FileOutputStream(file
)) {
293 outputStream
.write(session
.serialize());
295 } catch (IOException e2
) {
296 logger
.error("Failed to store session file {}: {}", file
, e2
.getMessage());
301 private void archiveSessionLocked(final Key key
) {
302 final var session
= loadSessionLocked(key
);
303 if (session
== null) {
306 session
.archiveCurrentState();
307 storeSessionLocked(key
, session
);
310 private void deleteSessionLocked(final Key key
) {
311 cachedSessions
.remove(key
);
313 final var file
= getSessionFile(key
);
314 if (!file
.exists()) {
318 Files
.delete(file
.toPath());
319 } catch (IOException e
) {
320 logger
.error("Failed to delete session file {}: {}", file
, e
.getMessage());
324 private static boolean isActive(SessionRecord
record) {
325 return record != null
326 && record.hasSenderChain()
327 && record.getSessionVersion() == CiphertextMessage
.CURRENT_VERSION
;
330 private record Key(RecipientId recipientId
, int deviceId
) {}