1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
use std::collections::HashMap;

use error::Result;
use http::StatusCode;
use httprequest::HttpRequest;
use httpresponse::HttpResponse;
use middleware::{Middleware, Response};

type ErrorHandler<S> = Fn(&HttpRequest<S>, HttpResponse) -> Result<Response>;

/// `Middleware` for allowing custom handlers for responses.
///
/// You can use `ErrorHandlers::handler()` method  to register a custom error
/// handler for specific status code. You can modify existing response or
/// create completely new one.
///
/// ## Example
///
/// ```rust
/// # extern crate actix_web;
/// use actix_web::middleware::{ErrorHandlers, Response};
/// use actix_web::{http, App, HttpRequest, HttpResponse, Result};
///
/// fn render_500<S>(_: &HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
///     let mut builder = resp.into_builder();
///     builder.header(http::header::CONTENT_TYPE, "application/json");
///     Ok(Response::Done(builder.into()))
/// }
///
/// fn main() {
///     let app = App::new()
///         .middleware(
///             ErrorHandlers::new()
///                 .handler(http::StatusCode::INTERNAL_SERVER_ERROR, render_500),
///         )
///         .resource("/test", |r| {
///             r.method(http::Method::GET).f(|_| HttpResponse::Ok());
///             r.method(http::Method::HEAD)
///                 .f(|_| HttpResponse::MethodNotAllowed());
///         })
///         .finish();
/// }
/// ```
pub struct ErrorHandlers<S> {
    handlers: HashMap<StatusCode, Box<ErrorHandler<S>>>,
}

impl<S> Default for ErrorHandlers<S> {
    fn default() -> Self {
        ErrorHandlers {
            handlers: HashMap::new(),
        }
    }
}

impl<S> ErrorHandlers<S> {
    /// Construct new `ErrorHandlers` instance
    pub fn new() -> Self {
        ErrorHandlers::default()
    }

    /// Register error handler for specified status code
    pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self
    where
        F: Fn(&HttpRequest<S>, HttpResponse) -> Result<Response> + 'static,
    {
        self.handlers.insert(status, Box::new(handler));
        self
    }
}

impl<S: 'static> Middleware<S> for ErrorHandlers<S> {
    fn response(&self, req: &HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
        if let Some(handler) = self.handlers.get(&resp.status()) {
            handler(req, resp)
        } else {
            Ok(Response::Done(resp))
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use error::{Error, ErrorInternalServerError};
    use http::header::CONTENT_TYPE;
    use http::StatusCode;
    use httpmessage::HttpMessage;
    use middleware::Started;
    use test::{self, TestRequest};

    fn render_500<S>(_: &HttpRequest<S>, resp: HttpResponse) -> Result<Response> {
        let mut builder = resp.into_builder();
        builder.header(CONTENT_TYPE, "0001");
        Ok(Response::Done(builder.into()))
    }

    #[test]
    fn test_handler() {
        let mw =
            ErrorHandlers::new().handler(StatusCode::INTERNAL_SERVER_ERROR, render_500);

        let mut req = TestRequest::default().finish();
        let resp = HttpResponse::InternalServerError().finish();
        let resp = match mw.response(&mut req, resp) {
            Ok(Response::Done(resp)) => resp,
            _ => panic!(),
        };
        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");

        let resp = HttpResponse::Ok().finish();
        let resp = match mw.response(&mut req, resp) {
            Ok(Response::Done(resp)) => resp,
            _ => panic!(),
        };
        assert!(!resp.headers().contains_key(CONTENT_TYPE));
    }

    struct MiddlewareOne;

    impl<S> Middleware<S> for MiddlewareOne {
        fn start(&self, _: &HttpRequest<S>) -> Result<Started, Error> {
            Err(ErrorInternalServerError("middleware error"))
        }
    }

    #[test]
    fn test_middleware_start_error() {
        let mut srv = test::TestServer::new(move |app| {
            app.middleware(
                ErrorHandlers::new()
                    .handler(StatusCode::INTERNAL_SERVER_ERROR, render_500),
            ).middleware(MiddlewareOne)
                .handler(|_| HttpResponse::Ok())
        });

        let request = srv.get().finish().unwrap();
        let response = srv.execute(request.send()).unwrap();
        assert_eq!(response.headers().get(CONTENT_TYPE).unwrap(), "0001");
    }
}