227 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			227 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright 2014 beego Author. All Rights Reserved.
 | |
| //
 | |
| // Licensed under the Apache License, Version 2.0 (the "License");
 | |
| // you may not use this file except in compliance with the License.
 | |
| // You may obtain a copy of the License at
 | |
| //
 | |
| //      http://www.apache.org/licenses/LICENSE-2.0
 | |
| //
 | |
| // Unless required by applicable law or agreed to in writing, software
 | |
| // distributed under the License is distributed on an "AS IS" BASIS,
 | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| // See the License for the specific language governing permissions and
 | |
| // limitations under the License.
 | |
| 
 | |
| // Package cors provides handlers to enable CORS support.
 | |
| // Usage
 | |
| //	import (
 | |
| // 		"github.com/astaxie/beego"
 | |
| //		"github.com/astaxie/beego/plugins/cors"
 | |
| // )
 | |
| //
 | |
| //	func main() {
 | |
| //		// CORS for https://foo.* origins, allowing:
 | |
| //		// - PUT and PATCH methods
 | |
| //		// - Origin header
 | |
| //		// - Credentials share
 | |
| //		beego.InsertFilter("*", beego.BeforeRouter, cors.Allow(&cors.Options{
 | |
| //			AllowOrigins:     []string{"https://*.foo.com"},
 | |
| //			AllowMethods:     []string{"PUT", "PATCH"},
 | |
| //			AllowHeaders:     []string{"Origin"},
 | |
| //			ExposeHeaders:    []string{"Content-Length"},
 | |
| //			AllowCredentials: true,
 | |
| //		}))
 | |
| //		beego.Run()
 | |
| //	}
 | |
| package cors
 | |
| 
 | |
| import (
 | |
| 	"regexp"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/astaxie/beego"
 | |
| 	"github.com/astaxie/beego/context"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	headerAllowOrigin      = "Access-Control-Allow-Origin"
 | |
| 	headerAllowCredentials = "Access-Control-Allow-Credentials"
 | |
| 	headerAllowHeaders     = "Access-Control-Allow-Headers"
 | |
| 	headerAllowMethods     = "Access-Control-Allow-Methods"
 | |
| 	headerExposeHeaders    = "Access-Control-Expose-Headers"
 | |
| 	headerMaxAge           = "Access-Control-Max-Age"
 | |
| 
 | |
| 	headerOrigin         = "Origin"
 | |
| 	headerRequestMethod  = "Access-Control-Request-Method"
 | |
| 	headerRequestHeaders = "Access-Control-Request-Headers"
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	defaultAllowHeaders = []string{"Origin", "Accept", "Content-Type", "Authorization"}
 | |
| 	// Regex patterns are generated from AllowOrigins. These are used and generated internally.
 | |
| 	allowOriginPatterns = []string{}
 | |
| )
 | |
| 
 | |
| // Options represents Access Control options.
 | |
| type Options struct {
 | |
| 	// If set, all origins are allowed.
 | |
| 	AllowAllOrigins bool
 | |
| 	// A list of allowed origins. Wild cards and FQDNs are supported.
 | |
| 	AllowOrigins []string
 | |
| 	// If set, allows to share auth credentials such as cookies.
 | |
| 	AllowCredentials bool
 | |
| 	// A list of allowed HTTP methods.
 | |
| 	AllowMethods []string
 | |
| 	// A list of allowed HTTP headers.
 | |
| 	AllowHeaders []string
 | |
| 	// A list of exposed HTTP headers.
 | |
| 	ExposeHeaders []string
 | |
| 	// Max age of the CORS headers.
 | |
| 	MaxAge time.Duration
 | |
| }
 | |
| 
 | |
| // Header converts options into CORS headers.
 | |
