user1563721
user1563721

Reputation: 1561

WebSocket server get client certificate onOpen

I have a simple ServerEndpoint running on WildFly 10, which is configured as wss with mutual TLS, so client certificate required. I have no problems in connecting to endpoint, so the mutual authentication is correctly done, but I can't access client certificate in onOpen method. I am trying to do it using getUserPrincipal(), I'm always getting null.

I need to get client certificate for authorization purposes.

import java.io.IOException;
import java.security.Principal;

import javax.servlet.http.HttpSession;
import javax.websocket.EndpointConfig;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;

@ServerEndpoint(value = "/test", configurator = GetHttpSessionConfigurator.class)
public class TestWebSocketEndPoint {

    private Session wsSession;
    private HttpSession httpSession;

    @OnOpen
    public void onOpen(Session session, EndpointConfig config){
        this.wsSession = session;
        this.httpSession = (HttpSession) config.getUserProperties().get(HttpSession.class.getName());
        Principal userPrincipal = session.getUserPrincipal();
        System.out.println(session.getId() + " has opened a connection"); 
        try {
            session.getBasicRemote().sendText("Connection Established");
        } catch (IOException ex) {
            ex.printStackTrace();
        }
    }

    /**
     * When a user sends a message to the server, this method will intercept the message
     * and allow us to react to it. For now the message is read as a String.
     */
    @OnMessage
    public void onMessage(String message, Session session){
        System.out.println("Message from " + session.getId() + ": " + message);
        try {
            session.getBasicRemote().sendText(message);
        } catch (IOException ex) {
            ex.printStackTrace();
        }
    }

    /**
     * The user closes the connection.
     * 
     * Note: you can't send messages to the client from this method
     */
    @OnClose
    public void onClose(Session session){
        System.out.println("Session " +session.getId()+" has ended");
    }
}

GetHttpSessionConfigurator:

import java.security.Principal;
import java.util.List;
import java.util.Map;

import javax.servlet.http.HttpSession;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;

public class GetHttpSessionConfigurator extends ServerEndpointConfig.Configurator {
    @Override
    public void modifyHandshake(ServerEndpointConfig config, 
                                HandshakeRequest request, 
                                HandshakeResponse response)
    {
        HttpSession httpSession = (HttpSession)request.getHttpSession();
        Map<String, List<String>> map = request.getParameterMap();
        Principal principal = request.getUserPrincipal();
        config.getUserProperties().put(HttpSession.class.getName(),httpSession);
    }
}

RequestListener:

import java.security.Principal;
import java.security.cert.X509Certificate;

import javax.servlet.ServletRequestEvent;
import javax.servlet.ServletRequestListener;
import javax.servlet.annotation.WebListener;
import javax.servlet.http.HttpServletRequest;

@WebListener
public class RequestListener implements ServletRequestListener {

    public void requestDestroyed(ServletRequestEvent sre) {
        // TODO Auto-generated method stub

    }

    public void requestInitialized(ServletRequestEvent sre) {
        ((HttpServletRequest) sre.getServletRequest()).getSession();
        Principal p = ((HttpServletRequest) sre.getServletRequest()).getUserPrincipal();

        boolean secure = ((HttpServletRequest) sre.getServletRequest()).isSecure();
        String authType = ((HttpServletRequest) sre.getServletRequest()).getAuthType();

        X509Certificate[] certs = (X509Certificate[]) ((HttpServletRequest) sre.getServletRequest()).getAttribute("javax.servlet.request.X509Certificate");
    }

}

The websocket client is a standalone application using TooTallNate/java-websocket and connecting securely:

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.UnrecoverableKeyException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.util.Enumeration;

import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.java_websocket.WebSocketImpl;

public class TestClient {

    private static final Log log = LogFactory.getLog(TestClient.class);

    public static void main(String[] args) throws URISyntaxException {
        WebSocketImpl.DEBUG = true;

        WSRAClient wsRaClient = new WSRAClient(new URI("wss://localhost:8443/TestWebSocket-0.0.1-SNAPSHOT/test"));

        String keystoreFile = "keystore.p12";
        String keystorePassword = "keystore";

        String truststoreFile = "truststore.jks";
        String truststorePassword = "truststore";


        try {
            SSLContext ssl = SSLContext.getInstance("TLSv1.2");

            log.info("Configuring SSL keystore");
            KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); 
            KeyStore store = KeyStore.getInstance(KeyStore.getDefaultType());
            log.debug("Loading keystore");
            store.load(new FileInputStream(keystoreFile), keystorePassword.toCharArray());
            log.debug("Number of keystore certificates: " + store.size());
            Enumeration<String> enumeration = store.aliases();
            while(enumeration.hasMoreElements()) {
                String alias = enumeration.nextElement();
                log.debug("alias name: " + alias);
                Certificate certificate = store.getCertificate(alias);
                log.debug(certificate.toString());
            }
            kmf.init(store, keystorePassword.toCharArray());
            KeyManager[] keyManagers = new KeyManager[1];
            keyManagers = kmf.getKeyManagers();

            log.info("Configuring SSL truststore");
            TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
            KeyStore truststore = KeyStore.getInstance(KeyStore.getDefaultType());
            log.debug("Loading truststore");
            truststore.load(new FileInputStream(truststoreFile), truststorePassword.toCharArray());
            log.debug("Number of truststore certificates: " + truststore.size());
            enumeration = truststore.aliases();
            while(enumeration.hasMoreElements()) {
                String alias = (String)enumeration.nextElement();
                log.debug("alias name: " + alias);
                Certificate certificate = truststore.getCertificate(alias);
                log.debug(certificate.toString());
            }
            tmf.init(truststore);
            TrustManager[] trustManagers = tmf.getTrustManagers();

            ssl.init(keyManagers, trustManagers, new SecureRandom());

            SSLSocketFactory factory = ssl.getSocketFactory();// (SSLSocketFactory) SSLSocketFactory.getDefault();

            wsRaClient.setSocket(factory.createSocket());

            wsRaClient.connectBlocking();

            BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
            while ( true ) {
                String line = reader.readLine();
                if(line.equals("close")) {
                    wsRaClient.close();
                } else {
                    wsRaClient.send(line);
                }
            }

        } catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
            log.error(e);
            System.exit(0);
        } catch (KeyStoreException e) {
            e.printStackTrace();
            log.error(e);
            System.exit(0);
        } catch (CertificateException e) {
            e.printStackTrace();
            log.error(e);
            System.exit(0);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
            log.error(e);
            System.exit(0);
        } catch (IOException e) {
            e.printStackTrace();
            log.error(e);
            System.exit(0);
        } catch (UnrecoverableKeyException e) {
            e.printStackTrace();
            log.error(e);
            System.exit(0);
        } catch (KeyManagementException e) {
            e.printStackTrace();
            log.error(e);
            System.exit(0);
        } catch (InterruptedException e) {
            e.printStackTrace();
            log.error(e);
            System.exit(0);
        }

    }

}

Upvotes: 2

Views: 1683

Answers (1)

Saar peer
Saar peer

Reputation: 847

See : Accessing HttpServletRequest properties within a WebSocket @ServerEndpoint

  1. Create servlet filter on URL pattern matching websocket handshake request.
  2. In filter, get request attribute of interest and put it in session before continuing chain.
  3. Finally get it from the session which is in turn just available via handshake request

Upvotes: 3

Related Questions