Skip to content

Commit

Permalink
Create functional option for ctx.SetCookie (#208)
Browse files Browse the repository at this point in the history
Co-authored-by: ᴜɴᴋɴᴡᴏɴ <u@gogs.io>
Co-authored-by: 6543 <6543@obermui.de>
  • Loading branch information
3 people committed Nov 13, 2020
1 parent d229aed commit 6f0734a
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 23 deletions.
61 changes: 43 additions & 18 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ type Request struct {
*http.Request
}

// Body returns a RequestBody for the request
func (r *Request) Body() *RequestBody {
return &RequestBody{r.Request.Body}
}

// ContextInvoker is an inject.FastInvoker wrapper of func(ctx *Context).
type ContextInvoker func(ctx *Context)

// Invoke implements inject.FastInvoker which simplifies calls of `func(ctx *Context)` function.
func (invoke ContextInvoker) Invoke(params []interface{}) ([]reflect.Value, error) {
invoke(params[0].(*Context))
return nil, nil
Expand All @@ -97,41 +99,43 @@ type Context struct {
Data map[string]interface{}
}

func (c *Context) handler() Handler {
if c.index < len(c.handlers) {
return c.handlers[c.index]
func (ctx *Context) handler() Handler {
if ctx.index < len(ctx.handlers) {
return ctx.handlers[ctx.index]
}
if c.index == len(c.handlers) {
return c.action
if ctx.index == len(ctx.handlers) {
return ctx.action
}
panic("invalid index for context handler")
}

func (c *Context) Next() {
c.index += 1
c.run()
// Next runs the next handler in the context chain
func (ctx *Context) Next() {
ctx.index++
ctx.run()
}

func (c *Context) Written() bool {
return c.Resp.Written()
// Written returns whether the context response has been written to
func (ctx *Context) Written() bool {
return ctx.Resp.Written()
}

func (c *Context) run() {
for c.index <= len(c.handlers) {
vals, err := c.Invoke(c.handler())
func (ctx *Context) run() {
for ctx.index <= len(ctx.handlers) {
vals, err := ctx.Invoke(ctx.handler())
if err != nil {
panic(err)
}
c.index += 1
ctx.index++

// if the handler returned something, write it to the http response
if len(vals) > 0 {
ev := c.GetVal(reflect.TypeOf(ReturnHandler(nil)))
ev := ctx.GetVal(reflect.TypeOf(ReturnHandler(nil)))
handleReturn := ev.Interface().(ReturnHandler)
handleReturn(c, vals)
handleReturn(ctx, vals)
}

if c.Written() {
if ctx.Written() {
return
}
}
Expand Down Expand Up @@ -172,6 +176,7 @@ func (ctx *Context) HTMLSet(status int, setName, tplName string, data ...interfa
ctx.renderHTML(status, setName, tplName, data...)
}

// Redirect sends a redirect response
func (ctx *Context) Redirect(location string, status ...int) {
code := http.StatusFound
if len(status) == 1 {
Expand All @@ -181,7 +186,7 @@ func (ctx *Context) Redirect(location string, status ...int) {
http.Redirect(ctx.Resp, ctx.Req.Request, location, code)
}

// Maximum amount of memory to use when parsing a multipart form.
// MaxMemory is the maximum amount of memory to use when parsing a multipart form.
// Set this to whatever value you prefer; default is 10 MB.
var MaxMemory = int64(1024 * 1024 * 10)

Expand Down Expand Up @@ -341,26 +346,34 @@ func (ctx *Context) SetCookie(name string, value string, others ...interface{})
cookie.MaxAge = int(v)
case int32:
cookie.MaxAge = int(v)
case func(*http.Cookie):
v(&cookie)
}
}

cookie.Path = "/"
if len(others) > 1 {
if v, ok := others[1].(string); ok && len(v) > 0 {
cookie.Path = v
} else if v, ok := others[1].(func(*http.Cookie)); ok {
v(&cookie)
}
}

if len(others) > 2 {
if v, ok := others[2].(string); ok && len(v) > 0 {
cookie.Domain = v
} else if v, ok := others[1].(func(*http.Cookie)); ok {
v(&cookie)
}
}

if len(others) > 3 {
switch v := others[3].(type) {
case bool:
cookie.Secure = v
case func(*http.Cookie):
v(&cookie)
default:
if others[3] != nil {
cookie.Secure = true
Expand All @@ -371,13 +384,25 @@ func (ctx *Context) SetCookie(name string, value string, others ...interface{})
if len(others) > 4 {
if v, ok := others[4].(bool); ok && v {
cookie.HttpOnly = true
} else if v, ok := others[1].(func(*http.Cookie)); ok {
v(&cookie)
}
}

if len(others) > 5 {
if v, ok := others[5].(time.Time); ok {
cookie.Expires = v
cookie.RawExpires = v.Format(time.UnixDate)
} else if v, ok := others[1].(func(*http.Cookie)); ok {
v(&cookie)
}
}

if len(others) > 6 {
for _, other := range others[6:] {
if v, ok := other.(func(*http.Cookie)); ok {
v(&cookie)
}
}
}

Expand Down
14 changes: 13 additions & 1 deletion context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"time"

"github.com/unknwon/com"
"gopkg.in/macaron.v1/cookie"

. "github.com/smartystreets/goconvey/convey"
)
Expand Down Expand Up @@ -209,7 +210,18 @@ func Test_Context(t *testing.T) {
So(err, ShouldBeNil)
ctx.SetCookie("user", "Unknwon", 1, "/", "localhost", true, true, t)
ctx.SetCookie("user", "Unknwon", int32(1), "/", "localhost", 1)
ctx.SetCookie("user", "Unknwon", int64(1))
called := false
ctx.SetCookie("user", "Unknwon", int64(1), func(c *http.Cookie) {
called = true
})
So(called, ShouldBeTrue)
ctx.SetCookie("user", "Unknown",
cookie.Secure(true),
cookie.HttpOnly(true),
cookie.Path("/"),
cookie.MaxAge(1),
cookie.Domain("localhost"),
)
})

resp := httptest.NewRecorder()
Expand Down
78 changes: 78 additions & 0 deletions cookie/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright 2020 The Macaron Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.

// Package cookie contains helper functions for setting cookie values.
package cookie

import (
"net/http"
"time"
)

// MaxAge sets the maximum age for a provided cookie
func MaxAge(maxAge int) func(*http.Cookie) {
return func(c *http.Cookie) {
c.MaxAge = maxAge
}
}

// Path sets the path for a provided cookie
func Path(path string) func(*http.Cookie) {
return func(c *http.Cookie) {
c.Path = path
}
}

// Domain sets the domain for a provided cookie
func Domain(domain string) func(*http.Cookie) {
return func(c *http.Cookie) {
c.Domain = domain
}
}

// Secure sets the secure setting for a provided cookie
func Secure(secure bool) func(*http.Cookie) {
return func(c *http.Cookie) {
c.Secure = secure
}
}

// HttpOnly sets the HttpOnly setting for a provided cookie
func HttpOnly(httpOnly bool) func(*http.Cookie) {
return func(c *http.Cookie) {
c.HttpOnly = httpOnly
}
}

// HTTPOnly sets the HttpOnly setting for a provided cookie
func HTTPOnly(httpOnly bool) func(*http.Cookie) {
return func(c *http.Cookie) {
c.HttpOnly = httpOnly
}
}

// Expires sets the expires and rawexpires for a provided cookie
func Expires(expires time.Time) func(*http.Cookie) {
return func(c *http.Cookie) {
c.Expires = expires
c.RawExpires = expires.Format(time.UnixDate)
}
}

// SameSite sets the SameSite for a provided cookie
func SameSite(sameSite http.SameSite) func(*http.Cookie) {
return func(c *http.Cookie) {
c.SameSite = sameSite
}
}
8 changes: 4 additions & 4 deletions tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ func Test_getWildcards(t *testing.T) {
":id([0-9]+)_:name": result{"([0-9]+)_(.+)", ":id :name"},
"article_:id_:page.html": result{"article_(.+)_(.+).html", ":id :page"},
"article_:id:int_:page:string.html": result{"article_([0-9]+)_([\\w]+).html", ":id :page"},
"*": result{"*", ""},
"*.*": result{"*.*", ""},
"*": result{"*", ""},
"*.*": result{"*.*", ""},
}
Convey("Get wildcards", t, func() {
for key, result := range cases {
Expand All @@ -56,8 +56,8 @@ func Test_getRawPattern(t *testing.T) {
"article_:id_:page.html": "article_:id_:page.html",
"article_:id:int_:page:string.html": "article_:id_:page.html",
"article_:id([0-9]+)_:page([\\w]+).html": "article_:id_:page.html",
"*": "*",
"*.*": "*.*",
"*": "*",
"*.*": "*.*",
}
Convey("Get raw pattern", t, func() {
for k, v := range cases {
Expand Down

0 comments on commit 6f0734a

Please sign in to comment.