| func (o *Options) Header(origin string) (headers map[string]string) {
 | |
| 	headers = make(map[string]string)
 | |
| 	// if origin is not allowed, don't extend the headers
 | |
| 	// with CORS headers.
 | |
| 	if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// add allow origin
 | |
| 	if o.AllowAllOrigins {
 | |
| 		headers[headerAllowOrigin] = "*"
 | |
| 	} else {
 | |
| 		headers[headerAllowOrigin] = origin
 | |
| 	}
 | |
| 
 | |
| 	// add allow credentials
 | |
| 	headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
 | |
| 
 | |
| 	// add allow methods
 | |
| 	if len(o.AllowMethods) > 0 {
 | |
| 		headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
 | |
| 	}
 | |
| 
 | |
| 	// add allow headers
 | |
| 	if len(o.AllowHeaders) > 0 {
 | |
| 		headers[headerAllowHeaders] = strings.Join(o.AllowHeaders, ",")
 | |
| 	}
 | |
| 
 | |
| 	// add exposed header
 | |
| 	if len(o.ExposeHeaders) > 0 {
 | |
| 		headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
 | |
| 	}
 | |
| 	// add a max age header
 | |
| 	if o.MaxAge > time.Duration(0) {
 | |
| 		headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // PreflightHeader converts options into CORS headers for a preflight response.
 | |
| func (o *Options) PreflightHeader(origin, rMethod, rHeaders string) (headers map[string]string) {
 | |
| 	headers = make(map[string]string)
 | |
| 	if !o.AllowAllOrigins && !o.IsOriginAllowed(origin) {
 | |
| 		return
 | |
| 	}
 | |
| 	// verify if requested method is allowed
 | |
| 	for _, method := range o.AllowMethods {
 | |
| 		if method == rMethod {
 | |
| 			headers[headerAllowMethods] = strings.Join(o.AllowMethods, ",")
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// verify if requested headers are allowed
 | |
| 	var allowed []string
 | |
| 	for _, rHeader := range strings.Split(rHeaders, ",") {
 | |
| 		rHeader = strings.TrimSpace(rHeader)
 | |
| 	lookupLoop:
 | |
| 		for _, allowedHeader := range o.AllowHeaders {
 | |
| 			if strings.ToLower(rHeader) == strings.ToLower(allowedHeader) {
 | |
| 				allowed = append(allowed, rHeader)
 | |
| 				break lookupLoop
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	headers[headerAllowCredentials] = strconv.FormatBool(o.AllowCredentials)
 | |
| 	// add allow origin
 | |
| 	if o.AllowAllOrigins {
 | |
| 		headers[headerAllowOrigin] = "*"
 | |
| 	} else {
 | |
| 		headers[headerAllowOrigin] = origin
 | |
| 	}
 | |
| 
 | |
| 	// add allowed headers
 | |
| 	if len(allowed) > 0 {
 | |
| 		headers[headerAllowHeaders] = strings.Join(allowed, ",")
 | |
| 	}
 | |
| 
 | |
| 	// add exposed headers
 | |
| 	if len(o.ExposeHeaders) > 0 {
 | |
| 		headers[headerExposeHeaders] = strings.Join(o.ExposeHeaders, ",")
 | |
| 	}
 | |
| 	// add a max age header
 | |
| 	if o.MaxAge > time.Duration(0) {
 | |
| 		headers[headerMaxAge] = strconv.FormatInt(int64(o.MaxAge/time.Second), 10)
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // IsOriginAllowed looks up if the origin matches one of the patterns
 | |
| // generated from Options.AllowOrigins patterns.
 | |
| func (o *Options) IsOriginAllowed(origin string) (allowed bool) {
 | |
| 	for _, pattern := range allowOriginPatterns {
 | |
| 		allowed, _ = regexp.MatchString(pattern, origin)
 | |
| 		if allowed {
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // Allow enables CORS for requests those match the provided options.
 | |
| func Allow(opts *Options) beego.FilterFunc {
 | |
| 	// Allow default headers if nothing is specified.
 | |
| 	if len(opts.AllowHeaders) == 0 {
 | |
| 		opts.AllowHeaders = defaultAllowHeaders
 | |
| 	}
 | |
| 
 | |
| 	for _, origin := range opts.AllowOrigins {
 | |
| 		pattern := regexp.QuoteMeta(origin)
 | |
| 		pattern = strings.Replace(pattern, "\\*", ".*", -1)
 | |
| 		pattern = strings.Replace(pattern, "\\?", ".", -1)
 | |
| 		allowOriginPatterns = append(allowOriginPatterns, "^"+pattern+"$")
 | |
| 	}
 | |
| 
 | |
| 	return func(ctx *context.Context) {
 | |
| 		var (
 | |
| 			origin           = ctx.Input.Header(headerOrigin)
 | |
| 			requestedMethod  = ctx.Input.Header(headerRequestMethod)
 | |
| 			requestedHeaders = ctx.Input.Header(headerRequestHeaders)
 | |
| 			// additional headers to be added
 | |
| 			// to the response.
 | |
| 			headers map[string]string
 | |
| 		)
 | |
| 
 | |
| 		if ctx.Input.Method() == "OPTIONS" &&
 | |
| 			(requestedMethod != "" || requestedHeaders != "") {
 | |
| 			headers = opts.PreflightHeader(origin, requestedMethod, requestedHeaders)
 | |
| 			for key, value := range headers {
 | |
| 				ctx.Output.Header(key, value)
 | |
| 			}
 | |
| 			return
 | |
| 		}
 | |
| 		headers = opts.Header(origin)
 | |
| 
 | |
| 		for key, value := range headers {
 | |
| 			ctx.Output.Header(key, value)
 | |
| 		}
 | |
| 	}
 | |
| }
 |