229 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			229 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright 2014 beego Author. All Rights Reserved.
 | 
						|
//
 | 
						|
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
// you may not use this file except in compliance with the License.
 | 
						|
// You may obtain a copy of the License at
 | 
						|
//
 | 
						|
//      http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
//
 | 
						|
// Unless required by applicable law or agreed to in writing, software
 | 
						|
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
// See the License for the specific language governing permissions and
 | 
						|
// limitations under the License.
 | 
						|
 | 
						|
// Package cors provides handlers to enable CORS support.
 | 
						|
// Usage
 | 
						|
//	import (
 | 
						|
// 		"github.com/astaxie/beego"
 | 
						|
//		"github.com/astaxie/beego/plugins/cors"
 | 
						|
// )
 | 
						|
//
 | 
						|
//	func main() {
 | 
						|
//		// CORS for https://foo.* origins, allowing:
 | 
						|
//		// - PUT and PATCH methods
 | 
						|
//		// - Origin header
 | 
						|
//		// - Credentials share
 | 
						|
//		beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{
 | 
						|
//			AllowOrigins:     []string{"https://*.foo.com"},
 | 
						|
//			AllowMethods:     []string{"PUT", "PATCH"},
 | 
						|
//			AllowHeaders:     []string{"Origin"},
 | 
						|
//			ExposeHeaders:    []string{"Content-Length"},
 | 
						|
//			AllowCredentials: true,
 | 
						|
//		}))
 | 
						|
//		beego.Run()
 | 
						|
//	}
 | 
						|
package cors
 | 
						|
 | 
						|
import (
 | 
						|
	"net/http"
 | 
						|
	"regexp"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/astaxie/beego"
 | 
						|
	"github.com/astaxie/beego/context"
 | 
						|
)
 | 
						|
 | 
						|
const (
 | 
						|
	headerAllowOrigin      = "Access-Control-Allow-Origin"
 | 
						|
	headerAllowCredentials = "Access-Control-Allow-Credentials"
 | 
						|
	headerAllowHeaders     = "Access-Control-Allow-Headers"
 | 
						|
	headerAllowMethods     = "Access-Control-Allow-Methods"
 | 
						|
	headerExposeHeaders    = "Access-Control-Expose-Headers"
 | 
						|
	headerMaxAge           = "Access-Control-Max-Age"
 | 
						|
 | 
						|
	headerOrigin         = "Origin"
 | 
						|
	headerRequestMethod  = "Access-Control-Request-Method"
 | 
						|
	headerRequestHeaders = "Access-Control-Request-Headers"
 | 
						|
)
 | 
						|
 | 
						|
var (
 | 
						|
	defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"}
 | 
						|
	// Regex patterns are generated from AllowOrigins. These are used and generated internally.
 | 
						|
	allowOriginPatterns = []string{}
 | 
						|
)
 | 
						|
 | 
						|
// Options represents Access Control options.
 | 
						|
type Options struct {
 | 
						|
	// If set, all origins are allowed.
 | 
						|
	AllowAllOrigins bool
 | 
						|
	// A list of allowed origins. Wild cards and FQDNs are supported.
 | 
						|
	AllowOrigins []string
 | 
						|
	// If set, allows to share auth credentials such as cookies.
 | 
						|
	AllowCredentials bool
 | 
						|
	// A list of allowed HTTP methods.
 | 
						|
	AllowMethods []string
 | 
						|
	// A list of allowed HTTP headers.
 | 
						|
	AllowHeaders []string
 | 
						|
	// A list of exposed HTTP headers.
 | 
						|
	ExposeHeaders []string
 | 
						|
	// Max age of the CORS headers.
 | 
						|
	MaxAge time.Duration
 | 
						|
}
 | 
						|
 | 
						|
// Header converts options into CORS headers.
 | 
						|
