Files
spring-soap/spring-security-modules/spring-5-security/src/main/java/com/baeldung/xss/XSSRequestWrapper.java
Hamid Reza Sharifi 7d5be17ce2 Bael-4684-Prevent Cross-Site Scripting (XSS) in a Spring application-(new) (#10480)
* #bael-4684: add main source code

* #bael-4684: add test

* #bael-4684: add required dependencies
2021-02-12 10:50:52 +00:00

124 lines
3.5 KiB
Java

package com.baeldung.xss;
import org.apache.commons.codec.Charsets;
import org.apache.commons.io.IOUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import static com.baeldung.xss.XSSUtils.stripXSS;
public class XSSRequestWrapper extends HttpServletRequestWrapper {
private byte[] rawData;
private HttpServletRequest request;
private ResettableServletInputStream servletStream;
public XSSRequestWrapper(HttpServletRequest request) {
super(request);
this.request = request;
this.servletStream = new ResettableServletInputStream();
}
public void resetInputStream(byte[] newRawData) {
rawData = newRawData;
servletStream.stream = new ByteArrayInputStream(newRawData);
}
@Override
public ServletInputStream getInputStream() throws IOException {
if (rawData == null) {
rawData = IOUtils.toByteArray(this.request.getReader(), Charsets.UTF_8);
servletStream.stream = new ByteArrayInputStream(rawData);
}
return servletStream;
}
@Override
public BufferedReader getReader() throws IOException {
if (rawData == null) {
rawData = IOUtils.toByteArray(this.request.getReader(), Charsets.UTF_8);
servletStream.stream = new ByteArrayInputStream(rawData);
}
return new BufferedReader(new InputStreamReader(servletStream));
}
private class ResettableServletInputStream extends ServletInputStream {
private InputStream stream;
@Override
public int read() throws IOException {
return stream.read();
}
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
}
}
@Override
public String[] getParameterValues(String parameter) {
String[] values = super.getParameterValues(parameter);
if (values == null) {
return null;
}
int count = values.length;
String[] encodedValues = new String[count];
for (int i = 0; i < count; i++) {
encodedValues[i] = stripXSS(values[i]);
}
return encodedValues;
}
@Override
public String getParameter(String parameter) {
String value = super.getParameter(parameter);
return stripXSS(value);
}
@Override
public String getHeader(String name) {
String value = super.getHeader(name);
return stripXSS(value);
}
@Override
public Enumeration<String> getHeaders(String name) {
List<String> result = new ArrayList<>();
Enumeration<String> headers = super.getHeaders(name);
while (headers.hasMoreElements()) {
String header = headers.nextElement();
String[] tokens = header.split(",");
for (String token : tokens) {
result.add(stripXSS(token));
}
}
return Collections.enumeration(result);
}
}