package com.devplatform.alarm.config;

import java.io.UnsupportedEncodingException;
import java.net.URLDecoder;
import java.util.regex.Pattern;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

/**
 * 请求参数过滤器(修复跨站脚本漏洞)
 *
 * @author admin
 */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {

  HttpServletRequest orgRequest = null;

  public XssHttpServletRequestWrapper(HttpServletRequest request) {
    super(request);
    orgRequest = request;
  }

  /**
   * 覆盖getParameter方法，将参数名和参数值都做xss过滤。<br/> 如果需要获得原始的值，则通过super.getParameterValues(name)来获取<br/> getParameterNames,getParameterValues和getParameterMap也可能需要覆盖
   */
  @Override
  public String getParameter(String name) {
    String value = super.getParameter(xssEncode(name));
    if (value != null) {
      value = xssEncode(value);
    }

    return value;
  }

  /**
   * 返回值之前 先进行过滤
   */
  @Override
  public String[] getParameterValues(String name) {
    String[] values = super.getParameterValues(xssEncode(name));
    if (values != null) {
      for (int i = 0; i < values.length; i++) {
        values[i] = xssEncode(values[i]);
      }
    }
    return values;
  }

  @Override
  public Object getAttribute(String name) {
    Object value = super.getAttribute(xssEncode(name));
    if (value instanceof String) {
      return xssEncode(value.toString());
    }

    return value;
  }

  /**
   * 覆盖getHeader方法，将参数名和参数值都做xss过滤。<br/> 如果需要获得原始的值，则通过super.getHeaders(name)来获取<br/> getHeaderNames 也可能需要覆盖
   */
  @Override
  public String getHeader(String name) {

    String value = super.getHeader(xssEncode(name));
    if (value != null) {
      value = xssEncode(value);
    }
    return value;
  }

  public String getRequestUrl() {
    String strVal = "";
    try {
      strVal = URLDecoder.decode(super.getRequestURI(), "UTF-8");
    } catch (UnsupportedEncodingException e) {
      e.printStackTrace();
    }
    String value = xssEncode(strVal);
    if (value != null) {
      value = xssEncode(value);
    }
    return value;
  }

  /**
   * 转义
   */
  public String escape(String s) {
    StringBuilder sb = new StringBuilder(s.length() + 16);
    for (int i = 0; i < s.length(); i++) {
      char c = s.charAt(i);
      switch (c) {
        case '>':
          sb.append('＞');
          break;
        case '<':
          sb.append('＜');
          break;
        case '\'':
          sb.append('‘');
          break;
        case '\"':
          sb.append('“');
          break;
        case '\\':
          sb.append('＼');
          break;
        case '%':
          sb.append('％');
          break;
        case ';':
          sb.append('；');
          break;
        default:
          sb.append(c);
          break;
      }

    }
    return sb.toString();
  }

  /**
   * 将容易引起xss漏洞的半角字符直接替换成全角字符
   */
  public String xssEncode(String s) {
    if (s == null || s.isEmpty()) {
      return s;
    }

    String result = stripXss(s);
    if (null != result) {
      result = escape(result);
    }

    return result;
  }

  private String stripXss(String value) {
    if (value != null) {

      value = value.replaceAll("", "");
      String srcStr = "<script>(.*?)</script>";
      Pattern scriptPattern = Pattern.compile(srcStr, Pattern.CASE_INSENSITIVE);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr1 = "<(no)?script[^>]*>.*?</(no)?script>";
      Pattern alertPanttern = Pattern.compile(srcStr1, Pattern.CASE_INSENSITIVE);
      value = alertPanttern.matcher(value).replaceAll("");

      String srcStr2 = "src[\r\n]*=[\r\n]*\\\'(.*?)\\\'";
      scriptPattern = Pattern.compile(srcStr2, Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr3 = "src[\r\n]*=[\r\n]*\\\"(.*?)\\\"";
      scriptPattern = Pattern.compile(srcStr3, Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr4 = "</script>";
      scriptPattern = Pattern.compile(srcStr4, Pattern.CASE_INSENSITIVE);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr5 = "<script(.*?)>";
      scriptPattern = Pattern.compile(srcStr5, Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr6 = "eval\\((.*?)\\)";
      scriptPattern = Pattern.compile(srcStr6,
          Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr7 = "expression\\((.*?)\\)";
      scriptPattern = Pattern.compile(srcStr7, Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr8 = "javascript:";
      scriptPattern = Pattern.compile(srcStr8, Pattern.CASE_INSENSITIVE);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr9 = "vbscript:";
      scriptPattern = Pattern.compile(srcStr9, Pattern.CASE_INSENSITIVE);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr10 = "onload(.*?)=";
      scriptPattern = Pattern.compile(srcStr10, Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr11 = "<iframe>(.*?)</iframe>";
      scriptPattern = Pattern.compile(srcStr11, Pattern.CASE_INSENSITIVE);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr12 = "</iframe>";
      scriptPattern = Pattern.compile(srcStr12, Pattern.CASE_INSENSITIVE);
      value = scriptPattern.matcher(value).replaceAll("");

      String srcStr13 = "<iframe(.*?)>";
      scriptPattern = Pattern.compile(srcStr13, Pattern.CASE_INSENSITIVE | Pattern.MULTILINE | Pattern.DOTALL);
      value = scriptPattern.matcher(value).replaceAll("");
      value = value.replace(";", "");
      value = value.replace("<", "");
      value = value.replace(">", "");
    }
    return value;
  }

  /**
   * 获取最原始的request
   */
  public HttpServletRequest getOrgRequest() {
    return orgRequest;
  }

  @Override
  public HttpServletRequest getRequest() {
    return orgRequest;
  }

  /**
   * 获取最原始的request的静态方法
   */
  public static HttpServletRequest getOrgRequest(HttpServletRequest req) {
    if (req instanceof XssHttpServletRequestWrapper) {
      return ((XssHttpServletRequestWrapper) req).getOrgRequest();
    }

    return req;
  }
}
