1 package org
.asamk
.signal
.manager
.storage
.sessions
;
3 import org
.asamk
.signal
.manager
.storage
.recipients
.RecipientId
;
4 import org
.asamk
.signal
.manager
.storage
.recipients
.RecipientResolver
;
5 import org
.asamk
.signal
.manager
.util
.IOUtils
;
6 import org
.asamk
.signal
.manager
.util
.Utils
;
7 import org
.slf4j
.Logger
;
8 import org
.slf4j
.LoggerFactory
;
9 import org
.whispersystems
.libsignal
.SignalProtocolAddress
;
10 import org
.whispersystems
.libsignal
.protocol
.CiphertextMessage
;
11 import org
.whispersystems
.libsignal
.state
.SessionRecord
;
12 import org
.whispersystems
.signalservice
.api
.SignalServiceSessionStore
;
15 import java
.io
.FileInputStream
;
16 import java
.io
.FileOutputStream
;
17 import java
.io
.IOException
;
18 import java
.nio
.file
.Files
;
19 import java
.util
.Arrays
;
20 import java
.util
.Collection
;
21 import java
.util
.HashMap
;
22 import java
.util
.List
;
24 import java
.util
.regex
.Matcher
;
25 import java
.util
.regex
.Pattern
;
26 import java
.util
.stream
.Collectors
;
28 public class SessionStore
implements SignalServiceSessionStore
{
30 private final static Logger logger
= LoggerFactory
.getLogger(SessionStore
.class);
32 private final Map
<Key
, SessionRecord
> cachedSessions
= new HashMap
<>();
34 private final File sessionsPath
;
36 private final RecipientResolver resolver
;
39 final File sessionsPath
, final RecipientResolver resolver
41 this.sessionsPath
= sessionsPath
;
42 this.resolver
= resolver
;
46 public SessionRecord
loadSession(SignalProtocolAddress address
) {
47 final var key
= getKey(address
);
49 synchronized (cachedSessions
) {
50 final var session
= loadSessionLocked(key
);
51 if (session
== null) {
52 return new SessionRecord();
59 public List
<Integer
> getSubDeviceSessions(String name
) {
60 final var recipientId
= resolveRecipient(name
);
62 synchronized (cachedSessions
) {
63 return getKeysLocked(recipientId
).stream()
64 // get all sessions for recipient except main device session
65 .filter(key
-> key
.getDeviceId() != 1 && key
.getRecipientId().equals(recipientId
))
66 .map(Key
::getDeviceId
)
67 .collect(Collectors
.toList());
72 public void storeSession(SignalProtocolAddress address
, SessionRecord session
) {
73 final var key
= getKey(address
);
75 synchronized (cachedSessions
) {
76 storeSessionLocked(key
, session
);
81 public boolean containsSession(SignalProtocolAddress address
) {
82 final var key
= getKey(address
);
84 synchronized (cachedSessions
) {
85 final var session
= loadSessionLocked(key
);
86 if (session
== null) {
90 return session
.hasSenderChain() && session
.getSessionVersion() == CiphertextMessage
.CURRENT_VERSION
;
95 public void deleteSession(SignalProtocolAddress address
) {
96 final var key
= getKey(address
);
98 synchronized (cachedSessions
) {
99 deleteSessionLocked(key
);
104 public void deleteAllSessions(String name
) {
105 final var recipientId
= resolveRecipient(name
);
106 deleteAllSessions(recipientId
);
109 public void deleteAllSessions(RecipientId recipientId
) {
110 synchronized (cachedSessions
) {
111 final var keys
= getKeysLocked(recipientId
);
112 for (var key
: keys
) {
113 deleteSessionLocked(key
);
119 public void archiveSession(final SignalProtocolAddress address
) {
120 final var key
= getKey(address
);
122 synchronized (cachedSessions
) {
123 archiveSessionLocked(key
);
127 public void archiveAllSessions() {
128 synchronized (cachedSessions
) {
129 final var keys
= getKeysLocked();
130 for (var key
: keys
) {
131 archiveSessionLocked(key
);
136 public void mergeRecipients(RecipientId recipientId
, RecipientId toBeMergedRecipientId
) {
137 synchronized (cachedSessions
) {
138 final var otherHasSession
= getKeysLocked(toBeMergedRecipientId
).size() > 0;
139 if (!otherHasSession
) {
143 final var hasSession
= getKeysLocked(recipientId
).size() > 0;
145 logger
.debug("To be merged recipient had sessions, deleting.");
146 deleteAllSessions(toBeMergedRecipientId
);
148 logger
.debug("To be merged recipient had sessions, re-assigning to the new recipient.");
149 final var keys
= getKeysLocked(toBeMergedRecipientId
);
150 for (var key
: keys
) {
151 final var session
= loadSessionLocked(key
);
152 deleteSessionLocked(key
);
153 if (session
== null) {
156 final var newKey
= new Key(recipientId
, key
.getDeviceId());
157 storeSessionLocked(newKey
, session
);
164 * @param identifier can be either a serialized uuid or a e164 phone number
166 private RecipientId
resolveRecipient(String identifier
) {
167 return resolver
.resolveRecipient(Utils
.getSignalServiceAddressFromIdentifier(identifier
));
170 private Key
getKey(final SignalProtocolAddress address
) {
171 final var recipientId
= resolveRecipient(address
.getName());
172 return new Key(recipientId
, address
.getDeviceId());
175 private List
<Key
> getKeysLocked(RecipientId recipientId
) {
176 final var files
= sessionsPath
.listFiles((_file
, s
) -> s
.startsWith(recipientId
.getId() + "_"));
180 return parseFileNames(files
);
183 private Collection
<Key
> getKeysLocked() {
184 final var files
= sessionsPath
.listFiles();
188 return parseFileNames(files
);
191 final Pattern sessionFileNamePattern
= Pattern
.compile("([0-9]+)_([0-9]+)");
193 private List
<Key
> parseFileNames(final File
[] files
) {
194 return Arrays
.stream(files
)
195 .map(f
-> sessionFileNamePattern
.matcher(f
.getName()))
196 .filter(Matcher
::matches
)
197 .map(matcher
-> new Key(RecipientId
.of(Long
.parseLong(matcher
.group(1))),
198 Integer
.parseInt(matcher
.group(2))))
199 .collect(Collectors
.toList());
202 private File
getSessionPath(Key key
) {
204 IOUtils
.createPrivateDirectories(sessionsPath
);
205 } catch (IOException e
) {
206 throw new AssertionError("Failed to create sessions path", e
);
208 return new File(sessionsPath
, key
.getRecipientId().getId() + "_" + key
.getDeviceId());
211 private SessionRecord
loadSessionLocked(final Key key
) {
213 final var session
= cachedSessions
.get(key
);
214 if (session
!= null) {
219 final var file
= getSessionPath(key
);
220 if (!file
.exists()) {
223 try (var inputStream
= new FileInputStream(file
)) {
224 final var session
= new SessionRecord(inputStream
.readAllBytes());
225 cachedSessions
.put(key
, session
);
227 } catch (IOException e
) {
228 logger
.warn("Failed to load session, resetting session: {}", e
.getMessage());
233 private void storeSessionLocked(final Key key
, final SessionRecord session
) {
234 cachedSessions
.put(key
, session
);
236 final var file
= getSessionPath(key
);
238 try (var outputStream
= new FileOutputStream(file
)) {
239 outputStream
.write(session
.serialize());
241 } catch (IOException e
) {
242 logger
.warn("Failed to store session, trying to delete file and retry: {}", e
.getMessage());
244 Files
.delete(file
.toPath());
245 try (var outputStream
= new FileOutputStream(file
)) {
246 outputStream
.write(session
.serialize());
248 } catch (IOException e2
) {
249 logger
.error("Failed to store session file {}: {}", file
, e2
.getMessage());
254 private void archiveSessionLocked(final Key key
) {
255 final var session
= loadSessionLocked(key
);
256 if (session
== null) {
259 session
.archiveCurrentState();
260 storeSessionLocked(key
, session
);
263 private void deleteSessionLocked(final Key key
) {
264 cachedSessions
.remove(key
);
266 final var file
= getSessionPath(key
);
267 if (!file
.exists()) {
271 Files
.delete(file
.toPath());
272 } catch (IOException e
) {
273 logger
.error("Failed to delete session file {}: {}", file
, e
.getMessage());
277 private static final class Key
{
279 private final RecipientId recipientId
;
280 private final int deviceId
;
282 public Key(final RecipientId recipientId
, final int deviceId
) {
283 this.recipientId
= recipientId
;
284 this.deviceId
= deviceId
;
287 public RecipientId
getRecipientId() {
291 public int getDeviceId() {
296 public boolean equals(final Object o
) {
297 if (this == o
) return true;
298 if (o
== null || getClass() != o
.getClass()) return false;
300 final var key
= (Key
) o
;
302 if (deviceId
!= key
.deviceId
) return false;
303 return recipientId
.equals(key
.recipientId
);
307 public int hashCode() {
308 int result
= recipientId
.hashCode();
309 result
= 31 * result
+ deviceId
;