Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions middleware/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := ""

res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)

// Preflight request is an OPTIONS request, using three HTTP request headers: Access-Control-Request-Method,
// Access-Control-Request-Headers, and the Origin header. See: https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
// For simplicity we just consider method type and later `Origin` header.
Expand All @@ -211,6 +209,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {

// No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain
if origin == "" {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if !preflight {
return next(c)
}
Expand Down Expand Up @@ -261,26 +260,36 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {

// Origin not allowed
if allowOrigin == "" {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}

res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}

// Simple request
if !preflight {
if exposeHeaders != "" {
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
err := next(c)
// Skip setting CORS headers when an upstream handler (e.g. reverse proxy) already set them.
if res.Header().Get(echo.HeaderAccessControlAllowOrigin) == "" {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
if exposeHeaders != "" {
res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders)
}
}
return next(c)
return err
}

// Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
}
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)

Expand Down
42 changes: 42 additions & 0 deletions middleware/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"errors"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"testing"

"github.com/labstack/echo/v4"
Expand Down Expand Up @@ -683,3 +686,42 @@ func Test_allowOriginFunc(t *testing.T) {
}
}
}

func TestCORSNoDuplicateHeadersFromUpstream(t *testing.T) {
t.Parallel()

backend := echo.New()
backend.Use(CORS())
backend.GET("/", func(c echo.Context) error {
return c.String(http.StatusOK, "ok")
})
backendServer := httptest.NewServer(backend)
t.Cleanup(backendServer.Close)

backendURL, err := url.Parse(backendServer.URL)
assert.NoError(t, err)
reverseProxy := httputil.NewSingleHostReverseProxy(backendURL)

proxy := echo.New()
proxy.Use(CORS())
proxy.Any("/*", func(c echo.Context) error {
req := c.Request()
res := c.Response()
req.URL.Path = strings.TrimPrefix(req.URL.Path, "/proxy")
if req.URL.Path == "" {
req.URL.Path = "/"
}
reverseProxy.ServeHTTP(res, req)
return nil
})

req := httptest.NewRequest(http.MethodGet, "/proxy/", nil)
req.Header.Set(echo.HeaderOrigin, "http://example.com")
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, 1, len(rec.Header()[echo.HeaderAccessControlAllowOrigin]))
assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin))
assert.Equal(t, 1, len(rec.Header()[echo.HeaderVary]))
}