]> nmode's Git Repositories - signal-cli/blob - lib/src/main/java/org/asamk/signal/manager/SignalWebSocketHealthMonitor.java
Reconnect websockets after errors
[signal-cli] / lib / src / main / java / org / asamk / signal / manager / SignalWebSocketHealthMonitor.java
1 package org.asamk.signal.manager;
2
3 import org.slf4j.Logger;
4 import org.slf4j.LoggerFactory;
5 import org.whispersystems.libsignal.util.guava.Preconditions;
6 import org.whispersystems.signalservice.api.SignalWebSocket;
7 import org.whispersystems.signalservice.api.util.SleepTimer;
8 import org.whispersystems.signalservice.api.websocket.HealthMonitor;
9 import org.whispersystems.signalservice.api.websocket.WebSocketConnectionState;
10 import org.whispersystems.signalservice.internal.websocket.WebSocketConnection;
11
12 import java.util.Arrays;
13 import java.util.concurrent.TimeUnit;
14
15 import io.reactivex.rxjava3.schedulers.Schedulers;
16
17 /**
18 * Monitors the health of the identified and unidentified WebSockets. If either one appears to be
19 * unhealthy, will trigger restarting both.
20 * <p>
21 * The monitor is also responsible for sending heartbeats/keep-alive messages to prevent
22 * timeouts.
23 */
24 public final class SignalWebSocketHealthMonitor implements HealthMonitor {
25
26 private final static Logger logger = LoggerFactory.getLogger(SignalWebSocketHealthMonitor.class);
27
28 private static final long KEEP_ALIVE_SEND_CADENCE = TimeUnit.SECONDS.toMillis(WebSocketConnection.KEEPALIVE_TIMEOUT_SECONDS);
29 private static final long MAX_TIME_SINCE_SUCCESSFUL_KEEP_ALIVE = KEEP_ALIVE_SEND_CADENCE * 3;
30
31 private SignalWebSocket signalWebSocket;
32 private final SleepTimer sleepTimer;
33
34 private volatile KeepAliveSender keepAliveSender;
35
36 private final HealthState identified = new HealthState();
37 private final HealthState unidentified = new HealthState();
38
39 public SignalWebSocketHealthMonitor(SleepTimer sleepTimer) {
40 this.sleepTimer = sleepTimer;
41 }
42
43 public void monitor(SignalWebSocket signalWebSocket) {
44 Preconditions.checkNotNull(signalWebSocket);
45 Preconditions.checkArgument(this.signalWebSocket == null, "monitor can only be called once");
46
47 this.signalWebSocket = signalWebSocket;
48
49 //noinspection ResultOfMethodCallIgnored
50 signalWebSocket.getWebSocketState()
51 .subscribeOn(Schedulers.computation())
52 .observeOn(Schedulers.computation())
53 .distinctUntilChanged()
54 .subscribe(s -> onStateChange(s, identified));
55
56 //noinspection ResultOfMethodCallIgnored
57 signalWebSocket.getUnidentifiedWebSocketState()
58 .subscribeOn(Schedulers.computation())
59 .observeOn(Schedulers.computation())
60 .distinctUntilChanged()
61 .subscribe(s -> onStateChange(s, unidentified));
62 }
63
64 private synchronized void onStateChange(WebSocketConnectionState connectionState, HealthState healthState) {
65 switch (connectionState) {
66 case CONNECTED:
67 logger.debug("WebSocket is now connected");
68 break;
69 case AUTHENTICATION_FAILED:
70 logger.debug("WebSocket authentication failed");
71 break;
72 case FAILED:
73 logger.debug("WebSocket connection failed");
74 break;
75 }
76
77 healthState.needsKeepAlive = connectionState == WebSocketConnectionState.CONNECTED;
78
79 if (keepAliveSender == null && isKeepAliveNecessary()) {
80 keepAliveSender = new KeepAliveSender();
81 keepAliveSender.start();
82 } else if (keepAliveSender != null && !isKeepAliveNecessary()) {
83 keepAliveSender.shutdown();
84 keepAliveSender = null;
85 }
86 }
87
88 @Override
89 public void onKeepAliveResponse(long sentTimestamp, boolean isIdentifiedWebSocket) {
90 if (isIdentifiedWebSocket) {
91 identified.lastKeepAliveReceived = System.currentTimeMillis();
92 } else {
93 unidentified.lastKeepAliveReceived = System.currentTimeMillis();
94 }
95 }
96
97 @Override
98 public void onMessageError(int status, boolean isIdentifiedWebSocket) {
99 if (status == 409) {
100 HealthState healthState = (isIdentifiedWebSocket ? identified : unidentified);
101 if (healthState.mismatchErrorTracker.addSample(System.currentTimeMillis())) {
102 logger.warn("Received too many mismatch device errors, forcing new websockets.");
103 signalWebSocket.forceNewWebSockets();
104 signalWebSocket.connect();
105 }
106 }
107 }
108
109 private boolean isKeepAliveNecessary() {
110 return identified.needsKeepAlive || unidentified.needsKeepAlive;
111 }
112
113 private static class HealthState {
114
115 private final HttpErrorTracker mismatchErrorTracker = new HttpErrorTracker(5, TimeUnit.MINUTES.toMillis(1));
116
117 private volatile boolean needsKeepAlive;
118 private volatile long lastKeepAliveReceived;
119 }
120
121 /**
122 * Sends periodic heartbeats/keep-alives over both WebSockets to prevent connection timeouts. If
123 * either WebSocket fails 3 times to get a return heartbeat both are forced to be recreated.
124 */
125 private class KeepAliveSender extends Thread {
126
127 private volatile boolean shouldKeepRunning = true;
128
129 public void run() {
130 identified.lastKeepAliveReceived = System.currentTimeMillis();
131 unidentified.lastKeepAliveReceived = System.currentTimeMillis();
132
133 while (shouldKeepRunning && isKeepAliveNecessary()) {
134 try {
135 sleepTimer.sleep(KEEP_ALIVE_SEND_CADENCE);
136
137 if (shouldKeepRunning && isKeepAliveNecessary()) {
138 long keepAliveRequiredSinceTime = System.currentTimeMillis()
139 - MAX_TIME_SINCE_SUCCESSFUL_KEEP_ALIVE;
140
141 if (identified.lastKeepAliveReceived < keepAliveRequiredSinceTime
142 || unidentified.lastKeepAliveReceived < keepAliveRequiredSinceTime) {
143 logger.warn("Missed keep alives, identified last: "
144 + identified.lastKeepAliveReceived
145 + " unidentified last: "
146 + unidentified.lastKeepAliveReceived
147 + " needed by: "
148 + keepAliveRequiredSinceTime);
149 signalWebSocket.forceNewWebSockets();
150 signalWebSocket.connect();
151 } else {
152 signalWebSocket.sendKeepAlive();
153 }
154 }
155 } catch (Throwable e) {
156 logger.warn("Error occured in KeepAliveSender, ignoring ...", e);
157 }
158 }
159 }
160
161 public void shutdown() {
162 shouldKeepRunning = false;
163 }
164 }
165
166 private final static class HttpErrorTracker {
167
168 private final long[] timestamps;
169 private final long errorTimeRange;
170
171 public HttpErrorTracker(int samples, long errorTimeRange) {
172 this.timestamps = new long[samples];
173 this.errorTimeRange = errorTimeRange;
174 }
175
176 public synchronized boolean addSample(long now) {
177 long errorsMustBeAfter = now - errorTimeRange;
178 int count = 1;
179 int minIndex = 0;
180
181 for (int i = 0; i < timestamps.length; i++) {
182 if (timestamps[i] < errorsMustBeAfter) {
183 timestamps[i] = 0;
184 } else if (timestamps[i] != 0) {
185 count++;
186 }
187
188 if (timestamps[i] < timestamps[minIndex]) {
189 minIndex = i;
190 }
191 }
192
193 timestamps[minIndex] = now;
194
195 if (count >= timestamps.length) {
196 Arrays.fill(timestamps, 0);
197 return true;
198 }
199 return false;
200 }
201 }
202 }