Reputation: 69997
I'm implementing a web application which, among other things, has to show and interact with webpages proxied to backend services. For this, I'm using the HTTP-Proxy-Servlet which works well most of the time.
However, certain backend services' webpages use websockets and the proxy servlet above doesn't support websockets.
I tried implementing it by reconstructing the websocket call towards the backend and then copying between streams, but that doesn't work. The browser reports "Invalid frame header" and Tomcat fails with
Error parsing HTTP request header
Invalid character found in method name. HTTP method names must be tokens
at org.apache.coyote.http11.Http11InputBuffer.parseRequestLine(Http11InputBuffer.java:414)
My code:
import java.io.IOException;
import java.net.*;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.*;
import javax.servlet.ServletException;
import javax.servlet.http.*;
import org.apache.http.HttpRequest;
import org.mitre.dsmiley.httpproxy.ProxyServlet;
public class ProxyWithWebSocket extends ProxyServlet {
private static final long serialVersionUID = -2566573965489129976L;
protected ExecutorService exec;
@Override
public void init() throws ServletException {
super.init();
exec = Executors.newCachedThreadPool();
}
@Override
public void destroy() {
super.destroy();
exec.shutdown();
}
@Override
protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
throws ServletException, IOException {
var wsKey = servletRequest.getHeader("Sec-WebSocket-Key");
if (wsKey != null) {
//initialize request attributes from caches if unset by a subclass by this point
if (servletRequest.getAttribute(ATTR_TARGET_URI) == null) {
servletRequest.setAttribute(ATTR_TARGET_URI, targetUri);
}
if (servletRequest.getAttribute(ATTR_TARGET_HOST) == null) {
servletRequest.setAttribute(ATTR_TARGET_HOST, targetHost);
}
String proxyRequestUri = rewriteUrlFromRequest(servletRequest);
URL u = new URL(proxyRequestUri);
var servletIn = servletRequest.getInputStream();
var servletOut = servletResponse.getOutputStream();
try (Socket sock = new Socket(u.getHost(), u.getPort())) {
var sockIn = sock.getInputStream();
var sockOut = sock.getOutputStream();
StringBuilder req = new StringBuilder(512);
req.append("GET " + u.getFile()).append(" HTTP/1.1");
System.out.println(" > WS|" + req);
req.append("\r\n");
var en = servletRequest.getHeaderNames();
while (en.hasMoreElements()) {
var n = en.nextElement();
String header = servletRequest.getHeader(n);
System.out.println(" > WS| " + n + ": " + header);
req.append(n + ": " + header + "\r\n");
}
req.append("\r\n");
sockOut.write(req.toString().getBytes(StandardCharsets.UTF_8));
sockOut.flush();
StringBuilder responseBytes = new StringBuilder(512);
int b = 0;
while (b != -1) {
b = sockIn.read();
if (b != -1) {
responseBytes.append((char)b);
var len = responseBytes.length();
if (len >= 4
&& responseBytes.charAt(len - 4) == '\r'
&& responseBytes.charAt(len - 3) == '\n'
&& responseBytes.charAt(len - 2) == '\r'
&& responseBytes.charAt(len - 1) == '\n'
) {
break;
}
}
}
String[] rows = responseBytes.toString().split("\r\n");
String response = rows[0];
System.out.println(" < WS|" + response);
int idx1 = response.indexOf(' ');
int idx2 = response.indexOf(' ', idx1 + 1);
for (int i = 1; i < rows.length; i++) {
String line = rows[i];
int idx3 = line.indexOf(":");
var k = line.substring(0, idx3);
var headerField = line.substring(idx3 + 2);
System.out.println(" < WS| " + k + ": " + headerField);
servletResponse.setHeader(k, headerField);
}
servletResponse.setStatus(Integer.parseInt(response.substring(idx1 + 1, idx2)));
servletResponse.flushBuffer();
System.out.println(" < WS| Flush");
var f1 = exec.submit(() -> {
var c = 0;
var bs = 0;
while ((bs = servletIn.read()) != -1) {
sockOut.write(bs);
c++;
}
System.out.println(" > WS| Done: " + c);
return null;
});
var f2 = exec.submit(() -> {
var c = 0;
var bs = 0;
while ((bs = sockIn.read()) != -1) {
servletOut.write(bs);
servletOut.flush();
c++;
}
System.out.println(" < WS| Done: " + c);
return null;
});
try {
f1.get();
} catch (Exception ex) {
f2.cancel(true);
return;
}
try {
f2.get();
} catch (Exception ex) {
}
}
} else {
super.service(servletRequest, servletResponse);
}
}
}
A typical exchange looks like this (via those println):
> WS|GET /cellhub?id=NhWO8SnGyDb_Vrk23rmhVQ HTTP/1.1
> WS| host: localhost:8080
> WS| connection: Upgrade
> WS| pragma: no-cache
> WS| cache-control: no-cache
> WS| user-agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.71 Safari/537.36
> WS| upgrade: websocket
> WS| origin: http://localhost:8080
> WS| sec-websocket-version: 13
> WS| accept-encoding: gzip, deflate, br
> WS| accept-language: hu,hu-HU;q=0.9,en-US;q=0.8,en;q=0.7
> WS| cookie: JSESSIONID=57E4B30452BC3EB2657139DAF70E65AD; JSESSIONID=AD5E7BB5FE17B4072F3ABEE32B9479AC
> WS| sec-websocket-key: nrZWEb6Co4DKggUNwPeV8g==
> WS| sec-websocket-extensions: permessage-deflate; client_max_window_bits
< WS|HTTP/1.1 101 Switching Protocols
< WS| Connection: Upgrade
< WS| Date: Thu, 07 Oct 2021 13:18:41 GMT
< WS| Server: Kestrel
< WS| Upgrade: websocket
< WS| Sec-WebSocket-Accept: /9uN8ZF67WepGJQ3+DPBLMCBotc=
< WS| Flush
> WS| Done: 0
< WS| Done: 42
How can I make this work?
Edit
I found the HttpServletRequest.upgrade
method which appears to be for changing protocols. I've updated the part after the header copying:
int respCode = Integer.parseInt(response.substring(idx1 + 1, idx2));
if (respCode != 101) {
servletResponse.setStatus(respCode);
servletResponse.flushBuffer();
System.out.println(" < WS| Flush");
closeSocket = true;
} else {
var uh = servletRequest.upgrade(WsUpgradeHandler.class);
uh.preInit(exec, sockIn, sockOut, sock);
}
Where WsUpgradeHandler
is
public static class WsUpgradeHandler implements HttpUpgradeHandler {
ExecutorService exec;
InputStream sockIn;
OutputStream sockOut;
Socket sock;
Future<?> f1;
Future<?> f2;
public WsUpgradeHandler() { }
public void preInit(ExecutorService exec, InputStream sockIn, OutputStream sockOut, Socket sock) {
this.exec = exec;
this.sockIn = sockIn;
this.sockOut = sockOut;
this.sock = sock;
}
@Override
public void init(WebConnection wc) {
System.out.println(" * WS| Upgrade begin");
try {
var servletIn = wc.getInputStream();
var servletOut = wc.getOutputStream();
f1 = exec.submit(() -> {
System.out.println(" > WS| Client -> Backend");
var c = 0;
var bs = 0;
try {
while ((bs = servletIn.read()) != -1) {
sockOut.write(bs);
c++;
}
} catch (Exception exc) {
exc.printStackTrace();
} finally {
sockOut.close();
}
System.out.println(" > WS| Done: " + c);
return null;
});
f2 = exec.submit(() -> {
System.out.println(" > WS| Backend -> Client");
var c = 0;
try {
var bs = 0;
while ((bs = sockIn.read()) != -1) {
servletOut.write(bs);
servletOut.flush();
c++;
}
} catch (Exception exc) {
exc.printStackTrace();
} finally {
servletOut.close();
}
System.out.println(" < WS| Done: " + c);
return null;
});
} catch (IOException ex) {
ex.printStackTrace();
}
}
@Override
public void destroy() {
System.out.println(" * WS| Upgrade closing");
f1.cancel(true);
f2.cancel(true);
try {
sock.close();
} catch (IOException ex) {
}
System.out.println(" * WS| Upgrade close");
}
}
This does work for passing messages around but if the websocket connection from the browser ends, Tomcat's CPU utilization goes very high (no other activity should be happening) at this point. It appears some or all of Tomcat's NIO theads are spinning and the thread pool I'm using has no threads any longer.
Upvotes: 3
Views: 984
Reputation: 69997
I think I managed to solve the issue.
The code above was almost correct with one exception: apparently the init()
method should not return when using blocking mode as demonstrated by this Tomcat test example.
The second issue, namely the high CPU usage was tracked down to a poller thread in tomcat that had bugs before. I was running my code in Tomcat 9.0.12 and once upgraded to Tomcat 9.0.54, the CPU usage issues went away.
Thus the complete working code looks like this: (I know, I know, byte showeling and manually preparing HTML requrests is not optimal, but that's what Loom is for, right ;)
import java.io.*;
import java.net.*;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.*;
import javax.servlet.ServletException;
import javax.servlet.http.*;
import org.apache.http.HttpRequest;
import org.mitre.dsmiley.httpproxy.ProxyServlet;
public class ProxyWithWebSocket extends ProxyServlet {
private static final long serialVersionUID = -2566573965489129976L;
protected ExecutorService exec;
@Override
public void init() throws ServletException {
super.init();
exec = Executors.newCachedThreadPool();
}
@Override
public void destroy() {
super.destroy();
exec.shutdown();
}
@Override
protected void copyRequestHeaders(HttpServletRequest servletRequest, HttpRequest proxyRequest) {
super.copyRequestHeaders(servletRequest, proxyRequest);
String userId = (String)servletRequest.getAttribute("UserID");
if (userId != null) {
proxyRequest.addHeader("UserID", userId);
}
}
@Override
protected void service(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
throws ServletException, IOException {
var wsKey = servletRequest.getHeader("Sec-WebSocket-Key");
if (wsKey != null) {
//initialize request attributes from caches if unset by a subclass by this point
if (servletRequest.getAttribute(ATTR_TARGET_URI) == null) {
servletRequest.setAttribute(ATTR_TARGET_URI, targetUri);
}
if (servletRequest.getAttribute(ATTR_TARGET_HOST) == null) {
servletRequest.setAttribute(ATTR_TARGET_HOST, targetHost);
}
String proxyRequestUri = rewriteUrlFromRequest(servletRequest);
URL u = new URL(proxyRequestUri);
Socket sock = new Socket(u.getHost(), u.getPort());
boolean closeSocket = false;
try {
var sockIn = sock.getInputStream();
var sockOut = sock.getOutputStream();
StringBuilder req = new StringBuilder(512);
req.append("GET " + u.getFile()).append(" HTTP/1.1");
System.out.println(" > WS|" + req);
req.append("\r\n");
var en = servletRequest.getHeaderNames();
while (en.hasMoreElements()) {
var n = en.nextElement();
String header = servletRequest.getHeader(n);
System.out.println(" > WS| " + n + ": " + header);
req.append(n + ": " + header + "\r\n");
}
req.append("\r\n");
sockOut.write(req.toString().getBytes(StandardCharsets.UTF_8));
sockOut.flush();
StringBuilder responseBytes = new StringBuilder(512);
int b = 0;
while (b != -1) {
b = sockIn.read();
if (b != -1) {
responseBytes.append((char)b);
var len = responseBytes.length();
if (len >= 4
&& responseBytes.charAt(len - 4) == '\r'
&& responseBytes.charAt(len - 3) == '\n'
&& responseBytes.charAt(len - 2) == '\r'
&& responseBytes.charAt(len - 1) == '\n'
) {
break;
}
}
}
String[] rows = responseBytes.toString().split("\r\n");
String response = rows[0];
System.out.println(" < WS|" + response);
int idx1 = response.indexOf(' ');
int idx2 = response.indexOf(' ', idx1 + 1);
for (int i = 1; i < rows.length; i++) {
String line = rows[i];
int idx3 = line.indexOf(":");
var k = line.substring(0, idx3);
var headerField = line.substring(idx3 + 2);
System.out.println(" < WS| " + k + ": " + headerField);
servletResponse.setHeader(k, headerField);
}
int respCode = Integer.parseInt(response.substring(idx1 + 1, idx2));
if (respCode != 101) {
servletResponse.setStatus(respCode);
servletResponse.flushBuffer();
System.out.println(" < WS| Flush");
closeSocket = true;
} else {
var uh = servletRequest.upgrade(WsUpgradeHandler.class);
uh.preInit(exec, sockIn, sockOut, sock);
}
} finally {
if (closeSocket) {
sock.close();
}
}
} else {
super.service(servletRequest, servletResponse);
}
}
public static class WsUpgradeHandler implements HttpUpgradeHandler {
ExecutorService exec;
InputStream sockIn;
OutputStream sockOut;
Socket sock;
Future<?> f2;
public WsUpgradeHandler() { }
public void preInit(ExecutorService exec, InputStream sockIn, OutputStream sockOut, Socket sock) {
this.exec = exec;
this.sockIn = sockIn;
this.sockOut = sockOut;
this.sock = sock;
}
@Override
public void init(WebConnection wc) {
System.out.println(" * WS| Upgrade begin");
try {
var servletIn = wc.getInputStream();
var servletOut = wc.getOutputStream();
f2 = exec.submit(() -> {
System.out.println(" > WS| Backend -> Client");
var c = 0;
try {
var bs = 0;
while ((bs = sockIn.read()) != -1) {
servletOut.write(bs);
servletOut.flush();
c++;
}
} catch (SocketException | EOFException exc) {
// this is fine
} catch (Exception exc) {
exc.printStackTrace();
} finally {
servletOut.close();
}
System.out.println(" < WS| Done: " + c);
return null;
});
System.out.println(" > WS| Client -> Backend");
var c = 0;
var bs = 0;
try {
while ((bs = servletIn.read()) != -1) {
sockOut.write(bs);
c++;
}
} catch (SocketException | EOFException exc) {
// this is fine
} catch (Exception exc) {
exc.printStackTrace();
} finally {
sockOut.close();
}
System.out.println(" > WS| Done: " + c);
f2.get();
} catch (Exception ex) {
ex.printStackTrace();
} finally {
if (f2 != null) {
f2.cancel(true);
}
}
}
@Override
public void destroy() {
System.out.println(" * WS| Upgrade closing");
if (f2 != null) {
f2.cancel(true);
}
try {
sock.close();
} catch (IOException ex) {
}
System.out.println(" * WS| Upgrade close");
}
}
}
Upvotes: 3