Reputation: 1054
I have a simple REST controller that I use for accepting a file being uploaded from a HTML form. The project is Spring Boot 2.6.1 and Java 17. But the problem was also to be found in Spring Boot 2.3.7 and Java 15.
@PostMapping(path = "/file", consumes = MediaType.MULTIPART_FORM_DATA_VALUE)
public void handleFileUpload(@RequestParam("file") MultipartFile file) {
fileService.upload(file.getInputStream(), file.getOriginalFilename());
}
The problem is file
is always NULL. I found a lot of different answers about setting a MultipartResolver
bean or enabling spring.http.multipart.enabled = true
but nothing helped. I have a logging filter as one of the first filters in the chain. After debugging in the filter chain I found out that making a call to request.getParts()
made everything work. My filter look like this:
public class LoggingFilter extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
BufferedRequestWrapper bufferedRequest = new BufferedRequestWrapper(httpServletRequest);
BufferedResponseWrapper bufferedResponse = new BufferedResponseWrapper((HttpServletResponse) response);
filterChain.doFilter(bufferedRequest, bufferedResponse);
logRequest(httpServletRequest, bufferedRequest);
logResponse(httpServletRequest, bufferedResponse);
}
I changed the filter to:
public class LoggingFilter extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
if (request.getContentType() != null && request.getContentType().startsWith("multipart/form-data")) {
httpServletRequest.getParts(); // Trigger initialization of multi-part.
}
BufferedRequestWrapper bufferedRequest = new BufferedRequestWrapper(httpServletRequest);
BufferedResponseWrapper bufferedResponse = new BufferedResponseWrapper((HttpServletResponse) response);
filterChain.doFilter(bufferedRequest, bufferedResponse);
logRequest(httpServletRequest, bufferedRequest);
logResponse(httpServletRequest, bufferedResponse);
}
and everything was working. My question is; why is this needed? And is there a better way of doing this?
Below is a complete example where only the actual logging is removed because we use a custom logging framework.
package com.unwire.ticketing.filter.logging;
import lombok.Getter;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.output.TeeOutputStream;
import org.springframework.web.filter.GenericFilterBean;
import javax.servlet.*;
import javax.servlet.http.*;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Locale;
import java.util.stream.Collectors;
public class Log extends GenericFilterBean {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
if (request.getContentType() != null && request.getContentType().startsWith("multipart/form-data")) {
httpServletRequest.getParts(); // Trigger initialization of multi-part.
}
try {
BufferedRequestWrapper bufferedRequest = new BufferedRequestWrapper(httpServletRequest);
BufferedResponseWrapper bufferedResponse = new BufferedResponseWrapper((HttpServletResponse) response);
filterChain.doFilter(bufferedRequest, bufferedResponse);
logRequest(httpServletRequest, bufferedRequest);
logResponse(httpServletRequest, bufferedResponse);
} catch (Throwable t) {
}
}
private void logRequest(HttpServletRequest request, BufferedRequestWrapper bufferedRequest) throws IOException {
String body = bufferedRequest.getRequestBody();
// Log request
}
private void logResponse(HttpServletRequest httpServletRequest, BufferedResponseWrapper bufferedResponse) {
// Log response
}
private static final class BufferedRequestWrapper extends HttpServletRequestWrapper {
private final byte[] buffer;
BufferedRequestWrapper(HttpServletRequest req) throws IOException {
super(req);
if (req.getContentType() == null || (req.getContentType() != null && !req.getContentType().startsWith("application/x-www-form-urlencoded"))) {
// Read InputStream and store its content in a buffer.
InputStream is = req.getInputStream();
ByteArrayOutputStream baos = new ByteArrayOutputStream();
byte[] buf = new byte[1024];
int read;
while ((read = is.read(buf)) > 0) {
baos.write(buf, 0, read);
}
this.buffer = baos.toByteArray();
} else {
buffer = new byte[0];
}
}
@Override
public ServletInputStream getInputStream() {
return new BufferedServletInputStream(new ByteArrayInputStream(this.buffer));
}
@Override
public Collection<Part> getParts() throws IOException, ServletException {
return super.getParts();
}
String getRequestBody() throws IOException {
return IOUtils.readLines(this.getInputStream(), StandardCharsets.UTF_8.name()).stream()
.map(String::trim)
.collect(Collectors.joining());
}
}
private static final class BufferedServletInputStream extends ServletInputStream {
private final ByteArrayInputStream bais;
BufferedServletInputStream(ByteArrayInputStream bais) {
this.bais = bais;
}
@Override
public int available() {
return this.bais.available();
}
@Override
public int read() {
return this.bais.read();
}
@Override
public int read(byte[] buf, int off, int len) {
return this.bais.read(buf, off, len);
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setReadListener(ReadListener readListener) {
}
}
public static class TeeServletOutputStream extends ServletOutputStream {
private final TeeOutputStream targetStream;
TeeServletOutputStream(OutputStream one, OutputStream two) {
targetStream = new TeeOutputStream(one, two);
}
@Override
public void write(int arg0) throws IOException {
this.targetStream.write(arg0);
}
public void flush() throws IOException {
super.flush();
this.targetStream.flush();
}
public void close() throws IOException {
super.close();
this.targetStream.close();
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setWriteListener(WriteListener writeListener) {
}
}
public class BufferedResponseWrapper implements HttpServletResponse {
HttpServletResponse original;
TeeServletOutputStream tee;
ByteArrayOutputStream bos;
@Getter
Long startTime;
BufferedResponseWrapper(HttpServletResponse response) {
this.original = response;
this.startTime = System.currentTimeMillis();
}
String getContent() {
if (bos != null) {
return bos.toString();
} else {
return "";
}
}
@Override
public PrintWriter getWriter() throws IOException {
return original.getWriter();
}
@Override
public ServletOutputStream getOutputStream() throws IOException {
if (tee == null) {
bos = new ByteArrayOutputStream();
tee = new TeeServletOutputStream(original.getOutputStream(), bos);
}
return tee;
}
@Override
public String getCharacterEncoding() {
return original.getCharacterEncoding();
}
@Override
public void setCharacterEncoding(String charset) {
original.setCharacterEncoding(charset);
}
@Override
public String getContentType() {
return original.getContentType();
}
@Override
public void setContentType(String type) {
original.setContentType(type);
}
@Override
public void setContentLength(int len) {
original.setContentLength(len);
}
@Override
public void setContentLengthLong(long l) {
original.setContentLengthLong(l);
}
@Override
public int getBufferSize() {
return original.getBufferSize();
}
@Override
public void setBufferSize(int size) {
original.setBufferSize(size);
}
@Override
public void flushBuffer() throws IOException {
if (tee != null) {
tee.flush();
}
}
@Override
public void resetBuffer() {
original.resetBuffer();
}
@Override
public boolean isCommitted() {
return original.isCommitted();
}
@Override
public void reset() {
original.reset();
}
@Override
public Locale getLocale() {
return original.getLocale();
}
@Override
public void setLocale(Locale loc) {
original.setLocale(loc);
}
@Override
public void addCookie(Cookie cookie) {
original.addCookie(cookie);
}
@Override
public boolean containsHeader(String name) {
return original.containsHeader(name);
}
@Override
public String encodeURL(String url) {
return original.encodeURL(url);
}
@Override
public String encodeRedirectURL(String url) {
return original.encodeRedirectURL(url);
}
@Override
public void sendError(int sc, String msg) throws IOException {
original.sendError(sc, msg);
}
@Override
public void sendError(int sc) throws IOException {
original.sendError(sc);
}
@Override
public void sendRedirect(String location) throws IOException {
original.sendRedirect(location);
}
@Override
public void setDateHeader(String name, long date) {
original.setDateHeader(name, date);
}
@Override
public void addDateHeader(String name, long date) {
original.addDateHeader(name, date);
}
@Override
public void setHeader(String name, String value) {
original.setHeader(name, value);
}
@Override
public void addHeader(String name, String value) {
original.addHeader(name, value);
}
@Override
public void setIntHeader(String name, int value) {
original.setIntHeader(name, value);
}
@Override
public void addIntHeader(String name, int value) {
original.addIntHeader(name, value);
}
@Override
public String getHeader(String arg0) {
return original.getHeader(arg0);
}
@Override
public Collection<String> getHeaderNames() {
return original.getHeaderNames();
}
@Override
public Collection<String> getHeaders(String arg0) {
return original.getHeaders(arg0);
}
@Override
public int getStatus() {
return original.getStatus();
}
@Override
public void setStatus(int sc) {
original.setStatus(sc);
}
}
}
Upvotes: 4
Views: 4460
Reputation: 1
This solution worked for logging a multipart file :)
public class RequestResponseLoggingFilter extends OncePerRequestFilter {
private static final Logger logAccess = LogManager.getLogger("info-log");
private static final List<MediaType> VISIBLE_TYPES = Arrays.asList(
MediaType.valueOf("text/*"),
MediaType.APPLICATION_FORM_URLENCODED,
MediaType.APPLICATION_JSON,
MediaType.APPLICATION_XML,
MediaType.valueOf("application/*+json"),
MediaType.valueOf("application/*+xml"),
MediaType.MULTIPART_FORM_DATA
);
private static final List<String> SENSITIVE_HEADERS = Arrays.asList(
"authorization",
"proxy-authorization"
);
@Value("${debug.request-response-log:true}")
private boolean enabled;
@ManagedOperation(description = "Enable logging of HTTP requests and responses")
public void enable() {
this.enabled = true;
}
@ManagedOperation(description = "Disable logging of HTTP requests and responses")
public void disable() {
this.enabled = false;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
if (isAsyncDispatch(request)) {
filterChain.doFilter(request, response);
} else {
doFilterWrapped(wrapRequest(request), wrapResponse(response), filterChain);
}
}
protected void doFilterWrapped(ContentCachingRequestWrapper request, ContentCachingResponseWrapper response, FilterChain filterChain) throws ServletException, IOException {
StringBuilder msg = new StringBuilder();
try {
beforeRequest(request, response, msg);
filterChain.doFilter(request, response);
} finally {
afterRequest(request, response, msg);
if (logAccess.isInfoEnabled()) {
logAccess.info(msg.toString());
}
response.copyBodyToResponse();
}
}
protected void beforeRequest(ContentCachingRequestWrapper request, ContentCachingResponseWrapper response, StringBuilder msg) {
if (enabled && logAccess.isInfoEnabled()) {
msg.append("\n================================================================================= \n\n");
msg.append("Access ").append(request.getRequestURI()).append("\n");
msg.append("URL \t\t : ").append(request.getRequestURI()).append("\n");
msg.append("Header \t\t : \n");
logRequestHeader(request, "", msg);
msg.append("QueryParam \t : ");
String queryString = request.getQueryString();
if (queryString != null) {
msg.append(queryString);
}
msg.append("\nRequest \t : ");
logRequestBody(request, "", msg);
}
}
protected void afterRequest(ContentCachingRequestWrapper request, ContentCachingResponseWrapper response, StringBuilder msg) {
if (enabled && logAccess.isInfoEnabled()) {
msg.append("\nResponse \t : ");
logResponse(response, "", msg);
msg.append("\n================================================================================= \n");
}
}
private static void logRequestHeader(ContentCachingRequestWrapper request, String prefix, StringBuilder msg) {
Collections.list(request.getHeaderNames())
.forEach(headerName ->
Collections.list(request.getHeaders(headerName))
.forEach(headerValue -> {
if (isSensitiveHeader(headerName)) {
msg.append(String.format("%s %s: %s", prefix, headerName, "*******")).append("\n");
} else {
msg.append(String.format("%s %s: %s", prefix, headerName, headerValue)).append("\n");
}
}));
msg.append(prefix).append("\n");
}
private static void logRequestBody(ContentCachingRequestWrapper request, String prefix, StringBuilder msg) {
String contentType = request.getContentType();
if (contentType != null && contentType.startsWith("multipart/form-data")) {
try {
Collection<Part> parts = request.getParts();
msg.append("{");
for (Part part : parts) {
msg.append(String.format("\"%s\": ", part.getName()));
if (part.getSubmittedFileName() != null) {
InputStream inputStream = part.getInputStream();
byte[] fileBytes = IOUtils.toByteArray(inputStream);
String base64Content = Base64.getEncoder().encodeToString(fileBytes);
msg.append(String.format("\"%s\"", base64Content));
msg.append(",");
} else {
String fieldValue = IOUtils.toString(part.getInputStream(), StandardCharsets.UTF_8);
msg.append(String.format("\"%s\",", fieldValue));
}
}
if (msg.length() > 2 && msg.charAt(msg.length() - 1) == ',') {
msg.setLength(msg.length() - 1);
}
msg.append("}");
} catch (Exception e) {
msg.append(String.format("%s Failed to log multipart content due to: %s\n", prefix, e.getMessage()));
}
} else {
byte[] content = request.getContentAsByteArray();
if (content.length > 0) {
logContent(content, request.getContentType(), request.getCharacterEncoding(), prefix, msg);
}
}
}
private static void logResponse(ContentCachingResponseWrapper response, String prefix, StringBuilder msg) {
byte[] content = response.getContentAsByteArray();
if (content.length > 0) {
logContent(content, response.getContentType(), response.getCharacterEncoding(), prefix, msg);
}
}
private static void logContent(byte[] content, String contentType, String contentEncoding, String prefix, StringBuilder msg) {
MediaType mediaType = MediaType.valueOf(contentType);
boolean visible = VISIBLE_TYPES.stream().anyMatch(visibleType -> visibleType.includes(mediaType));
if (visible) {
try {
String contentString = new String(content, contentEncoding);
msg.append(prefix).append(contentString).append("\n");
} catch (UnsupportedEncodingException e) {
msg.append(String.format("%s [%d bytes content]", prefix, content.length)).append("\n");
}
} else {
msg.append(String.format("%s [%d bytes content]", prefix, content.length)).append("\n");
}
}
private static boolean isSensitiveHeader(String headerName) {
return SENSITIVE_HEADERS.contains(headerName.toLowerCase());
}
private static ContentCachingRequestWrapper wrapRequest(HttpServletRequest request) {
if (request instanceof ContentCachingRequestWrapper) {
return (ContentCachingRequestWrapper) request;
} else {
return new ContentCachingRequestWrapper(request);
}
}
private static ContentCachingResponseWrapper wrapResponse(HttpServletResponse response) {
if (response instanceof ContentCachingResponseWrapper) {
return (ContentCachingResponseWrapper) response;
} else {
return new ContentCachingResponseWrapper(response);
}
}
}
Upvotes: 0
Reputation: 2061
Please consider using ContentCachingRequestWrapper.
It's built-in of spring which help you can read caches all content read from the input stream and reader.
Be aware, with multipart file, spring already have a wrapper ... MultipartHttpServletRequest
Upvotes: 3