Skip to content

Commit c0ddca6

Browse files
committed
[java] Tapping the Node session when there is WebSocket activity
Fixes #12223
1 parent 19a1813 commit c0ddca6

File tree

1 file changed

+67
-27
lines changed

1 file changed

+67
-27
lines changed

java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.openqa.selenium.Capabilities;
3434
import org.openqa.selenium.devtools.CdpEndpointFinder;
3535
import org.openqa.selenium.grid.data.Session;
36+
import org.openqa.selenium.internal.Require;
3637
import org.openqa.selenium.remote.SessionId;
3738
import org.openqa.selenium.remote.http.BinaryMessage;
3839
import org.openqa.selenium.remote.http.ClientConfig;
@@ -75,57 +76,68 @@ public Optional<Consumer<Message>> apply(String uri, Consumer<Message> downstrea
7576
return Optional.empty();
7677
}
7778

78-
String sessionId =
79-
Stream.of(fwdMatch, cdpMatch, bidiMatch, vncMatch)
80-
.filter(Objects::nonNull)
81-
.findFirst()
82-
.get()
83-
.getParameters()
84-
.get("sessionId");
79+
Optional<UrlTemplate.Match> firstMatch =
80+
Stream.of(fwdMatch, cdpMatch, bidiMatch, vncMatch).filter(Objects::nonNull).findFirst();
81+
82+
if (firstMatch.isEmpty()) {
83+
LOG.warning("No session id found in uri " + uri);
84+
return Optional.empty();
85+
}
86+
87+
String sessionId = firstMatch.get().getParameters().get("sessionId");
8588

8689
LOG.fine("Matching websockets for session id: " + sessionId);
8790
SessionId id = new SessionId(sessionId);
8891

8992
if (!node.isSessionOwner(id)) {
90-
LOG.info("Not owner of " + id);
93+
LOG.warning("Not owner of " + id);
9194
return Optional.empty();
9295
}
9396

9497
Session session = node.getSession(id);
9598
Capabilities caps = session.getCapabilities();
9699
LOG.fine("Scanning for endpoint: " + caps);
97100

101+
// Used by the ForwardingListener to notify the node that the session is still active
102+
Consumer<SessionId> sessionConsumer = node::isSessionOwner;
103+
98104
if (bidiMatch != null) {
99-
return findBiDiEndpoint(downstream, caps);
105+
return findBiDiEndpoint(downstream, caps, sessionConsumer, id);
100106
}
101107

102108
if (vncMatch != null) {
103-
return findVncEndpoint(downstream, caps);
109+
// Passing a fake consumer to the ForwardingListener to avoid sending a session notification
110+
// when VNC is used.
111+
sessionConsumer = fakeConsumer -> {};
112+
return findVncEndpoint(downstream, caps, sessionConsumer, id);
104113
}
105114

106115
// This match happens when a user wants to do CDP over Dynamic Grid
107116
if (fwdMatch != null) {
108117
LOG.info("Matched endpoint where CDP connection is being forwarded");
109-
return findCdpEndpoint(downstream, caps);
118+
return findCdpEndpoint(downstream, caps, sessionConsumer, id);
110119
}
111120
if (caps.getCapabilityNames().contains("se:forwardCdp")) {
112121
LOG.info("Found endpoint where CDP connection needs to be forwarded");
113-
return findForwardCdpEndpoint(downstream, caps);
122+
return findForwardCdpEndpoint(downstream, caps, sessionConsumer, id);
114123
}
115-
return findCdpEndpoint(downstream, caps);
124+
return findCdpEndpoint(downstream, caps, sessionConsumer, id);
116125
}
117126

