}
private IdentityKeyPair getIdentityKeyPair() {
- return account.getSignalProtocolStore().getIdentityKeyPair();
+ return account.getIdentityKeyPair();
}
public int getDeviceId() {
public void updateAccountAttributes() throws IOException {
accountManager.setAccountAttributes(null,
- account.getSignalProtocolStore().getLocalRegistrationId(),
+ account.getLocalRegistrationId(),
true,
// set legacy pin only if no KBS master key is set
account.getPinMasterKey() == null ? account.getRegistrationLockPin() : null,
}
private void handleEndSession(SignalServiceAddress source) {
- account.getSignalProtocolStore().deleteAllSessions(source);
+ account.getSessionStore().deleteAllSessions(source.getIdentifier());
}
private List<HandleAction> handleSignalServiceDataMessage(
account.setRegistered(true);
account.setUuid(UuidUtil.parseOrNull(response.getUuid()));
account.setRegistrationLockPin(pin);
- account.getSignalProtocolStore().archiveAllSessions();
+ account.getSessionStore().archiveAllSessions();
account.getSignalProtocolStore()
.saveIdentity(account.getSelfAddress(),
- account.getSignalProtocolStore().getIdentityKeyPair().getPublicKey(),
+ account.getIdentityKeyPair().getPublicKey(),
TrustLevel.TRUSTED_VERIFIED);
Manager m = null;
) throws IOException {
return accountManager.verifyAccountWithCode(verificationCode,
null,
- account.getSignalProtocolStore().getLocalRegistrationId(),
+ account.getLocalRegistrationId(),
true,
legacyPin,
registrationLock,
import org.asamk.signal.manager.storage.recipients.LegacyRecipientStore;
import org.asamk.signal.manager.storage.recipients.RecipientId;
import org.asamk.signal.manager.storage.recipients.RecipientStore;
+import org.asamk.signal.manager.storage.sessions.SessionStore;
import org.asamk.signal.manager.storage.stickers.StickerStore;
import org.asamk.signal.manager.storage.threads.LegacyJsonThreadStore;
import org.asamk.signal.manager.util.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.libsignal.IdentityKeyPair;
+import org.whispersystems.libsignal.SignalProtocolAddress;
import org.whispersystems.libsignal.state.PreKeyRecord;
+import org.whispersystems.libsignal.state.SessionRecord;
import org.whispersystems.libsignal.state.SignedPreKeyRecord;
import org.whispersystems.libsignal.util.Medium;
import org.whispersystems.libsignal.util.Pair;
private boolean registered = false;
private JsonSignalProtocolStore signalProtocolStore;
+ private SessionStore sessionStore;
private JsonGroupStore groupStore;
private JsonContactsStore contactStore;
private RecipientStore recipientStore;
account.username = username;
account.profileKey = profileKey;
- account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId);
account.groupStore = new JsonGroupStore(getGroupCachePath(dataPath, username));
account.contactStore = new JsonContactsStore();
account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username),
account::mergeRecipients);
+ account.sessionStore = new SessionStore(getSessionsPath(dataPath, username),
+ account.recipientStore::resolveRecipient);
+ account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId, account.sessionStore);
account.profileStore = new ProfileStore();
account.stickerStore = new StickerStore();
account.password = password;
account.profileKey = profileKey;
account.deviceId = deviceId;
- account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId);
account.groupStore = new JsonGroupStore(getGroupCachePath(dataPath, username));
account.contactStore = new JsonContactsStore();
account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username),
account::mergeRecipients);
+ account.sessionStore = new SessionStore(getSessionsPath(dataPath, username),
+ account.recipientStore::resolveRecipient);
+ account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId, account.sessionStore);
account.profileStore = new ProfileStore();
account.stickerStore = new StickerStore();
}
private void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
- // TODO
+ sessionStore.mergeRecipients(recipientId, toBeMergedRecipientId);
}
public static File getFileName(File dataPath, String username) {
return new File(getUserPath(dataPath, username), "group-cache");
}
+ private static File getSessionsPath(File dataPath, String username) {
+ return new File(getUserPath(dataPath, username), "sessions");
+ }
+
private static File getRecipientsStoreFile(File dataPath, String username) {
return new File(getUserPath(dataPath, username), "recipients-store");
}
signalProtocolStore = jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"),
JsonSignalProtocolStore.class);
+ sessionStore = new SessionStore(getSessionsPath(dataPath, username), recipientStore::resolveRecipient);
+ if (signalProtocolStore.getLegacySessionStore() != null) {
+ logger.debug("Migrating legacy session store.");
+ for (var session : signalProtocolStore.getLegacySessionStore().getSessions()) {
+ try {
+ sessionStore.storeSession(new SignalProtocolAddress(session.address.getIdentifier(),
+ session.deviceId), new SessionRecord(session.sessionRecord));
+ } catch (IOException e) {
+ logger.warn("Failed to migrate session, ignoring", e);
+ }
+ }
+ }
+ signalProtocolStore.setSessionStore(sessionStore);
registered = Utils.getNotNullNode(rootNode, "registered").asBoolean();
var groupStoreNode = rootNode.get("groupStore");
if (groupStoreNode != null) {
}
}
- for (var session : signalProtocolStore.getSessions()) {
- session.address = recipientStore.resolveServiceAddress(session.address);
- }
-
for (var identity : signalProtocolStore.getIdentities()) {
identity.setAddress(recipientStore.resolveServiceAddress(identity.getAddress()));
}
return signalProtocolStore;
}
+ public SessionStore getSessionStore() {
+ return sessionStore;
+ }
+
public JsonGroupStore getGroupStore() {
return groupStore;
}
return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID;
}
+ public IdentityKeyPair getIdentityKeyPair() {
+ return signalProtocolStore.getIdentityKeyPair();
+ }
+
+ public int getLocalRegistrationId() {
+ return signalProtocolStore.getLocalRegistrationId();
+ }
+
public String getPassword() {
return password;
}
+++ /dev/null
-package org.asamk.signal.manager.storage.protocol;
-
-import com.fasterxml.jackson.core.JsonGenerator;
-import com.fasterxml.jackson.core.JsonParser;
-import com.fasterxml.jackson.databind.DeserializationContext;
-import com.fasterxml.jackson.databind.JsonDeserializer;
-import com.fasterxml.jackson.databind.JsonNode;
-import com.fasterxml.jackson.databind.JsonSerializer;
-import com.fasterxml.jackson.databind.SerializerProvider;
-
-import org.asamk.signal.manager.util.Utils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.whispersystems.libsignal.SignalProtocolAddress;
-import org.whispersystems.libsignal.protocol.CiphertextMessage;
-import org.whispersystems.libsignal.state.SessionRecord;
-import org.whispersystems.signalservice.api.SignalServiceSessionStore;
-import org.whispersystems.signalservice.api.push.SignalServiceAddress;
-import org.whispersystems.signalservice.api.util.UuidUtil;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Base64;
-import java.util.LinkedList;
-import java.util.List;
-
-class JsonSessionStore implements SignalServiceSessionStore {
-
- private final static Logger logger = LoggerFactory.getLogger(JsonSessionStore.class);
-
- private final List<SessionInfo> sessions = new ArrayList<>();
-
- private SignalServiceAddressResolver resolver;
-
- public JsonSessionStore() {
- }
-
- public void setResolver(final SignalServiceAddressResolver resolver) {
- this.resolver = resolver;
- }
-
- private SignalServiceAddress resolveSignalServiceAddress(String identifier) {
- if (resolver != null) {
- return resolver.resolveSignalServiceAddress(identifier);
- } else {
- return Utils.getSignalServiceAddressFromIdentifier(identifier);
- }
- }
-
- @Override
- public synchronized SessionRecord loadSession(SignalProtocolAddress address) {
- var serviceAddress = resolveSignalServiceAddress(address.getName());
- for (var info : sessions) {
- if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
- try {
- return new SessionRecord(info.sessionRecord);
- } catch (IOException e) {
- logger.warn("Failed to load session, resetting session: {}", e.getMessage());
- return new SessionRecord();
- }
- }
- }
-
- return new SessionRecord();
- }
-
- public synchronized List<SessionInfo> getSessions() {
- return sessions;
- }
-
- @Override
- public synchronized List<Integer> getSubDeviceSessions(String name) {
- var serviceAddress = resolveSignalServiceAddress(name);
-
- var deviceIds = new LinkedList<Integer>();
- for (var info : sessions) {
- if (info.address.matches(serviceAddress) && info.deviceId != 1) {
- deviceIds.add(info.deviceId);
- }
- }
-
- return deviceIds;
- }
-
- @Override
- public synchronized void storeSession(SignalProtocolAddress address, SessionRecord record) {
- var serviceAddress = resolveSignalServiceAddress(address.getName());
- for (var info : sessions) {
- if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
- if (!info.address.getUuid().isPresent() || !info.address.getNumber().isPresent()) {
- info.address = serviceAddress;
- }
- info.sessionRecord = record.serialize();
- return;
- }
- }
-
- sessions.add(new SessionInfo(serviceAddress, address.getDeviceId(), record.serialize()));
- }
-
- @Override
- public synchronized boolean containsSession(SignalProtocolAddress address) {
- var serviceAddress = resolveSignalServiceAddress(address.getName());
- for (var info : sessions) {
- if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
- final SessionRecord sessionRecord;
- try {
- sessionRecord = new SessionRecord(info.sessionRecord);
- } catch (IOException e) {
- logger.warn("Failed to check session: {}", e.getMessage());
- return false;
- }
-
- return sessionRecord.hasSenderChain()
- && sessionRecord.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
- }
- }
- return false;
- }
-
- @Override
- public synchronized void deleteSession(SignalProtocolAddress address) {
- var serviceAddress = resolveSignalServiceAddress(address.getName());
- sessions.removeIf(info -> info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId());
- }
-
- @Override
- public synchronized void deleteAllSessions(String name) {
- var serviceAddress = resolveSignalServiceAddress(name);
- deleteAllSessions(serviceAddress);
- }
-
- public synchronized void deleteAllSessions(SignalServiceAddress serviceAddress) {
- sessions.removeIf(info -> info.address.matches(serviceAddress));
- }
-
- @Override
- public void archiveSession(final SignalProtocolAddress address) {
- final var sessionRecord = loadSession(address);
- if (sessionRecord == null) {
- return;
- }
- sessionRecord.archiveCurrentState();
- storeSession(address, sessionRecord);
- }
-
- public void archiveAllSessions() {
- for (var info : sessions) {
- try {
- final var sessionRecord = new SessionRecord(info.sessionRecord);
- sessionRecord.archiveCurrentState();
- info.sessionRecord = sessionRecord.serialize();
- } catch (IOException ignored) {
- }
- }
- }
-
- public static class JsonSessionStoreDeserializer extends JsonDeserializer<JsonSessionStore> {
-
- @Override
- public JsonSessionStore deserialize(
- JsonParser jsonParser, DeserializationContext deserializationContext
- ) throws IOException {
- JsonNode node = jsonParser.getCodec().readTree(jsonParser);
-
- var sessionStore = new JsonSessionStore();
-
- if (node.isArray()) {
- for (var session : node) {
- var sessionName = session.hasNonNull("name") ? session.get("name").asText() : null;
- if (UuidUtil.isUuid(sessionName)) {
- // Ignore sessions that were incorrectly created with UUIDs as name
- continue;
- }
-
- var uuid = session.hasNonNull("uuid") ? UuidUtil.parseOrNull(session.get("uuid").asText()) : null;
- final var serviceAddress = uuid == null
- ? Utils.getSignalServiceAddressFromIdentifier(sessionName)
- : new SignalServiceAddress(uuid, sessionName);
- final var deviceId = session.get("deviceId").asInt();
- final var record = Base64.getDecoder().decode(session.get("record").asText());
- var sessionInfo = new SessionInfo(serviceAddress, deviceId, record);
- sessionStore.sessions.add(sessionInfo);
- }
- }
-
- return sessionStore;
- }
- }
-
- public static class JsonSessionStoreSerializer extends JsonSerializer<JsonSessionStore> {
-
- @Override
- public void serialize(
- JsonSessionStore jsonSessionStore, JsonGenerator json, SerializerProvider serializerProvider
- ) throws IOException {
- json.writeStartArray();
- for (var sessionInfo : jsonSessionStore.sessions) {
- json.writeStartObject();
- if (sessionInfo.address.getNumber().isPresent()) {
- json.writeStringField("name", sessionInfo.address.getNumber().get());
- }
- if (sessionInfo.address.getUuid().isPresent()) {
- json.writeStringField("uuid", sessionInfo.address.getUuid().get().toString());
- }
- json.writeNumberField("deviceId", sessionInfo.deviceId);
- json.writeStringField("record", Base64.getEncoder().encodeToString(sessionInfo.sessionRecord));
- json.writeEndObject();
- }
- json.writeEndArray();
- }
- }
-
-}
package org.asamk.signal.manager.storage.protocol;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.asamk.signal.manager.TrustLevel;
+import org.asamk.signal.manager.storage.sessions.SessionStore;
import org.whispersystems.libsignal.IdentityKey;
import org.whispersystems.libsignal.IdentityKeyPair;
import org.whispersystems.libsignal.InvalidKeyIdException;
import java.util.List;
+@JsonIgnoreProperties(value = "sessionStore", allowSetters = true)
public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
@JsonProperty("preKeys")
private JsonPreKeyStore preKeyStore;
@JsonProperty("sessionStore")
- @JsonDeserialize(using = JsonSessionStore.JsonSessionStoreDeserializer.class)
- @JsonSerialize(using = JsonSessionStore.JsonSessionStoreSerializer.class)
- private JsonSessionStore sessionStore;
+ @JsonDeserialize(using = LegacyJsonSessionStore.JsonSessionStoreDeserializer.class)
+ private LegacyJsonSessionStore legacySessionStore;
@JsonProperty("signedPreKeyStore")
@JsonDeserialize(using = JsonSignedPreKeyStore.JsonSignedPreKeyStoreDeserializer.class)
@JsonSerialize(using = JsonIdentityKeyStore.JsonIdentityKeyStoreSerializer.class)
private JsonIdentityKeyStore identityKeyStore;
- public JsonSignalProtocolStore() {
- }
+ private SessionStore sessionStore;
- public JsonSignalProtocolStore(
- JsonPreKeyStore preKeyStore,
- JsonSessionStore sessionStore,
- JsonSignedPreKeyStore signedPreKeyStore,
- JsonIdentityKeyStore identityKeyStore
- ) {
- this.preKeyStore = preKeyStore;
- this.sessionStore = sessionStore;
- this.signedPreKeyStore = signedPreKeyStore;
- this.identityKeyStore = identityKeyStore;
+ public JsonSignalProtocolStore() {
}
- public JsonSignalProtocolStore(IdentityKeyPair identityKeyPair, int registrationId) {
+ public JsonSignalProtocolStore(IdentityKeyPair identityKeyPair, int registrationId, SessionStore sessionStore) {
preKeyStore = new JsonPreKeyStore();
- sessionStore = new JsonSessionStore();
+ this.sessionStore = sessionStore;
signedPreKeyStore = new JsonSignedPreKeyStore();
this.identityKeyStore = new JsonIdentityKeyStore(identityKeyPair, registrationId);
}
public void setResolver(final SignalServiceAddressResolver resolver) {
- sessionStore.setResolver(resolver);
identityKeyStore.setResolver(resolver);
}
+ public void setSessionStore(final SessionStore sessionStore) {
+ this.sessionStore = sessionStore;
+ }
+
+ public LegacyJsonSessionStore getLegacySessionStore() {
+ return legacySessionStore;
+ }
+
@Override
public IdentityKeyPair getIdentityKeyPair() {
return identityKeyStore.getIdentityKeyPair();
return sessionStore.loadSession(address);
}
- public List<SessionInfo> getSessions() {
- return sessionStore.getSessions();
- }
-
@Override
public List<Integer> getSubDeviceSessions(String name) {
return sessionStore.getSubDeviceSessions(name);
sessionStore.deleteAllSessions(name);
}
- public void deleteAllSessions(SignalServiceAddress serviceAddress) {
- sessionStore.deleteAllSessions(serviceAddress);
- }
-
@Override
public void archiveSession(final SignalProtocolAddress address) {
sessionStore.archiveSession(address);
}
- public void archiveAllSessions() {
- sessionStore.archiveAllSessions();
- }
-
@Override
public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException {
return signedPreKeyStore.loadSignedPreKey(signedPreKeyId);
--- /dev/null
+package org.asamk.signal.manager.storage.protocol;
+
+import com.fasterxml.jackson.core.JsonParser;
+import com.fasterxml.jackson.databind.DeserializationContext;
+import com.fasterxml.jackson.databind.JsonDeserializer;
+import com.fasterxml.jackson.databind.JsonNode;
+
+import org.asamk.signal.manager.util.Utils;
+import org.whispersystems.signalservice.api.push.SignalServiceAddress;
+import org.whispersystems.signalservice.api.util.UuidUtil;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.List;
+
+public class LegacyJsonSessionStore {
+
+ private final List<SessionInfo> sessions = new ArrayList<>();
+
+ public LegacyJsonSessionStore() {
+ }
+
+ public List<SessionInfo> getSessions() {
+ return sessions;
+ }
+
+ public static class JsonSessionStoreDeserializer extends JsonDeserializer<LegacyJsonSessionStore> {
+
+ @Override
+ public LegacyJsonSessionStore deserialize(
+ JsonParser jsonParser, DeserializationContext deserializationContext
+ ) throws IOException {
+ JsonNode node = jsonParser.getCodec().readTree(jsonParser);
+
+ var sessionStore = new LegacyJsonSessionStore();
+
+ if (node.isArray()) {
+ for (var session : node) {
+ var sessionName = session.hasNonNull("name") ? session.get("name").asText() : null;
+ if (UuidUtil.isUuid(sessionName)) {
+ // Ignore sessions that were incorrectly created with UUIDs as name
+ continue;
+ }
+
+ var uuid = session.hasNonNull("uuid") ? UuidUtil.parseOrNull(session.get("uuid").asText()) : null;
+ final var serviceAddress = uuid == null
+ ? Utils.getSignalServiceAddressFromIdentifier(sessionName)
+ : new SignalServiceAddress(uuid, sessionName);
+ final var deviceId = session.get("deviceId").asInt();
+ final var record = Base64.getDecoder().decode(session.get("record").asText());
+ var sessionInfo = new SessionInfo(serviceAddress, deviceId, record);
+ sessionStore.sessions.add(sessionInfo);
+ }
+ }
+
+ return sessionStore;
+ }
+ }
+}
this.id = id;
}
- long getId() {
+ public static RecipientId of(long id) {
+ return new RecipientId(id);
+ }
+
+ public long getId() {
return id;
}
--- /dev/null
+package org.asamk.signal.manager.storage.recipients;
+
+import org.whispersystems.signalservice.api.push.SignalServiceAddress;
+
+public interface RecipientResolver {
+
+ RecipientId resolveRecipient(SignalServiceAddress address);
+}
--- /dev/null
+package org.asamk.signal.manager.storage.sessions;
+
+import org.asamk.signal.manager.storage.recipients.RecipientId;
+import org.asamk.signal.manager.storage.recipients.RecipientResolver;
+import org.asamk.signal.manager.util.IOUtils;
+import org.asamk.signal.manager.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.libsignal.SignalProtocolAddress;
+import org.whispersystems.libsignal.protocol.CiphertextMessage;
+import org.whispersystems.libsignal.state.SessionRecord;
+import org.whispersystems.signalservice.api.SignalServiceSessionStore;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+
+public class SessionStore implements SignalServiceSessionStore {
+
+ private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
+
+ private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
+
+ private final File sessionsPath;
+
+ private final RecipientResolver resolver;
+
+ public SessionStore(
+ final File sessionsPath, final RecipientResolver resolver
+ ) {
+ this.sessionsPath = sessionsPath;
+ this.resolver = resolver;
+ }
+
+ @Override
+ public SessionRecord loadSession(SignalProtocolAddress address) {
+ final var key = getKey(address);
+
+ synchronized (cachedSessions) {
+ final var session = loadSessionLocked(key);
+ if (session == null) {
+ return new SessionRecord();
+ }
+ return session;
+ }
+ }
+
+ @Override
+ public List<Integer> getSubDeviceSessions(String name) {
+ final var recipientId = resolveRecipient(name);
+
+ synchronized (cachedSessions) {
+ return getKeysLocked(recipientId).stream()
+ // get all sessions for recipient except main device session
+ .filter(key -> key.getDeviceId() != 1 && key.getRecipientId().equals(recipientId))
+ .map(Key::getDeviceId)
+ .collect(Collectors.toList());
+ }
+ }
+
+ @Override
+ public void storeSession(SignalProtocolAddress address, SessionRecord session) {
+ final var key = getKey(address);
+
+ synchronized (cachedSessions) {
+ storeSessionLocked(key, session);
+ }
+ }
+
+ @Override
+ public boolean containsSession(SignalProtocolAddress address) {
+ final var key = getKey(address);
+
+ synchronized (cachedSessions) {
+ final var session = loadSessionLocked(key);
+ if (session == null) {
+ return false;
+ }
+
+ return session.hasSenderChain() && session.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
+ }
+ }
+
+ @Override
+ public void deleteSession(SignalProtocolAddress address) {
+ final var key = getKey(address);
+
+ synchronized (cachedSessions) {
+ deleteSessionLocked(key);
+ }
+ }
+
+ @Override
+ public void deleteAllSessions(String name) {
+ final var recipientId = resolveRecipient(name);
+ deleteAllSessions(recipientId);
+ }
+
+ public void deleteAllSessions(RecipientId recipientId) {
+ synchronized (cachedSessions) {
+ final var keys = getKeysLocked(recipientId);
+ for (var key : keys) {
+ deleteSessionLocked(key);
+ }
+ }
+ }
+
+ @Override
+ public void archiveSession(final SignalProtocolAddress address) {
+ final var key = getKey(address);
+
+ synchronized (cachedSessions) {
+ archiveSessionLocked(key);
+ }
+ }
+
+ public void archiveAllSessions() {
+ synchronized (cachedSessions) {
+ final var keys = getKeysLocked();
+ for (var key : keys) {
+ archiveSessionLocked(key);
+ }
+ }
+ }
+
+ public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
+ synchronized (cachedSessions) {
+ final var otherHasSession = getKeysLocked(toBeMergedRecipientId).size() > 0;
+ if (!otherHasSession) {
+ return;
+ }
+
+ final var hasSession = getKeysLocked(recipientId).size() > 0;
+ if (hasSession) {
+ logger.debug("To be merged recipient had sessions, deleting.");
+ deleteAllSessions(toBeMergedRecipientId);
+ } else {
+ logger.debug("To be merged recipient had sessions, re-assigning to the new recipient.");
+ final var keys = getKeysLocked(toBeMergedRecipientId);
+ for (var key : keys) {
+ final var session = loadSessionLocked(key);
+ deleteSessionLocked(key);
+ if (session == null) {
+ continue;
+ }
+ final var newKey = new Key(recipientId, key.getDeviceId());
+ storeSessionLocked(newKey, session);
+ }
+ }
+ }
+ }
+
+ /**
+ * @param identifier can be either a serialized uuid or a e164 phone number
+ */
+ private RecipientId resolveRecipient(String identifier) {
+ return resolver.resolveRecipient(Utils.getSignalServiceAddressFromIdentifier(identifier));
+ }
+
+ private Key getKey(final SignalProtocolAddress address) {
+ final var recipientId = resolveRecipient(address.getName());
+ return new Key(recipientId, address.getDeviceId());
+ }
+
+ private List<Key> getKeysLocked(RecipientId recipientId) {
+ final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
+ if (files == null) {
+ return List.of();
+ }
+ return parseFileNames(files);
+ }
+
+ private Collection<Key> getKeysLocked() {
+ final var files = sessionsPath.listFiles();
+ if (files == null) {
+ return List.of();
+ }
+ return parseFileNames(files);
+ }
+
+ final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
+
+ private List<Key> parseFileNames(final File[] files) {
+ return Arrays.stream(files)
+ .map(f -> sessionFileNamePattern.matcher(f.getName()))
+ .filter(Matcher::matches)
+ .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
+ Integer.parseInt(matcher.group(2))))
+ .collect(Collectors.toList());
+ }
+
+ private File getSessionPath(Key key) {
+ try {
+ IOUtils.createPrivateDirectories(sessionsPath);
+ } catch (IOException e) {
+ throw new AssertionError("Failed to create sessions path", e);
+ }
+ return new File(sessionsPath, key.getRecipientId().getId() + "_" + key.getDeviceId());
+ }
+
+ private SessionRecord loadSessionLocked(final Key key) {
+ {
+ final var session = cachedSessions.get(key);
+ if (session != null) {
+ return session;
+ }
+ }
+
+ final var file = getSessionPath(key);
+ if (!file.exists()) {
+ return null;
+ }
+ try (var inputStream = new FileInputStream(file)) {
+ final var session = new SessionRecord(inputStream.readAllBytes());
+ cachedSessions.put(key, session);
+ return session;
+ } catch (IOException e) {
+ logger.warn("Failed to load session, resetting session: {}", e.getMessage());
+ return null;
+ }
+ }
+
+ private void storeSessionLocked(final Key key, final SessionRecord session) {
+ cachedSessions.put(key, session);
+
+ final var file = getSessionPath(key);
+ try {
+ try (var outputStream = new FileOutputStream(file)) {
+ outputStream.write(session.serialize());
+ }
+ } catch (IOException e) {
+ logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
+ try {
+ Files.delete(file.toPath());
+ try (var outputStream = new FileOutputStream(file)) {
+ outputStream.write(session.serialize());
+ }
+ } catch (IOException e2) {
+ logger.error("Failed to store session file {}: {}", file, e2.getMessage());
+ }
+ }
+ }
+
+ private void archiveSessionLocked(final Key key) {
+ final var session = loadSessionLocked(key);
+ if (session == null) {
+ return;
+ }
+ session.archiveCurrentState();
+ storeSessionLocked(key, session);
+ }
+
+ private void deleteSessionLocked(final Key key) {
+ cachedSessions.remove(key);
+
+ final var file = getSessionPath(key);
+ if (!file.exists()) {
+ return;
+ }
+ try {
+ Files.delete(file.toPath());
+ } catch (IOException e) {
+ logger.error("Failed to delete session file {}: {}", file, e.getMessage());
+ }
+ }
+
+ private static final class Key {
+
+ private final RecipientId recipientId;
+ private final int deviceId;
+
+ public Key(final RecipientId recipientId, final int deviceId) {
+ this.recipientId = recipientId;
+ this.deviceId = deviceId;
+ }
+
+ public RecipientId getRecipientId() {
+ return recipientId;
+ }
+
+ public int getDeviceId() {
+ return deviceId;
+ }
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ final var key = (Key) o;
+
+ if (deviceId != key.deviceId) return false;
+ return recipientId.equals(key.recipientId);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = recipientId.hashCode();
+ result = 31 * result + deviceId;
+ return result;
+ }
+ }
+}