]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/storage/senderKeys/SenderKeyRecordStore.java
f84903e44765fd011a3f4b2ef78ae693157b8e79
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / storage / senderKeys / SenderKeyRecordStore.java
1 package org.asamk.signal.manager.storage.senderKeys;
2
3 import org.asamk.signal.manager.storage.recipients.RecipientId;
4 import org.asamk.signal.manager.storage.recipients.RecipientResolver;
5 import org.asamk.signal.manager.util.IOUtils;
6 import org.slf4j.Logger;
7 import org.slf4j.LoggerFactory;
8 import org.whispersystems.libsignal.SignalProtocolAddress;
9 import org.whispersystems.libsignal.groups.state.SenderKeyRecord;
10
11 import java.io.File;
12 import java.io.FileInputStream;
13 import java.io.FileOutputStream;
14 import java.io.IOException;
15 import java.nio.file.Files;
16 import java.util.Arrays;
17 import java.util.HashMap;
18 import java.util.List;
19 import java.util.Map;
20 import java.util.UUID;
21 import java.util.regex.Matcher;
22 import java.util.regex.Pattern;
23 import java.util.stream.Collectors;
24
25 public class SenderKeyRecordStore implements org.whispersystems.libsignal.groups.state.SenderKeyStore {
26
27 private final static Logger logger = LoggerFactory.getLogger(SenderKeyRecordStore.class);
28
29 private final Map<Key, SenderKeyRecord> cachedSenderKeys = new HashMap<>();
30
31 private final File senderKeysPath;
32
33 private final RecipientResolver resolver;
34
35 public SenderKeyRecordStore(
36 final File senderKeysPath, final RecipientResolver resolver
37 ) {
38 this.senderKeysPath = senderKeysPath;
39 this.resolver = resolver;
40 }
41
42 @Override
43 public SenderKeyRecord loadSenderKey(final SignalProtocolAddress address, final UUID distributionId) {
44 final var key = getKey(address, distributionId);
45
46 synchronized (cachedSenderKeys) {
47 return loadSenderKeyLocked(key);
48 }
49 }
50
51 @Override
52 public void storeSenderKey(
53 final SignalProtocolAddress address, final UUID distributionId, final SenderKeyRecord record
54 ) {
55 final var key = getKey(address, distributionId);
56
57 synchronized (cachedSenderKeys) {
58 storeSenderKeyLocked(key, record);
59 }
60 }
61
62 public void deleteAll() {
63 synchronized (cachedSenderKeys) {
64 cachedSenderKeys.clear();
65 final var files = senderKeysPath.listFiles((_file, s) -> senderKeyFileNamePattern.matcher(s).matches());
66 if (files == null) {
67 return;
68 }
69
70 for (final var file : files) {
71 try {
72 Files.delete(file.toPath());
73 } catch (IOException e) {
74 logger.error("Failed to delete sender key file {}: {}", file, e.getMessage());
75 }
76 }
77 }
78 }
79
80 public void deleteAllFor(final RecipientId recipientId) {
81 synchronized (cachedSenderKeys) {
82 cachedSenderKeys.clear();
83 final var keys = getKeysLocked(recipientId);
84 for (var key : keys) {
85 deleteSenderKeyLocked(key);
86 }
87 }
88 }
89
90 public void mergeRecipients(RecipientId recipientId, RecipientId toBeMergedRecipientId) {
91 synchronized (cachedSenderKeys) {
92 final var keys = getKeysLocked(toBeMergedRecipientId);
93 final var otherHasSenderKeys = keys.size() > 0;
94 if (!otherHasSenderKeys) {
95 return;
96 }
97
98 logger.debug("Only to be merged recipient had sender keys, re-assigning to the new recipient.");
99 for (var key : keys) {
100 final var toBeMergedSenderKey = loadSenderKeyLocked(key);
101 deleteSenderKeyLocked(key);
102 if (toBeMergedSenderKey == null) {
103 continue;
104 }
105
106 final var newKey = new Key(recipientId, key.getDeviceId(), key.distributionId);
107 final var senderKeyRecord = loadSenderKeyLocked(newKey);
108 if (senderKeyRecord != null) {
109 continue;
110 }
111 storeSenderKeyLocked(newKey, senderKeyRecord);
112 }
113 }
114 }
115
116 /**
117 * @param identifier can be either a serialized uuid or a e164 phone number
118 */
119 private RecipientId resolveRecipient(String identifier) {
120 return resolver.resolveRecipient(identifier);
121 }
122
123 private Key getKey(final SignalProtocolAddress address, final UUID distributionId) {
124 final var recipientId = resolveRecipient(address.getName());
125 return new Key(recipientId, address.getDeviceId(), distributionId);
126 }
127
128 private List<Key> getKeysLocked(RecipientId recipientId) {
129 final var files = senderKeysPath.listFiles((_file, s) -> s.startsWith(recipientId.getId() + "_"));
130 if (files == null) {
131 return List.of();
132 }
133 return parseFileNames(files);
134 }
135
136 final Pattern senderKeyFileNamePattern = Pattern.compile("([0-9]+)_([0-9]+)_([0-9a-z\\-]+)");
137
138 private List<Key> parseFileNames(final File[] files) {
139 return Arrays.stream(files)
140 .map(f -> senderKeyFileNamePattern.matcher(f.getName()))
141 .filter(Matcher::matches)
142 .map(matcher -> new Key(RecipientId.of(Long.parseLong(matcher.group(1))),
143 Integer.parseInt(matcher.group(2)),
144 UUID.fromString(matcher.group(3))))
145 .collect(Collectors.toList());
146 }
147
148 private File getSenderKeyFile(Key key) {
149 try {
150 IOUtils.createPrivateDirectories(senderKeysPath);
151 } catch (IOException e) {
152 throw new AssertionError("Failed to create sender keys path", e);
153 }
154 return new File(senderKeysPath,
155 key.getRecipientId().getId() + "_" + key.getDeviceId() + "_" + key.distributionId.toString());
156 }
157
158 private SenderKeyRecord loadSenderKeyLocked(final Key key) {
159 {
160 final var senderKeyRecord = cachedSenderKeys.get(key);
161 if (senderKeyRecord != null) {
162 return senderKeyRecord;
163 }
164 }
165
166 final var file = getSenderKeyFile(key);
167 if (!file.exists()) {
168 return null;
169 }
170 try (var inputStream = new FileInputStream(file)) {
171 final var senderKeyRecord = new SenderKeyRecord(inputStream.readAllBytes());
172 cachedSenderKeys.put(key, senderKeyRecord);
173 return senderKeyRecord;
174 } catch (IOException e) {
175 logger.warn("Failed to load sender key, resetting sender key: {}", e.getMessage());
176 return null;
177 }
178 }
179
180 private void storeSenderKeyLocked(final Key key, final SenderKeyRecord senderKeyRecord) {
181 cachedSenderKeys.put(key, senderKeyRecord);
182
183 final var file = getSenderKeyFile(key);
184 try {
185 try (var outputStream = new FileOutputStream(file)) {
186 outputStream.write(senderKeyRecord.serialize());
187 }
188 } catch (IOException e) {
189 logger.warn("Failed to store sender key, trying to delete file and retry: {}", e.getMessage());
190 try {
191 Files.delete(file.toPath());
192 try (var outputStream = new FileOutputStream(file)) {
193 outputStream.write(senderKeyRecord.serialize());
194 }
195 } catch (IOException e2) {
196 logger.error("Failed to store sender key file {}: {}", file, e2.getMessage());
197 }
198 }
199 }
200
201 private void deleteSenderKeyLocked(final Key key) {
202 cachedSenderKeys.remove(key);
203
204 final var file = getSenderKeyFile(key);
205 if (!file.exists()) {
206 return;
207 }
208 try {
209 Files.delete(file.toPath());
210 } catch (IOException e) {
211 logger.error("Failed to delete sender key file {}: {}", file, e.getMessage());
212 }
213 }
214
215 private static final class Key {
216
217 private final RecipientId recipientId;
218 private final int deviceId;
219 private final UUID distributionId;
220
221 public Key(
222 final RecipientId recipientId, final int deviceId, final UUID distributionId
223 ) {
224 this.recipientId = recipientId;
225 this.deviceId = deviceId;
226 this.distributionId = distributionId;
227 }
228
229 public RecipientId getRecipientId() {
230 return recipientId;
231 }
232
233 public int getDeviceId() {
234 return deviceId;
235 }
236
237 public UUID getDistributionId() {
238 return distributionId;
239 }
240
241 @Override
242 public boolean equals(final Object o) {
243 if (this == o) return true;
244 if (o == null || getClass() != o.getClass()) return false;
245
246 final Key key = (Key) o;
247
248 if (deviceId != key.deviceId) return false;
249 if (!recipientId.equals(key.recipientId)) return false;
250 return distributionId.equals(key.distributionId);
251 }
252
253 @Override
254 public int hashCode() {
255 int result = recipientId.hashCode();
256 result = 31 * result + deviceId;
257 result = 31 * result + distributionId.hashCode();
258 return result;
259 }
260 }
261 }