]> nmode's Git Repositories - signal-cli/commitdiff
Refactor sessions store
authorAsamK <asamk@gmx.de>
Thu, 15 Apr 2021 20:33:35 +0000 (22:33 +0200)
committerAsamK <asamk@gmx.de>
Sat, 1 May 2021 06:46:00 +0000 (08:46 +0200)
lib/src/main/java/org/asamk/signal/manager/Manager.java
lib/src/main/java/org/asamk/signal/manager/RegistrationManager.java
lib/src/main/java/org/asamk/signal/manager/storage/SignalAccount.java
lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSessionStore.java [deleted file]
lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSignalProtocolStore.java
lib/src/main/java/org/asamk/signal/manager/storage/protocol/LegacyJsonSessionStore.java [new file with mode: 0644]
lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientId.java
lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientResolver.java [new file with mode: 0644]
lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java [new file with mode: 0644]

index 5d0e7d3dfca0ec4605918fdc9f034f9839b63c60..f672d0fe41a7ee72c5236793243a30d1360cb66c 100644 (file)
@@ -263,7 +263,7 @@ public class Manager implements Closeable {
     }
 
     private IdentityKeyPair getIdentityKeyPair() {
-        return account.getSignalProtocolStore().getIdentityKeyPair();
+        return account.getIdentityKeyPair();
     }
 
     public int getDeviceId() {
@@ -336,7 +336,7 @@ public class Manager implements Closeable {
 
     public void updateAccountAttributes() throws IOException {
         accountManager.setAccountAttributes(null,
-                account.getSignalProtocolStore().getLocalRegistrationId(),
+                account.getLocalRegistrationId(),
                 true,
                 // set legacy pin only if no KBS master key is set
                 account.getPinMasterKey() == null ? account.getRegistrationLockPin() : null,
@@ -1441,7 +1441,7 @@ public class Manager implements Closeable {
     }
 
     private void handleEndSession(SignalServiceAddress source) {
-        account.getSignalProtocolStore().deleteAllSessions(source);
+        account.getSessionStore().deleteAllSessions(source.getIdentifier());
     }
 
     private List<HandleAction> handleSignalServiceDataMessage(
index aad731a0a4f100460665b909e178feefa6eeddfc..a16ead37c3a26b5680cb1add71d50d43c054fe09 100644 (file)
@@ -163,10 +163,10 @@ public class RegistrationManager implements Closeable {
         account.setRegistered(true);
         account.setUuid(UuidUtil.parseOrNull(response.getUuid()));
         account.setRegistrationLockPin(pin);
-        account.getSignalProtocolStore().archiveAllSessions();
+        account.getSessionStore().archiveAllSessions();
         account.getSignalProtocolStore()
                 .saveIdentity(account.getSelfAddress(),
-                        account.getSignalProtocolStore().getIdentityKeyPair().getPublicKey(),
+                        account.getIdentityKeyPair().getPublicKey(),
                         TrustLevel.TRUSTED_VERIFIED);
 
         Manager m = null;
@@ -194,7 +194,7 @@ public class RegistrationManager implements Closeable {
     ) throws IOException {
         return accountManager.verifyAccountWithCode(verificationCode,
                 null,
-                account.getSignalProtocolStore().getLocalRegistrationId(),
+                account.getLocalRegistrationId(),
                 true,
                 legacyPin,
                 registrationLock,
index bec73b9409b3823595cdba80e3e794262c0697fd..18526f2a4a4568f0104dff8eb687fe6bb267d787 100644 (file)
@@ -21,6 +21,7 @@ import org.asamk.signal.manager.storage.protocol.SignalServiceAddressResolver;
 import org.asamk.signal.manager.storage.recipients.LegacyRecipientStore;
 import org.asamk.signal.manager.storage.recipients.RecipientId;
 import org.asamk.signal.manager.storage.recipients.RecipientStore;
+import org.asamk.signal.manager.storage.sessions.SessionStore;
 import org.asamk.signal.manager.storage.stickers.StickerStore;
 import org.asamk.signal.manager.storage.threads.LegacyJsonThreadStore;
 import org.asamk.signal.manager.util.IOUtils;
@@ -31,7 +32,9 @@ import org.signal.zkgroup.profiles.ProfileKey;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.whispersystems.libsignal.IdentityKeyPair;
+import org.whispersystems.libsignal.SignalProtocolAddress;
 import org.whispersystems.libsignal.state.PreKeyRecord;
+import org.whispersystems.libsignal.state.SessionRecord;
 import org.whispersystems.libsignal.state.SignedPreKeyRecord;
 import org.whispersystems.libsignal.util.Medium;
 import org.whispersystems.libsignal.util.Pair;
@@ -77,6 +80,7 @@ public class SignalAccount implements Closeable {
     private boolean registered = false;
 
     private JsonSignalProtocolStore signalProtocolStore;
+    private SessionStore sessionStore;
     private JsonGroupStore groupStore;
     private JsonContactsStore contactStore;
     private RecipientStore recipientStore;
@@ -125,11 +129,13 @@ public class SignalAccount implements Closeable {
 
         account.username = username;
         account.profileKey = profileKey;
-        account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId);
         account.groupStore = new JsonGroupStore(getGroupCachePath(dataPath, username));
         account.contactStore = new JsonContactsStore();
         account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username),
                 account::mergeRecipients);
+        account.sessionStore = new SessionStore(getSessionsPath(dataPath, username),
+                account.recipientStore::resolveRecipient);
+        account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId, account.sessionStore);
         account.profileStore = new ProfileStore();
         account.stickerStore = new StickerStore();
 
@@ -166,11 +172,13 @@ public class SignalAccount implements Closeable {
         account.password = password;
         account.profileKey = profileKey;
         account.deviceId = deviceId;
-        account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId);
         account.groupStore = new JsonGroupStore(getGroupCachePath(dataPath, username));
         account.contactStore = new JsonContactsStore();
         account.recipientStore = RecipientStore.load(getRecipientsStoreFile(dataPath, username),
                 account::mergeRecipients);
+        account.sessionStore = new SessionStore(getSessionsPath(dataPath, username),
+                account.recipientStore::resolveRecipient);
+        account.signalProtocolStore = new JsonSignalProtocolStore(identityKey, registrationId, account.sessionStore);
         account.profileStore = new ProfileStore();
         account.stickerStore = new StickerStore();
 
@@ -210,7 +218,7 @@ public class SignalAccount implements Closeable {
     }
 
     private void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
-        // TODO
+        sessionStore.mergeRecipients(recipientId, toBeMergedRecipientId);
     }
 
     public static File getFileName(File dataPath, String username) {
@@ -229,6 +237,10 @@ public class SignalAccount implements Closeable {
         return new File(getUserPath(dataPath, username), "group-cache");
     }
 
+    private static File getSessionsPath(File dataPath, String username) {
+        return new File(getUserPath(dataPath, username), "sessions");
+    }
+
     private static File getRecipientsStoreFile(File dataPath, String username) {
         return new File(getUserPath(dataPath, username), "recipients-store");
     }
@@ -304,6 +316,19 @@ public class SignalAccount implements Closeable {
 
         signalProtocolStore = jsonProcessor.convertValue(Utils.getNotNullNode(rootNode, "axolotlStore"),
                 JsonSignalProtocolStore.class);
+        sessionStore = new SessionStore(getSessionsPath(dataPath, username), recipientStore::resolveRecipient);
+        if (signalProtocolStore.getLegacySessionStore() != null) {
+            logger.debug("Migrating legacy session store.");
+            for (var session : signalProtocolStore.getLegacySessionStore().getSessions()) {
+                try {
+                    sessionStore.storeSession(new SignalProtocolAddress(session.address.getIdentifier(),
+                            session.deviceId), new SessionRecord(session.sessionRecord));
+                } catch (IOException e) {
+                    logger.warn("Failed to migrate session, ignoring", e);
+                }
+            }
+        }
+        signalProtocolStore.setSessionStore(sessionStore);
         registered = Utils.getNotNullNode(rootNode, "registered").asBoolean();
         var groupStoreNode = rootNode.get("groupStore");
         if (groupStoreNode != null) {
@@ -355,10 +380,6 @@ public class SignalAccount implements Closeable {
                 }
             }
 
-            for (var session : signalProtocolStore.getSessions()) {
-                session.address = recipientStore.resolveServiceAddress(session.address);
-            }
-
             for (var identity : signalProtocolStore.getIdentities()) {
                 identity.setAddress(recipientStore.resolveServiceAddress(identity.getAddress()));
             }
@@ -464,6 +485,10 @@ public class SignalAccount implements Closeable {
         return signalProtocolStore;
     }
 
+    public SessionStore getSessionStore() {
+        return sessionStore;
+    }
+
     public JsonGroupStore getGroupStore() {
         return groupStore;
     }
@@ -516,6 +541,14 @@ public class SignalAccount implements Closeable {
         return deviceId == SignalServiceAddress.DEFAULT_DEVICE_ID;
     }
 
+    public IdentityKeyPair getIdentityKeyPair() {
+        return signalProtocolStore.getIdentityKeyPair();
+    }
+
+    public int getLocalRegistrationId() {
+        return signalProtocolStore.getLocalRegistrationId();
+    }
+
     public String getPassword() {
         return password;
     }
diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/JsonSessionStore.java
deleted file mode 100644 (file)
index 1b5384a..0000000
+++ /dev/null
@@ -1,214 +0,0 @@
-package org.asamk.signal.manager.storage.protocol;
-
-import com.fasterxml.jackson.core.JsonGenerator;
-import com.fasterxml.jackson.core.JsonParser;
-import com.fasterxml.jackson.databind.DeserializationContext;
-import com.fasterxml.jackson.databind.JsonDeserializer;
-import com.fasterxml.jackson.databind.JsonNode;
-import com.fasterxml.jackson.databind.JsonSerializer;
-import com.fasterxml.jackson.databind.SerializerProvider;
-
-import org.asamk.signal.manager.util.Utils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-import org.whispersystems.libsignal.SignalProtocolAddress;
-import org.whispersystems.libsignal.protocol.CiphertextMessage;
-import org.whispersystems.libsignal.state.SessionRecord;
-import org.whispersystems.signalservice.api.SignalServiceSessionStore;
-import org.whispersystems.signalservice.api.push.SignalServiceAddress;
-import org.whispersystems.signalservice.api.util.UuidUtil;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Base64;
-import java.util.LinkedList;
-import java.util.List;
-
-class JsonSessionStore implements SignalServiceSessionStore {
-
-    private final static Logger logger = LoggerFactory.getLogger(JsonSessionStore.class);
-
-    private final List<SessionInfo> sessions = new ArrayList<>();
-
-    private SignalServiceAddressResolver resolver;
-
-    public JsonSessionStore() {
-    }
-
-    public void setResolver(final SignalServiceAddressResolver resolver) {
-        this.resolver = resolver;
-    }
-
-    private SignalServiceAddress resolveSignalServiceAddress(String identifier) {
-        if (resolver != null) {
-            return resolver.resolveSignalServiceAddress(identifier);
-        } else {
-            return Utils.getSignalServiceAddressFromIdentifier(identifier);
-        }
-    }
-
-    @Override
-    public synchronized SessionRecord loadSession(SignalProtocolAddress address) {
-        var serviceAddress = resolveSignalServiceAddress(address.getName());
-        for (var info : sessions) {
-            if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
-                try {
-                    return new SessionRecord(info.sessionRecord);
-                } catch (IOException e) {
-                    logger.warn("Failed to load session, resetting session: {}", e.getMessage());
-                    return new SessionRecord();
-                }
-            }
-        }
-
-        return new SessionRecord();
-    }
-
-    public synchronized List<SessionInfo> getSessions() {
-        return sessions;
-    }
-
-    @Override
-    public synchronized List<Integer> getSubDeviceSessions(String name) {
-        var serviceAddress = resolveSignalServiceAddress(name);
-
-        var deviceIds = new LinkedList<Integer>();
-        for (var info : sessions) {
-            if (info.address.matches(serviceAddress) && info.deviceId != 1) {
-                deviceIds.add(info.deviceId);
-            }
-        }
-
-        return deviceIds;
-    }
-
-    @Override
-    public synchronized void storeSession(SignalProtocolAddress address, SessionRecord record) {
-        var serviceAddress = resolveSignalServiceAddress(address.getName());
-        for (var info : sessions) {
-            if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
-                if (!info.address.getUuid().isPresent() || !info.address.getNumber().isPresent()) {
-                    info.address = serviceAddress;
-                }
-                info.sessionRecord = record.serialize();
-                return;
-            }
-        }
-
-        sessions.add(new SessionInfo(serviceAddress, address.getDeviceId(), record.serialize()));
-    }
-
-    @Override
-    public synchronized boolean containsSession(SignalProtocolAddress address) {
-        var serviceAddress = resolveSignalServiceAddress(address.getName());
-        for (var info : sessions) {
-            if (info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId()) {
-                final SessionRecord sessionRecord;
-                try {
-                    sessionRecord = new SessionRecord(info.sessionRecord);
-                } catch (IOException e) {
-                    logger.warn("Failed to check session: {}", e.getMessage());
-                    return false;
-                }
-
-                return sessionRecord.hasSenderChain()
-                        && sessionRecord.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
-            }
-        }
-        return false;
-    }
-
-    @Override
-    public synchronized void deleteSession(SignalProtocolAddress address) {
-        var serviceAddress = resolveSignalServiceAddress(address.getName());
-        sessions.removeIf(info -> info.address.matches(serviceAddress) && info.deviceId == address.getDeviceId());
-    }
-
-    @Override
-    public synchronized void deleteAllSessions(String name) {
-        var serviceAddress = resolveSignalServiceAddress(name);
-        deleteAllSessions(serviceAddress);
-    }
-
-    public synchronized void deleteAllSessions(SignalServiceAddress serviceAddress) {
-        sessions.removeIf(info -> info.address.matches(serviceAddress));
-    }
-
-    @Override
-    public void archiveSession(final SignalProtocolAddress address) {
-        final var sessionRecord = loadSession(address);
-        if (sessionRecord == null) {
-            return;
-        }
-        sessionRecord.archiveCurrentState();
-        storeSession(address, sessionRecord);
-    }
-
-    public void archiveAllSessions() {
-        for (var info : sessions) {
-            try {
-                final var sessionRecord = new SessionRecord(info.sessionRecord);
-                sessionRecord.archiveCurrentState();
-                info.sessionRecord = sessionRecord.serialize();
-            } catch (IOException ignored) {
-            }
-        }
-    }
-
-    public static class JsonSessionStoreDeserializer extends JsonDeserializer<JsonSessionStore> {
-
-        @Override
-        public JsonSessionStore deserialize(
-                JsonParser jsonParser, DeserializationContext deserializationContext
-        ) throws IOException {
-            JsonNode node = jsonParser.getCodec().readTree(jsonParser);
-
-            var sessionStore = new JsonSessionStore();
-
-            if (node.isArray()) {
-                for (var session : node) {
-                    var sessionName = session.hasNonNull("name") ? session.get("name").asText() : null;
-                    if (UuidUtil.isUuid(sessionName)) {
-                        // Ignore sessions that were incorrectly created with UUIDs as name
-                        continue;
-                    }
-
-                    var uuid = session.hasNonNull("uuid") ? UuidUtil.parseOrNull(session.get("uuid").asText()) : null;
-                    final var serviceAddress = uuid == null
-                            ? Utils.getSignalServiceAddressFromIdentifier(sessionName)
-                            : new SignalServiceAddress(uuid, sessionName);
-                    final var deviceId = session.get("deviceId").asInt();
-                    final var record = Base64.getDecoder().decode(session.get("record").asText());
-                    var sessionInfo = new SessionInfo(serviceAddress, deviceId, record);
-                    sessionStore.sessions.add(sessionInfo);
-                }
-            }
-
-            return sessionStore;
-        }
-    }
-
-    public static class JsonSessionStoreSerializer extends JsonSerializer<JsonSessionStore> {
-
-        @Override
-        public void serialize(
-                JsonSessionStore jsonSessionStore, JsonGenerator json, SerializerProvider serializerProvider
-        ) throws IOException {
-            json.writeStartArray();
-            for (var sessionInfo : jsonSessionStore.sessions) {
-                json.writeStartObject();
-                if (sessionInfo.address.getNumber().isPresent()) {
-                    json.writeStringField("name", sessionInfo.address.getNumber().get());
-                }
-                if (sessionInfo.address.getUuid().isPresent()) {
-                    json.writeStringField("uuid", sessionInfo.address.getUuid().get().toString());
-                }
-                json.writeNumberField("deviceId", sessionInfo.deviceId);
-                json.writeStringField("record", Base64.getEncoder().encodeToString(sessionInfo.sessionRecord));
-                json.writeEndObject();
-            }
-            json.writeEndArray();
-        }
-    }
-
-}
index d47f7a1a2890ef6add8d05c03b7079e12e6783fd..dc45a0da901d28154d82756fa3922511712d4ddd 100644 (file)
@@ -1,10 +1,12 @@
 package org.asamk.signal.manager.storage.protocol;
 
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
 import com.fasterxml.jackson.databind.annotation.JsonSerialize;
 
 import org.asamk.signal.manager.TrustLevel;
+import org.asamk.signal.manager.storage.sessions.SessionStore;
 import org.whispersystems.libsignal.IdentityKey;
 import org.whispersystems.libsignal.IdentityKeyPair;
 import org.whispersystems.libsignal.InvalidKeyIdException;
@@ -17,6 +19,7 @@ import org.whispersystems.signalservice.api.push.SignalServiceAddress;
 
 import java.util.List;
 
+@JsonIgnoreProperties(value = "sessionStore", allowSetters = true)
 public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
 
     @JsonProperty("preKeys")
@@ -25,9 +28,8 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
     private JsonPreKeyStore preKeyStore;
 
     @JsonProperty("sessionStore")
-    @JsonDeserialize(using = JsonSessionStore.JsonSessionStoreDeserializer.class)
-    @JsonSerialize(using = JsonSessionStore.JsonSessionStoreSerializer.class)
-    private JsonSessionStore sessionStore;
+    @JsonDeserialize(using = LegacyJsonSessionStore.JsonSessionStoreDeserializer.class)
+    private LegacyJsonSessionStore legacySessionStore;
 
     @JsonProperty("signedPreKeyStore")
     @JsonDeserialize(using = JsonSignedPreKeyStore.JsonSignedPreKeyStoreDeserializer.class)
@@ -39,33 +41,30 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
     @JsonSerialize(using = JsonIdentityKeyStore.JsonIdentityKeyStoreSerializer.class)
     private JsonIdentityKeyStore identityKeyStore;
 
-    public JsonSignalProtocolStore() {
-    }
+    private SessionStore sessionStore;
 
-    public JsonSignalProtocolStore(
-            JsonPreKeyStore preKeyStore,
-            JsonSessionStore sessionStore,
-            JsonSignedPreKeyStore signedPreKeyStore,
-            JsonIdentityKeyStore identityKeyStore
-    ) {
-        this.preKeyStore = preKeyStore;
-        this.sessionStore = sessionStore;
-        this.signedPreKeyStore = signedPreKeyStore;
-        this.identityKeyStore = identityKeyStore;
+    public JsonSignalProtocolStore() {
     }
 
-    public JsonSignalProtocolStore(IdentityKeyPair identityKeyPair, int registrationId) {
+    public JsonSignalProtocolStore(IdentityKeyPair identityKeyPair, int registrationId, SessionStore sessionStore) {
         preKeyStore = new JsonPreKeyStore();
-        sessionStore = new JsonSessionStore();
+        this.sessionStore = sessionStore;
         signedPreKeyStore = new JsonSignedPreKeyStore();
         this.identityKeyStore = new JsonIdentityKeyStore(identityKeyPair, registrationId);
     }
 
     public void setResolver(final SignalServiceAddressResolver resolver) {
-        sessionStore.setResolver(resolver);
         identityKeyStore.setResolver(resolver);
     }
 
+    public void setSessionStore(final SessionStore sessionStore) {
+        this.sessionStore = sessionStore;
+    }
+
+    public LegacyJsonSessionStore getLegacySessionStore() {
+        return legacySessionStore;
+    }
+
     @Override
     public IdentityKeyPair getIdentityKeyPair() {
         return identityKeyStore.getIdentityKeyPair();
@@ -142,10 +141,6 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
         return sessionStore.loadSession(address);
     }
 
-    public List<SessionInfo> getSessions() {
-        return sessionStore.getSessions();
-    }
-
     @Override
     public List<Integer> getSubDeviceSessions(String name) {
         return sessionStore.getSubDeviceSessions(name);
@@ -171,19 +166,11 @@ public class JsonSignalProtocolStore implements SignalServiceProtocolStore {
         sessionStore.deleteAllSessions(name);
     }
 
-    public void deleteAllSessions(SignalServiceAddress serviceAddress) {
-        sessionStore.deleteAllSessions(serviceAddress);
-    }
-
     @Override
     public void archiveSession(final SignalProtocolAddress address) {
         sessionStore.archiveSession(address);
     }
 
-    public void archiveAllSessions() {
-        sessionStore.archiveAllSessions();
-    }
-
     @Override
     public SignedPreKeyRecord loadSignedPreKey(int signedPreKeyId) throws InvalidKeyIdException {
         return signedPreKeyStore.loadSignedPreKey(signedPreKeyId);
diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/protocol/LegacyJsonSessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/protocol/LegacyJsonSessionStore.java
new file mode 100644 (file)
index 0000000..c43bdb0
--- /dev/null
@@ -0,0 +1,60 @@
+package org.asamk.signal.manager.storage.protocol;
+
+import com.fasterxml.jackson.core.JsonParser;
+import com.fasterxml.jackson.databind.DeserializationContext;
+import com.fasterxml.jackson.databind.JsonDeserializer;
+import com.fasterxml.jackson.databind.JsonNode;
+
+import org.asamk.signal.manager.util.Utils;
+import org.whispersystems.signalservice.api.push.SignalServiceAddress;
+import org.whispersystems.signalservice.api.util.UuidUtil;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.List;
+
+public class LegacyJsonSessionStore {
+
+    private final List<SessionInfo> sessions = new ArrayList<>();
+
+    public LegacyJsonSessionStore() {
+    }
+
+    public List<SessionInfo> getSessions() {
+        return sessions;
+    }
+
+    public static class JsonSessionStoreDeserializer extends JsonDeserializer<LegacyJsonSessionStore> {
+
+        @Override
+        public LegacyJsonSessionStore deserialize(
+                JsonParser jsonParser, DeserializationContext deserializationContext
+        ) throws IOException {
+            JsonNode node = jsonParser.getCodec().readTree(jsonParser);
+
+            var sessionStore = new LegacyJsonSessionStore();
+
+            if (node.isArray()) {
+                for (var session : node) {
+                    var sessionName = session.hasNonNull("name") ? session.get("name").asText() : null;
+                    if (UuidUtil.isUuid(sessionName)) {
+                        // Ignore sessions that were incorrectly created with UUIDs as name
+                        continue;
+                    }
+
+                    var uuid = session.hasNonNull("uuid") ? UuidUtil.parseOrNull(session.get("uuid").asText()) : null;
+                    final var serviceAddress = uuid == null
+                            ? Utils.getSignalServiceAddressFromIdentifier(sessionName)
+                            : new SignalServiceAddress(uuid, sessionName);
+                    final var deviceId = session.get("deviceId").asInt();
+                    final var record = Base64.getDecoder().decode(session.get("record").asText());
+                    var sessionInfo = new SessionInfo(serviceAddress, deviceId, record);
+                    sessionStore.sessions.add(sessionInfo);
+                }
+            }
+
+            return sessionStore;
+        }
+    }
+}
index 96e2c6922d58b6c91fc394e75cd8e3a51b51f64c..9d22d672dcf428b5591c3c5e42c5d81fcbc3ab25 100644 (file)
@@ -8,7 +8,11 @@ public class RecipientId {
         this.id = id;
     }
 
-    long getId() {
+    public static RecipientId of(long id) {
+        return new RecipientId(id);
+    }
+
+    public long getId() {
         return id;
     }
 
diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientResolver.java b/lib/src/main/java/org/asamk/signal/manager/storage/recipients/RecipientResolver.java
new file mode 100644 (file)
index 0000000..c6d06d2
--- /dev/null
@@ -0,0 +1,8 @@
+package org.asamk.signal.manager.storage.recipients;
+
+import org.whispersystems.signalservice.api.push.SignalServiceAddress;
+
+public interface RecipientResolver {
+
+    RecipientId resolveRecipient(SignalServiceAddress address);
+}
diff --git a/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java b/lib/src/main/java/org/asamk/signal/manager/storage/sessions/SessionStore.java
new file mode 100644 (file)
index 0000000..0773af9
--- /dev/null
@@ -0,0 +1,313 @@
+package org.asamk.signal.manager.storage.sessions;
+
+import org.asamk.signal.manager.storage.recipients.RecipientId;
+import org.asamk.signal.manager.storage.recipients.RecipientResolver;
+import org.asamk.signal.manager.util.IOUtils;
+import org.asamk.signal.manager.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.whispersystems.libsignal.SignalProtocolAddress;
+import org.whispersystems.libsignal.protocol.CiphertextMessage;
+import org.whispersystems.libsignal.state.SessionRecord;
+import org.whispersystems.signalservice.api.SignalServiceSessionStore;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+
+public class SessionStore implements SignalServiceSessionStore {
+
+    private final static Logger logger = LoggerFactory.getLogger(SessionStore.class);
+
+    private final Map<Key, SessionRecord> cachedSessions = new HashMap<>();
+
+    private final File sessionsPath;
+
+    private final RecipientResolver resolver;
+
+    public SessionStore(
+            final File sessionsPath, final RecipientResolver resolver
+    ) {
+        this.sessionsPath = sessionsPath;
+        this.resolver = resolver;
+    }
+
+    @Override
+    public SessionRecord loadSession(SignalProtocolAddress address) {
+        final var key = getKey(address);
+
+        synchronized (cachedSessions) {
+            final var session = loadSessionLocked(key);
+            if (session == null) {
+                return new SessionRecord();
+            }
+            return session;
+        }
+    }
+
+    @Override
+    public List<Integer> getSubDeviceSessions(String name) {
+        final var recipientId = resolveRecipient(name);
+
+        synchronized (cachedSessions) {
+            return getKeysLocked(recipientId).stream()
+                    // get all sessions for recipient except main device session
+                    .filter(key -> key.getDeviceId() != 1 && key.getRecipientId().equals(recipientId))
+                    .map(Key::getDeviceId)
+                    .collect(Collectors.toList());
+        }
+    }
+
+    @Override
+    public void storeSession(SignalProtocolAddress address, SessionRecord session) {
+        final var key = getKey(address);
+
+        synchronized (cachedSessions) {
+            storeSessionLocked(key, session);
+        }
+    }
+
+    @Override
+    public boolean containsSession(SignalProtocolAddress address) {
+        final var key = getKey(address);
+
+        synchronized (cachedSessions) {
+            final var session = loadSessionLocked(key);
+            if (session == null) {
+                return false;
+            }
+
+            return session.hasSenderChain() && session.getSessionVersion() == CiphertextMessage.CURRENT_VERSION;
+        }
+    }
+
+    @Override
+    public void deleteSession(SignalProtocolAddress address) {
+        final var key = getKey(address);
+
+        synchronized (cachedSessions) {
+            deleteSessionLocked(key);
+        }
+    }
+
+    @Override
+    public void deleteAllSessions(String name) {
+        final var recipientId = resolveRecipient(name);
+        deleteAllSessions(recipientId);
+    }
+
+    public void deleteAllSessions(RecipientId recipientId) {
+        synchronized (cachedSessions) {
+            final var keys = getKeysLocked(recipientId);
+            for (var key : keys) {
+                deleteSessionLocked(key);
+            }
+        }
+    }
+
+    @Override
+    public void archiveSession(final SignalProtocolAddress address) {
+        final var key = getKey(address);
+
+        synchronized (cachedSessions) {
+            archiveSessionLocked(key);
+        }
+    }
+
+    public void archiveAllSessions() {
+        synchronized (cachedSessions) {
+            final var keys = getKeysLocked();
+            for (var key : keys) {
+                archiveSessionLocked(key);
+            }
+        }
+    }
+
+    public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
+        synchronized (cachedSessions) {
+            final var otherHasSession = getKeysLocked(toBeMergedRecipientId).size() > 0;
+            if (!otherHasSession) {
+                return;
+            }
+
+            final var hasSession = getKeysLocked(recipientId).size() > 0;
+            if (hasSession) {
+                logger.debug("To be merged recipient had sessions, deleting.");
+                deleteAllSessions(toBeMergedRecipientId);
+            } else {
+                logger.debug("To be merged recipient had sessions, re-assigning to the new recipient.");
+                final var keys = getKeysLocked(toBeMergedRecipientId);
+                for (var key : keys) {
+                    final var session = loadSessionLocked(key);
+                    deleteSessionLocked(key);
+                    if (session == null) {
+                        continue;
+                    }
+                    final var newKey = new Key(recipientId, key.getDeviceId());
+                    storeSessionLocked(newKey, session);
+                }
+            }
+        }
+    }
+
+    /**
+     * @param identifier can be either a serialized uuid or a e164 phone number
+     */
+    private RecipientId resolveRecipient(String identifier) {
+        return resolver.resolveRecipient(Utils.getSignalServiceAddressFromIdentifier(identifier));
+    }
+
+    private Key getKey(final SignalProtocolAddress address) {
+        final var recipientId = resolveRecipient(address.getName());
+        return new Key(recipientId, address.getDeviceId());
+    }
+
+    private List<Key> getKeysLocked(RecipientId recipientId) {
+        final var files = sessionsPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
+        if (files == null) {
+            return List.of();
+        }
+        return parseFileNames(files);
+    }
+
+    private Collection<Key> getKeysLocked() {
+        final var files = sessionsPath.listFiles();
+        if (files == null) {
+            return List.of();
+        }
+        return parseFileNames(files);
+    }
+
+    final Pattern sessionFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)");
+
+    private List<Key> parseFileNames(final File[] files) {
+        return Arrays.stream(files)
+                .map(f -> sessionFileNamePattern.matcher(f.getName()))
+                .filter(Matcher::matches)
+                .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
+                        Integer.parseInt(matcher.group(2))))
+                .collect(Collectors.toList());
+    }
+
+    private File getSessionPath(Key key) {
+        try {
+            IOUtils.createPrivateDirectories(sessionsPath);
+        } catch (IOException e) {
+            throw new AssertionError("Failed to create sessions path", e);
+        }
+        return new File(sessionsPath, key.getRecipientId().getId() + "_" + key.getDeviceId());
+    }
+
+    private SessionRecord loadSessionLocked(final Key key) {
+        {
+            final var session = cachedSessions.get(key);
+            if (session != null) {
+                return session;
+            }
+        }
+
+        final var file = getSessionPath(key);
+        if (!file.exists()) {
+            return null;
+        }
+        try (var inputStream = new FileInputStream(file)) {
+            final var session = new SessionRecord(inputStream.readAllBytes());
+            cachedSessions.put(key, session);
+            return session;
+        } catch (IOException e) {
+            logger.warn("Failed to load session, resetting session: {}", e.getMessage());
+            return null;
+        }
+    }
+
+    private void storeSessionLocked(final Key key, final SessionRecord session) {
+        cachedSessions.put(key, session);
+
+        final var file = getSessionPath(key);
+        try {
+            try (var outputStream = new FileOutputStream(file)) {
+                outputStream.write(session.serialize());
+            }
+        } catch (IOException e) {
+            logger.warn("Failed to store session, trying to delete file and retry: {}", e.getMessage());
+            try {
+                Files.delete(file.toPath());
+                try (var outputStream = new FileOutputStream(file)) {
+                    outputStream.write(session.serialize());
+                }
+            } catch (IOException e2) {
+                logger.error("Failed to store session file {}: {}", file, e2.getMessage());
+            }
+        }
+    }
+
+    private void archiveSessionLocked(final Key key) {
+        final var session = loadSessionLocked(key);
+        if (session == null) {
+            return;
+        }
+        session.archiveCurrentState();
+        storeSessionLocked(key, session);
+    }
+
+    private void deleteSessionLocked(final Key key) {
+        cachedSessions.remove(key);
+
+        final var file = getSessionPath(key);
+        if (!file.exists()) {
+            return;
+        }
+        try {
+            Files.delete(file.toPath());
+        } catch (IOException e) {
+            logger.error("Failed to delete session file {}: {}", file, e.getMessage());
+        }
+    }
+
+    private static final class Key {
+
+        private final RecipientId recipientId;
+        private final int deviceId;
+
+        public Key(final RecipientId recipientId, final int deviceId) {
+            this.recipientId = recipientId;
+            this.deviceId = deviceId;
+        }
+
+        public RecipientId getRecipientId() {
+            return recipientId;
+        }
+
+        public int getDeviceId() {
+            return deviceId;
+        }
+
+        @Override
+        public boolean equals(final Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            final var key = (Key) o;
+
+            if (deviceId != key.deviceId) return false;
+            return recipientId.equals(key.recipientId);
+        }
+
+        @Override
+        public int hashCode() {
+            int result = recipientId.hashCode();
+            result = 31 * result + deviceId;
+            return result;
+        }
+    }
+}