func (o *Options) Header(origin string) (headers map[string]string) {
 | 
						|
	headers = make(map[string]string)
 | 
						|
	// if origin is not allowed, don't extend the headers
 | 
						|
	// with CORS headers.
 | 
						|
	if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	// add allow origin
 | 
						|
	if o.AllowAllOrigins {
 | 
						|
		headers[headerAllowOrigin] = "*"
 | 
						|
	} else {
 | 
						|
		headers[headerAllowOrigin] = origin
 | 
						|
	}
 | 
						|
 | 
						|
	// add allow credentials
 | 
						|
	headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
 | 
						|
 | 
						|
	// add allow methods
 | 
						|
	if len(o.AllowMethods) > 0 {
 | 
						|
		headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
 | 
						|
	}
 | 
						|
 | 
						|
	// add allow headers
 | 
						|
	if len(o.AllowHeaders) > 0 {
 | 
						|
		headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",")
 | 
						|
	}
 | 
						|
 | 
						|
	// add exposed header
 | 
						|
	if len(o.ExposeHeaders) > 0 {
 | 
						|
		headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
 | 
						|
	}
 | 
						|
	// add a max age header
 | 
						|
	if o.MaxAge > time.Duration(0) {
 | 
						|
		headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
// PreflightHeader converts options into CORS headers for a preflight response.
 | 
						|
func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) {
 | 
						|
	headers = make(map[string]string)
 | 
						|
	if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	// verify if requested method is allowed
 | 
						|
	for _, method := range o.AllowMethods {
 | 
						|
		if method == rMethod {
 | 
						|
			headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// verify if requested headers are allowed
 | 
						|
	var allowed []string
 | 
						|
	for _, rHeader := range strings.Split(rHeaders, ",") {
 | 
						|
		rHeader = strings.TrimSpace(rHeader)
 | 
						|
	lookupLoop:
 | 
						|
		for _, allowedHeader := range o.AllowHeaders {
 | 
						|
			if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) {
 | 
						|
				allowed = append(allowed, rHeader)
 | 
						|
				break lookupLoop
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
 | 
						|
	// add allow origin
 | 
						|
	if o.AllowAllOrigins {
 | 
						|
		headers[headerAllowOrigin] = "*"
 | 
						|
	} else {
 | 
						|
		headers[headerAllowOrigin] = origin
 | 
						|
	}
 | 
						|
 | 
						|
	// add allowed headers
 | 
						|
	if len(allowed) > 0 {
 | 
						|
		headers[headerAllowHeaders] = strings.Join(allowed, ",")
 | 
						|
	}
 | 
						|
 | 
						|
	// add exposed headers
 | 
						|
	if len(o.ExposeHeaders) > 0 {
 | 
						|
		headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
 | 
						|
	}
 | 
						|
	// add a max age header
 | 
						|
	if o.MaxAge > time.Duration(0) {
 | 
						|
		headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
// IsOriginAllowed looks up if the origin matches one of the patterns
 | 
						|
// generated from Options.AllowOrigins patterns.
 | 
						|
func (o *Options) IsOriginAllowed(origin string) (allowed bool) {
 | 
						|
	for _, pattern := range allowOriginPatterns {
 | 
						|
		allowed, _ = regexp.MatchString(pattern, origin)
 | 
						|
		if allowed {
 | 
						|
			return
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
// Allow enables CORS for requests those match the provided options.
 | 
						|
func Allow(opts *Options) beego.FilterFunc {
 | 
						|
	// Allow default headers if nothing is specified.
 | 
						|
	if len(opts.AllowHeaders) == 0 {
 | 
						|
		opts.AllowHeaders = defaultAllowHeaders
 | 
						|
	}
 | 
						|
 | 
						|
	for _, origin := range opts.AllowOrigins {
 | 
						|
		pattern := regexp.QuoteMeta(origin)
 | 
						|
		pattern = strings.Replace(pattern, "\\*", ".*", -1)
 | 
						|
		pattern = strings.Replace(pattern, "\\?", ".", -1)
 | 
						|
		allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$")
 | 
						|
	}
 | 
						|
 | 
						|
	return func(ctx *context.Context) {
 | 
						|
		var (
 | 
						|
			origin           = ctx.Input.Header(headerOrigin)
 | 
						|
			requestedMethod  = ctx.Input.Header(headerRequestMethod)
 | 
						|
			requestedHeaders = ctx.Input.Header(headerRequestHeaders)
 | 
						|
			// additional headers to be added
 | 
						|
			// to the response.
 | 
						|
			headers map[string]string
 | 
						|
		)
 | 
						|
 | 
						|
		if ctx.Input.Method() == "OPTIONS" &&
 | 
						|
			(requestedMethod != "" || requestedHeaders != "") {
 | 
						|
			headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders)
 | 
						|
			for key, value := range headers {
 | 
						|
				ctx.Output.Header(key, value)
 | 
						|
			}
 | 
						|
			ctx.ResponseWriter.WriteHeader(http.StatusOK)
 | 
						|
			return
 | 
						|
		}
 | 
						|
		headers = opts.Header(origin)
 | 
						|
 | 
						|
		for key, value := range headers {
 | 
						|
			ctx.Output.Header(key, value)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 |