118127
private Optional<Consumer<Message>> findCdpEndpoint(
119-
Consumer<Message> downstream, Capabilities caps) {
120-
// Using strings here to avoid Node depending upon specific drivers.
128+
Consumer<Message> downstream,
129+
Capabilities caps,
130+
Consumer<SessionId> sessionConsumer,
131+
SessionId sessionId) {
132+
121133
for (String cdpEndpointCap : CDP_ENDPOINT_CAPS) {
122134
Optional<URI> reportedUri = CdpEndpointFinder.getReportedUri(cdpEndpointCap, caps);
123135
Optional<HttpClient> client =
124136
reportedUri.map(uri -> CdpEndpointFinder.getHttpClient(clientFactory, uri));
125137
Optional<URI> cdpUri;
126138

127139
try {
128-
cdpUri = client.flatMap(httpClient -> CdpEndpointFinder.getCdpEndPoint(httpClient));
140+
cdpUri = client.flatMap(CdpEndpointFinder::getCdpEndPoint);
129141
} catch (Exception e) {
130142
try {
131143
client.ifPresent(HttpClient::close);
@@ -137,7 +149,7 @@ private Optional<Consumer<Message>> findCdpEndpoint(
137149

138150
if (cdpUri.isPresent()) {
139151
LOG.log(getDebugLogLevel(), String.format("Endpoint found in %s", cdpEndpointCap));
140-
return cdpUri.map(cdp -> createWsEndPoint(cdp, downstream));
152+
return cdpUri.map(cdp -> createWsEndPoint(cdp, downstream, sessionConsumer, sessionId));
141153
} else {
142154
try {
143155
client.ifPresent(HttpClient::close);
@@ -154,30 +166,41 @@ private Optional<Consumer<Message>> findCdpEndpoint(
154166
}
155167

156168
private Optional<Consumer<Message>> findBiDiEndpoint(
157-
Consumer<Message> downstream, Capabilities caps) {
169+
Consumer<Message> downstream,
170+
Capabilities caps,
171+
Consumer<SessionId> sessionConsumer,
172+
SessionId sessionId) {
158173
try {
159174
URI uri = new URI(String.valueOf(caps.getCapability("webSocketUrl")));
160-
return Optional.of(uri).map(bidi -> createWsEndPoint(bidi, downstream));
175+
return Optional.of(uri)
176+
.map(bidi -> createWsEndPoint(bidi, downstream, sessionConsumer, sessionId));
161177
} catch (URISyntaxException e) {
162178
LOG.warning("Unable to create URI from: " + caps.getCapability("webSocketUrl"));
163179
return Optional.empty();
164180
}
165181
}
166182

167183
private Optional<Consumer<Message>> findForwardCdpEndpoint(
168-
Consumer<Message> downstream, Capabilities caps) {
184+
Consumer<Message> downstream,
185+
Capabilities caps,
186+
Consumer<SessionId> sessionConsumer,
187+
SessionId sessionId) {
169188
// When using Dynamic Grid, we need to connect to a container before using the debuggerAddress
170189
try {
171190
URI uri = new URI(String.valueOf(caps.getCapability("se:forwardCdp")));
172-
return Optional.of(uri).map(cdp -> createWsEndPoint(cdp, downstream));
191+
return Optional.of(uri)
192+
.map(cdp -> createWsEndPoint(cdp, downstream, sessionConsumer, sessionId));
173193
} catch (URISyntaxException e) {
174194
LOG.warning("Unable to create URI from: " + caps.getCapability("se:forwardCdp"));
175195
return Optional.empty();
176196
}
177197
}
178198

179199
private Optional<Consumer<Message>> findVncEndpoint(
180-
Consumer<Message> downstream, Capabilities caps) {
200+
Consumer<Message> downstream,
201+
Capabilities caps,
202+
Consumer<SessionId> sessionConsumer,
203+
SessionId sessionId) {
181204
String vncLocalAddress = (String) caps.getCapability("se:vncLocalAddress");
182205
Optional<URI> vncUri;
183206
try {
@@ -187,40 +210,57 @@ private Optional<Consumer<Message>> findVncEndpoint(
187210
return Optional.empty();
188211
}
189212
LOG.log(getDebugLogLevel(), String.format("Endpoint found in %s", "se:vncLocalAddress"));
190-
return vncUri.map(vnc -> createWsEndPoint(vnc, downstream));
213+
return vncUri.map(vnc -> createWsEndPoint(vnc, downstream, sessionConsumer, sessionId));
191214
}
192215

193-
private Consumer<Message> createWsEndPoint(URI uri, Consumer<Message> downstream) {
194-
Objects.requireNonNull(uri);
216+
private Consumer<Message> createWsEndPoint(
217+
URI uri,
218+
Consumer<Message> downstream,
219+
Consumer<SessionId> sessionConsumer,
220+
SessionId sessionId) {
221+
Require.nonNull("downstream", downstream);
222+
Require.nonNull("uri", uri);
223+
Require.nonNull("sessionConsumer", sessionConsumer);
224+
Require.nonNull("sessionId", sessionId);
195225

196226
LOG.info("Establishing connection to " + uri);
197227

198228
HttpClient client = clientFactory.createClient(ClientConfig.defaultConfig().baseUri(uri));
199229
WebSocket upstream =
200-
client.openSocket(new HttpRequest(GET, uri.toString()), new ForwardingListener(downstream));
230+
client.openSocket(
231+
new HttpRequest(GET, uri.toString()),
232+
new ForwardingListener(downstream, sessionConsumer, sessionId));
201233
return upstream::send;
202234
}
203235

204236
private static class ForwardingListener implements WebSocket.Listener {
205237
private final Consumer<Message> downstream;
238+
private final Consumer<SessionId> sessionConsumer;
239+
private final SessionId sessionId;
206240

207-
public ForwardingListener(Consumer<Message> downstream) {
241+
public ForwardingListener(
242+
Consumer<Message> downstream, Consumer<SessionId> sessionConsumer, SessionId sessionId) {
208243
this.downstream = Objects.requireNonNull(downstream);
244+
this.sessionConsumer = Objects.requireNonNull(sessionConsumer);
245+
this.sessionId = Objects.requireNonNull(sessionId);
209246
}
210247

211248
@Override
212249
public void onBinary(byte[] data) {
213250
downstream.accept(new BinaryMessage(data));
251+
sessionConsumer.accept(sessionId);
214252
}
215253

216254
@Override
217255
public void onClose(int code, String reason) {
218256
downstream.accept(new CloseMessage(code, reason));
257+
sessionConsumer.accept(sessionId);
219258
}
220259

221260
@Override
222261
public void onText(CharSequence data) {
223262
downstream.accept(new TextMessage(data));
263+
sessionConsumer.accept(sessionId);
224264
}
225265

226266
@Override

0 commit comments

Comments
 (0)