From: AsamK Date: Thu, 15 Apr 2021 20:33:35 +0000 (+0200) Subject: Refactor sessions store X-Git-Tag: v0.8.2~47 X-Git-Url: https://git.nmode.ca/signal-cli/commitdiff_plain/f77519445cb81b0911b89b232edf17187ea8cef5?ds=sidebyside Refactor sessions store --- diff --git a/lib/src/main/java/org/asamk/signal/manager/Manager.java b/lib/src/main/java/org/asamk/signal/manager/Manager.java index 5d0e7d3d..f672d0fe 100644 --- a/lib/src/main/java/org/asamk/signal/manager/Manager.java +++ b/lib/src/main/java/org/asamk/signal/manager/Manager.java @@ -263,7 +263,7 @@ public class Manager implements Closeable { } private IdentityKeyPair getIdentityKeyPair() { - return account.getSignalProtocolStore().getIdentityKeyPair(); + return account.getIdentityKeyPair(); } public int getDeviceId() { @@ -336,7 +336,7 @@ public class Manager implements Closeable { public void updateAccountAttributes() throws IOException { accountManager.setAccountAttributes(null, - account.getSignalProtocolStore().getLocalRegistrationId(), + account.getLocalRegistrationId(), true, // set legacy pin only if no KBS master key is set account.getPinMasterKey() == null ? account.getRegistrationLockPin() : null, @@ -1441,7 +1441,7 @@ public class Manager implements Closeable { } private void handleEndSession(SignalServiceAddress source) { - account.getSignalProtocolStore().deleteAllSessions(source); + account.getSessionStore().deleteAllSessions(source.getIdentifier()); } private List handleSignalServiceDataMessage( diff --git a/lib/src/main/java/org/asamk/signal/manager/RegistrationManager.java b/lib/src/main/java/org/asamk/signal/manager/RegistrationManager.java index aad731a0..a16ead37 100644 --- a/lib/src/main/java/org/asamk/signal/manager/RegistrationManager.java +++ b/lib/src/main/java/org/asamk/signal/manager/RegistrationManager.java @@ -163,10 +163,10 @@ public class RegistrationManager implements Closeable { account.setRegistered(true); account.setUuid(UuidUtil.parseOrNull(response.getUuid())); account.setRegistrationLockPin(pin); - account.getSignalProtocolStore().archiveAllSessions(); + account.getSessionStore().archiveAllSessions(); account.getSignalProtocolStore() .saveIdentity(account.getSelfAddress(), - account.getSignalProtocolStore().getIdentityKeyPair().getPublicKey(), + account.getIdentityKeyPair().getPublicKey(), TrustLevel.TRUSTED_VERIFIED); Manager m = null; @@ -194,7 +194,7 @@ public class RegistrationManager implements Closeable { ) throws IOException { return accountManager.verifyAccountWithCode(verificationCode, null, - account.getSignalProtocolStore().getLocalRegistrationId(), + account.getLocalRegistrationId(), true, legacyPin, registrationLock, diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java index bec73b94..18526f2a 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java @@ -21,6 +21,7 @@ import org.asamk.signal.manager.storage.protocol.SignalServiceAddressResolver; import org.asamk.signal.manager.storage.recipients.LegacyRecipientStore; import org.asamk.signal.manager.storage.recipients.RecipientId; import org.asamk.signal.manager.storage.recipients.RecipientStore; +import org.asamk.signal.manager.storage.sessions.SessionStore; import org.asamk.signal.manager.storage.stickers.StickerStore; import org.asamk.signal.manager.storage.threads.LegacyJsonThreadStore; import org.asamk.signal.manager.util.IOUtils; @@ -31,7 +32,9 @@ import org.signal.zkgroup.profiles.ProfileKey; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.whispersystems.libsignal.IdentityKeyPair; +import org.whispersystems.libsignal.SignalProtocolAddress; import org.whispersystems.libsignal.state.PreKeyRecord; +import org.whispersystems.libsignal.state.SessionRecord; import org.whispersystems.libsignal.state.SignedPreKeyRecord; import org.whispersystems.libsignal.util.Medium; import org.whispersystems.libsignal.util.Pair; @@ -77,6 +80,7 @@ public class SignalAccount implements Closeable { private boolean registered = false; private JsonSignalProtocolStore signalProtocolStore; + private SessionStore sessionStore; private JsonGroupStore groupStore; private JsonContactsStore contactStore; private RecipientStore recipientStore; @@ -125,11 +129,13 @@ public class SignalAccount implements Closeable { account.username = username; account.profileKey = profileKey; - account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId); account.groupStore = new JsonGroupStore(getGroupCachePath(dataPath, username)); account.contactStore = new JsonContactsStore(); account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username), account::mergeRecipients); + account.sessionStore = new SessionStore(getSessionsPath(dataPath, username), + account.recipientStore::resolveRecipient); + account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId, account.sessionStore); account.profileStore = new ProfileStore(); account.stickerStore = new StickerStore(); @@ -166,11 +172,13 @@ public class SignalAccount implements Closeable { account.password = password; account.profileKey = profileKey; account.deviceId = deviceId; - account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId); account.groupStore = new JsonGroupStore(getGroupCachePath(dataPath, username)); account.contactStore = new JsonContactsStore(); account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username), account::mergeRecipients); + account.sessionStore = new SessionStore(getSessionsPath(dataPath, username), + account.recipientStore::resolveRecipient); + account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId, account.sessionStore); account.profileStore = new ProfileStore(); account.stickerStore = new StickerStore(); @@ -210,7 +218,7 @@ public class SignalAccount implements Closeable { } private void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) { - // TODO + sessionStore.mergeRecipients(recipientId, toBeMergedRecipientId); } public static File getFileName(File dataPath, String username) { @@ -229,6 +237,10 @@ public class SignalAccount implements Closeable { return new File(getUserPath(dataPath, username), "group-cache"); } + private static File getSessionsPath(File dataPath, String username) { + return new File(getUserPath(dataPath, username), "sessions"); + } + private static File getRecipientsStoreFile(File dataPath, String username) { return new File(getUserPath(dataPath, username), "recipients-store"); } @@ -304,6 +316,19 @@ public class SignalAccount implements Closeable { signalProtocolStore = jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"), JsonSignalProtocolStore.class); + sessionStore = new SessionStore(getSessionsPath(dataPath, username), recipientStore::resolveRecipient); + if (signalProtocolStore.getLegacySessionStore() != null) { + logger.debug("Migrating legacy session store."); + for (var session : signalProtocolStore.getLegacySessionStore().getSessions()) { + try { + sessionStore.storeSession(new SignalProtocolAddress(session.address.getIdentifier(), + session.deviceId), new SessionRecord(session.sessionRecord)); + } catch (IOException e) { + logger.warn("Failed to migrate session, ignoring", e); + } + } + } + signalProtocolStore.setSessionStore(sessionStore); registered = Utils.getNotNullNode(rootNode, "registered").asBoolean(); var groupStoreNode = rootNode.get("groupStore"); if (groupStoreNode != null) { @@ -355,10 +380,6 @@ public class SignalAccount implements Closeable { } } - for (var session : signalProtocolStore.getSessions()) { - session.address = recipientStore.resolveServiceAddress(session.address); - } - for (var identity : signalProtocolStore.getIdentities()) { identity.setAddress(recipientStore.resolveServiceAddress(identity.getAddress())); } @@ -464,6 +485,10 @@ public class SignalAccount implements Closeable { return signalProtocolStore; } + public SessionStore getSessionStore() { + return sessionStore; + } + public JsonGroupStore getGroupStore() { return groupStore; } @@ -516,6 +541,14 @@ public class SignalAccount implements Closeable { return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID; } + public IdentityKeyPair getIdentityKeyPair() { + return signalProtocolStore.getIdentityKeyPair(); + } + + public int getLocalRegistrationId() { + return signalProtocolStore.getLocalRegistrationId(); + } + public String getPassword() { return password; } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSessionStore.java deleted file mode 100644 index 1b5384a4..00000000 --- a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSessionStore.java +++ /dev/null @@ -1,214 +0,0 @@ -package org.asamk.signal.manager.storage.protocol; - -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.SerializerProvider; - -import org.asamk.signal.manager.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.whispersystems.libsignal.SignalProtocolAddress; -import org.whispersystems.libsignal.protocol.CiphertextMessage; -import org.whispersystems.libsignal.state.SessionRecord; -import org.whispersystems.signalservice.api.SignalServiceSessionStore; -import org.whispersystems.signalservice.api.push.SignalServiceAddress; -import org.whispersystems.signalservice.api.util.UuidUtil; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Base64; -import java.util.LinkedList; -import java.util.List; - -class JsonSessionStore implements SignalServiceSessionStore { - - private final static Logger logger = LoggerFactory.getLogger(JsonSessionStore.class); - - private final List sessions = new ArrayList<>(); - - private SignalServiceAddressResolver resolver; - - public JsonSessionStore() { - } - - public void setResolver(final SignalServiceAddressResolver resolver) { - this.resolver = resolver; - } - - private SignalServiceAddress resolveSignalServiceAddress(String identifier) { - if (resolver != null) { - return resolver.resolveSignalServiceAddress(identifier); - } else { - return Utils.getSignalServiceAddressFromIdentifier(identifier); - } - } - - @Override - public synchronized SessionRecord loadSession(SignalProtocolAddress address) { - var serviceAddress = resolveSignalServiceAddress(address.getName()); - for (var info : sessions) { - if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) { - try { - return new SessionRecord(info.sessionRecord); - } catch (IOException e) { - logger.warn("Failed to load session, resetting session: {}", e.getMessage()); - return new SessionRecord(); - } - } - } - - return new SessionRecord(); - } - - public synchronized List getSessions() { - return sessions; - } - - @Override - public synchronized List getSubDeviceSessions(String name) { - var serviceAddress = resolveSignalServiceAddress(name); - - var deviceIds = new LinkedList(); - for (var info : sessions) { - if (info.address.matches(serviceAddress) && info.deviceId != 1) { - deviceIds.add(info.deviceId); - } - } - - return deviceIds; - } - - @Override - public synchronized void storeSession(SignalProtocolAddress address, SessionRecord record) { - var serviceAddress = resolveSignalServiceAddress(address.getName()); - for (var info : sessions) { - if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) { - if (!info.address.getUuid().isPresent() || !info.address.getNumber().isPresent()) { - info.address = serviceAddress; - } - info.sessionRecord = record.serialize(); - return; - } - } - - sessions.add(new SessionInfo(serviceAddress, address.getDeviceId(), record.serialize())); - } - - @Override - public synchronized boolean containsSession(SignalProtocolAddress address) { - var serviceAddress = resolveSignalServiceAddress(address.getName()); - for (var info : sessions) { - if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) { - final SessionRecord sessionRecord; - try { - sessionRecord = new SessionRecord(info.sessionRecord); - } catch (IOException e) { - logger.warn("Failed to check session: {}", e.getMessage()); - return false; - } - - return sessionRecord.hasSenderChain() - && sessionRecord.getSessionVersion() == CiphertextMessage.CURRENT_VERSION; - } - } - return false; - } - - @Override - public synchronized void deleteSession(SignalProtocolAddress address) { - var serviceAddress = resolveSignalServiceAddress(address.getName()); - sessions.removeIf(info -> info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()); - } - - @Override - public synchronized void deleteAllSessions(String name) { - var serviceAddress = resolveSignalServiceAddress(name); - deleteAllSessions(serviceAddress); - } - - public synchronized void deleteAllSessions(SignalServiceAddress serviceAddress) { - sessions.removeIf(info -> info.address.matches(serviceAddress)); - } - - @Override - public void archiveSession(final SignalProtocolAddress address) { - final var sessionRecord = loadSession(address); - if (sessionRecord == null) { - return; - } - sessionRecord.archiveCurrentState(); - storeSession(address, sessionRecord); - } - - public void archiveAllSessions() { - for (var info : sessions) { - try { - final var sessionRecord = new SessionRecord(info.sessionRecord); - sessionRecord.archiveCurrentState(); - info.sessionRecord = sessionRecord.serialize(); - } catch (IOException ignored) { - } - } - } - - public static class JsonSessionStoreDeserializer extends JsonDeserializer { - - @Override - public JsonSessionStore deserialize( - JsonParser jsonParser, DeserializationContext deserializationContext - ) throws IOException { - JsonNode node = jsonParser.getCodec().readTree(jsonParser); - - var sessionStore = new JsonSessionStore(); - - if (node.isArray()) { - for (var session : node) { - var sessionName = session.hasNonNull("name") ? session.get("name").asText() : null; - if (UuidUtil.isUuid(sessionName)) { - // Ignore sessions that were incorrectly created with UUIDs as name - continue; - } - - var uuid = session.hasNonNull("uuid") ? UuidUtil.parseOrNull(session.get("uuid").asText()) : null; - final var serviceAddress = uuid == null - ? Utils.getSignalServiceAddressFromIdentifier(sessionName) - : new SignalServiceAddress(uuid, sessionName); - final var deviceId = session.get("deviceId").asInt(); - final var record = Base64.getDecoder().decode(session.get("record").asText()); - var sessionInfo = new SessionInfo(serviceAddress, deviceId, record); - sessionStore.sessions.add(sessionInfo); - } - } - - return sessionStore; - } - } - - public static class JsonSessionStoreSerializer extends JsonSerializer { - - @Override - public void serialize( - JsonSessionStore jsonSessionStore, JsonGenerator json, SerializerProvider serializerProvider - ) throws IOException { - json.writeStartArray(); - for (var sessionInfo : jsonSessionStore.sessions) { - json.writeStartObject(); - if (sessionInfo.address.getNumber().isPresent()) { - json.writeStringField("name", sessionInfo.address.getNumber().get()); - } - if (sessionInfo.address.getUuid().isPresent()) { - json.writeStringField("uuid", sessionInfo.address.getUuid().get().toString()); - } - json.writeNumberField("deviceId", sessionInfo.deviceId); - json.writeStringField("record", Base64.getEncoder().encodeToString(sessionInfo.sessionRecord)); - json.writeEndObject(); - } - json.writeEndArray(); - } - } - -} diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSignalProtocolStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSignalProtocolStore.java index d47f7a1a..dc45a0da 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSignalProtocolStore.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSignalProtocolStore.java @@ -1,10 +1,12 @@ package org.asamk.signal.manager.storage.protocol; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.annotation.JsonSerialize; import org.asamk.signal.manager.TrustLevel; +import org.asamk.signal.manager.storage.sessions.SessionStore; import org.whispersystems.libsignal.IdentityKey; import org.whispersystems.libsignal.IdentityKeyPair; import org.whispersystems.libsignal.InvalidKeyIdException; @@ -17,6 +19,7 @@ import org.whispersystems.signalservice.api.push.SignalServiceAddress; import java.util.List; +@JsonIgnoreProperties(value = "sessionStore", allowSetters = true) public class JsonSignalProtocolStore implements SignalServiceProtocolStore { @JsonProperty("preKeys") @@ -25,9 +28,8 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore { private JsonPreKeyStore preKeyStore; @JsonProperty("sessionStore") - @JsonDeserialize(using = JsonSessionStore.JsonSessionStoreDeserializer.class) - @JsonSerialize(using = JsonSessionStore.JsonSessionStoreSerializer.class) - private JsonSessionStore sessionStore; + @JsonDeserialize(using = LegacyJsonSessionStore.JsonSessionStoreDeserializer.class) + private LegacyJsonSessionStore legacySessionStore; @JsonProperty("signedPreKeyStore") @JsonDeserialize(using = JsonSignedPreKeyStore.JsonSignedPreKeyStoreDeserializer.class) @@ -39,33 +41,30 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore { @JsonSerialize(using = JsonIdentityKeyStore.JsonIdentityKeyStoreSerializer.class) private JsonIdentityKeyStore identityKeyStore; - public JsonSignalProtocolStore() { - } + private SessionStore sessionStore; - public JsonSignalProtocolStore( - JsonPreKeyStore preKeyStore, - JsonSessionStore sessionStore, - JsonSignedPreKeyStore signedPreKeyStore, - JsonIdentityKeyStore identityKeyStore - ) { - this.preKeyStore = preKeyStore; - this.sessionStore = sessionStore; - this.signedPreKeyStore = signedPreKeyStore; - this.identityKeyStore = identityKeyStore; + public JsonSignalProtocolStore() { } - public JsonSignalProtocolStore(IdentityKeyPair identityKeyPair, int registrationId) { + public JsonSignalProtocolStore(IdentityKeyPair identityKeyPair, int registrationId, SessionStore sessionStore) { preKeyStore = new JsonPreKeyStore(); - sessionStore = new JsonSessionStore(); + this.sessionStore = sessionStore; signedPreKeyStore = new JsonSignedPreKeyStore(); this.identityKeyStore = new JsonIdentityKeyStore(identityKeyPair, registrationId); } public void setResolver(final SignalServiceAddressResolver resolver) { - sessionStore.setResolver(resolver); identityKeyStore.setResolver(resolver); } + public void setSessionStore(final SessionStore sessionStore) { + this.sessionStore = sessionStore; + } + + public LegacyJsonSessionStore getLegacySessionStore() { + return legacySessionStore; + } + @Override public IdentityKeyPair getIdentityKeyPair() { return identityKeyStore.getIdentityKeyPair(); @@ -142,10 +141,6 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore { return sessionStore.loadSession(address); } - public List getSessions() { - return sessionStore.getSessions(); - } - @Override public List getSubDeviceSessions(String name) { return sessionStore.getSubDeviceSessions(name); @@ -171,19 +166,11 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore { sessionStore.deleteAllSessions(name); } - public void deleteAllSessions(SignalServiceAddress serviceAddress) { - sessionStore.deleteAllSessions(serviceAddress); - } - @Override public void archiveSession(final SignalProtocolAddress address) { sessionStore.archiveSession(address); } - public void archiveAllSessions() { - sessionStore.archiveAllSessions(); - } - @Override public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException { return signedPreKeyStore.loadSignedPreKey(signedPreKeyId); diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/LegacyJsonSessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/LegacyJsonSessionStore.java new file mode 100644 index 00000000..c43bdb08 --- /dev/null +++ b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/LegacyJsonSessionStore.java @@ -0,0 +1,60 @@ +package org.asamk.signal.manager.storage.protocol; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonNode; + +import org.asamk.signal.manager.util.Utils; +import org.whispersystems.signalservice.api.push.SignalServiceAddress; +import org.whispersystems.signalservice.api.util.UuidUtil; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; + +public class LegacyJsonSessionStore { + + private final List sessions = new ArrayList<>(); + + public LegacyJsonSessionStore() { + } + + public List getSessions() { + return sessions; + } + + public static class JsonSessionStoreDeserializer extends JsonDeserializer { + + @Override + public LegacyJsonSessionStore deserialize( + JsonParser jsonParser, DeserializationContext deserializationContext + ) throws IOException { + JsonNode node = jsonParser.getCodec().readTree(jsonParser); + + var sessionStore = new LegacyJsonSessionStore(); + + if (node.isArray()) { + for (var session : node) { + var sessionName = session.hasNonNull("name") ? session.get("name").asText() : null; + if (UuidUtil.isUuid(sessionName)) { + // Ignore sessions that were incorrectly created with UUIDs as name + continue; + } + + var uuid = session.hasNonNull("uuid") ? UuidUtil.parseOrNull(session.get("uuid").asText()) : null; + final var serviceAddress = uuid == null + ? Utils.getSignalServiceAddressFromIdentifier(sessionName) + : new SignalServiceAddress(uuid, sessionName); + final var deviceId = session.get("deviceId").asInt(); + final var record = Base64.getDecoder().decode(session.get("record").asText()); + var sessionInfo = new SessionInfo(serviceAddress, deviceId, record); + sessionStore.sessions.add(sessionInfo); + } + } + + return sessionStore; + } + } +} diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientId.java b/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientId.java index 96e2c692..9d22d672 100644 --- a/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientId.java +++ b/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientId.java @@ -8,7 +8,11 @@ public class RecipientId { this.id = id; } - long getId() { + public static RecipientId of(long id) { + return new RecipientId(id); + } + + public long getId() { return id; } diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientResolver.java b/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientResolver.java new file mode 100644 index 00000000..c6d06d23 --- /dev/null +++ b/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientResolver.java @@ -0,0 +1,8 @@ +package org.asamk.signal.manager.storage.recipients; + +import org.whispersystems.signalservice.api.push.SignalServiceAddress; + +public interface RecipientResolver { + + RecipientId resolveRecipient(SignalServiceAddress address); +} diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java new file mode 100644 index 00000000..0773af9a --- /dev/null +++ b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java @@ -0,0 +1,313 @@ +package org.asamk.signal.manager.storage.sessions; + +import org.asamk.signal.manager.storage.recipients.RecipientId; +import org.asamk.signal.manager.storage.recipients.RecipientResolver; +import org.asamk.signal.manager.util.IOUtils; +import org.asamk.signal.manager.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.whispersystems.libsignal.SignalProtocolAddress; +import org.whispersystems.libsignal.protocol.CiphertextMessage; +import org.whispersystems.libsignal.state.SessionRecord; +import org.whispersystems.signalservice.api.SignalServiceSessionStore; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +public class SessionStore implements SignalServiceSessionStore { + + private final static Logger logger = LoggerFactory.getLogger(SessionStore.class); + + private final Map cachedSessions = new HashMap<>(); + + private final File sessionsPath; + + private final RecipientResolver resolver; + + public SessionStore( + final File sessionsPath, final RecipientResolver resolver + ) { + this.sessionsPath = sessionsPath; + this.resolver = resolver; + } + + @Override + public SessionRecord loadSession(SignalProtocolAddress address) { + final var key = getKey(address); + + synchronized (cachedSessions) { + final var session = loadSessionLocked(key); + if (session == null) { + return new SessionRecord(); + } + return session; + } + } + + @Override + public List getSubDeviceSessions(String name) { + final var recipientId = resolveRecipient(name); + + synchronized (cachedSessions) { + return getKeysLocked(recipientId).stream() + // get all sessions for recipient except main device session + .filter(key -> key.getDeviceId() != 1 && key.getRecipientId().equals(recipientId)) + .map(Key::getDeviceId) + .collect(Collectors.toList()); + } + } + + @Override + public void storeSession(SignalProtocolAddress address, SessionRecord session) { + final var key = getKey(address); + + synchronized (cachedSessions) { + storeSessionLocked(key, session); + } + } + + @Override + public boolean containsSession(SignalProtocolAddress address) { + final var key = getKey(address); + + synchronized (cachedSessions) { + final var session = loadSessionLocked(key); + if (session == null) { + return false; + } + + return session.hasSenderChain() && session.getSessionVersion() == CiphertextMessage.CURRENT_VERSION; + } + } + + @Override + public void deleteSession(SignalProtocolAddress address) { + final var key = getKey(address); + + synchronized (cachedSessions) { + deleteSessionLocked(key); + } + } + + @Override + public void deleteAllSessions(String name) { + final var recipientId = resolveRecipient(name); + deleteAllSessions(recipientId); + } + + public void deleteAllSessions(RecipientId recipientId) { + synchronized (cachedSessions) { + final var keys = getKeysLocked(recipientId); + for (var key : keys) { + deleteSessionLocked(key); + } + } + } + + @Override + public void archiveSession(final SignalProtocolAddress address) { + final var key = getKey(address); + + synchronized (cachedSessions) { + archiveSessionLocked(key); + } + } + + public void archiveAllSessions() { + synchronized (cachedSessions) { + final var keys = getKeysLocked(); + for (var key : keys) { + archiveSessionLocked(key); + } + } + } + + public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) { + synchronized (cachedSessions) { + final var otherHasSession = getKeysLocked(toBeMergedRecipientId).size() > 0; + if (!otherHasSession) { + return; + } + + final var hasSession = getKeysLocked(recipientId).size() > 0; + if (hasSession) { + logger.debug("To be merged recipient had sessions, deleting."); + deleteAllSessions(toBeMergedRecipientId); + } else { + logger.debug("To be merged recipient had sessions, re-assigning to the new recipient."); + final var keys = getKeysLocked(toBeMergedRecipientId); + for (var key : keys) { + final var session = loadSessionLocked(key); + deleteSessionLocked(key); + if (session == null) { + continue; + } + final var newKey = new Key(recipientId, key.getDeviceId()); + storeSessionLocked(newKey, session); + } + } + } + } + + /** + * @param identifier can be either a serialized uuid or a e164 phone number + */ + private RecipientId resolveRecipient(String identifier) { + return resolver.resolveRecipient(Utils.getSignalServiceAddressFromIdentifier(identifier)); + } + + private Key getKey(final SignalProtocolAddress address) { + final var recipientId = resolveRecipient(address.getName()); + return new Key(recipientId, address.getDeviceId()); + } + + private List getKeysLocked(RecipientId recipientId) { + final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_")); + if (files == null) { + return List.of(); + } + return parseFileNames(files); + } + + private Collection getKeysLocked() { + final var files = sessionsPath.listFiles(); + if (files == null) { + return List.of(); + } + return parseFileNames(files); + } + + final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)"); + + private List parseFileNames(final File[] files) { + return Arrays.stream(files) + .map(f -> sessionFileNamePattern.matcher(f.getName())) + .filter(Matcher::matches) + .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))), + Integer.parseInt(matcher.group(2)))) + .collect(Collectors.toList()); + } + + private File getSessionPath(Key key) { + try { + IOUtils.createPrivateDirectories(sessionsPath); + } catch (IOException e) { + throw new AssertionError("Failed to create sessions path", e); + } + return new File(sessionsPath, key.getRecipientId().getId() + "_" + key.getDeviceId()); + } + + private SessionRecord loadSessionLocked(final Key key) { + { + final var session = cachedSessions.get(key); + if (session != null) { + return session; + } + } + + final var file = getSessionPath(key); + if (!file.exists()) { + return null; + } + try (var inputStream = new FileInputStream(file)) { + final var session = new SessionRecord(inputStream.readAllBytes()); + cachedSessions.put(key, session); + return session; + } catch (IOException e) { + logger.warn("Failed to load session, resetting session: {}", e.getMessage()); + return null; + } + } + + private void storeSessionLocked(final Key key, final SessionRecord session) { + cachedSessions.put(key, session); + + final var file = getSessionPath(key); + try { + try (var outputStream = new FileOutputStream(file)) { + outputStream.write(session.serialize()); + } + } catch (IOException e) { + logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage()); + try { + Files.delete(file.toPath()); + try (var outputStream = new FileOutputStream(file)) { + outputStream.write(session.serialize()); + } + } catch (IOException e2) { + logger.error("Failed to store session file {}: {}", file, e2.getMessage()); + } + } + } + + private void archiveSessionLocked(final Key key) { + final var session = loadSessionLocked(key); + if (session == null) { + return; + } + session.archiveCurrentState(); + storeSessionLocked(key, session); + } + + private void deleteSessionLocked(final Key key) { + cachedSessions.remove(key); + + final var file = getSessionPath(key); + if (!file.exists()) { + return; + } + try { + Files.delete(file.toPath()); + } catch (IOException e) { + logger.error("Failed to delete session file {}: {}", file, e.getMessage()); + } + } + + private static final class Key { + + private final RecipientId recipientId; + private final int deviceId; + + public Key(final RecipientId recipientId, final int deviceId) { + this.recipientId = recipientId; + this.deviceId = deviceId; + } + + public RecipientId getRecipientId() { + return recipientId; + } + + public int getDeviceId() { + return deviceId; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + final var key = (Key) o; + + if (deviceId != key.deviceId) return false; + return recipientId.equals(key.recipientId); + } + + @Override + public int hashCode() { + int result = recipientId.hashCode(); + result = 31 * result + deviceId; + return result; + } + } +}