Add HTTP header support for the dev server
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Sun, 8 Mar 2020 15:33:15 +0000 (16:33 +0100)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Sun, 8 Mar 2020 18:57:30 +0000 (19:57 +0100)
Fixes #7031

commands/commandeer.go
commands/server.go
config/commonConfig.go
config/commonConfig_test.go

index 3054ffb7412c2b32a5891ae99b0a75194a16e504..547bf8bf3262516f21e66926e43765a4b081ed07 100644 (file)
@@ -18,6 +18,8 @@ import (
        "errors"
        "sync"
 
+       hconfig "github.com/gohugoio/hugo/config"
+
        "golang.org/x/sync/semaphore"
 
        "io/ioutil"
@@ -58,7 +60,8 @@ type commandeerHugoState struct {
 type commandeer struct {
        *commandeerHugoState
 
-       logger *loggers.Logger
+       logger       *loggers.Logger
+       serverConfig *config.Server
 
        // Currently only set when in "fast render mode". But it seems to
        // be fast enough that we could maybe just add it for all server modes.
@@ -343,6 +346,7 @@ func (c *commandeer) loadConfig(mustHaveConfigFile, running bool) error {
 
        cfg.Logger = logger
        c.logger = logger
+       c.serverConfig = hconfig.DecodeServer(cfg.Cfg)
 
        createMemFs := config.GetBool("renderToMemory")
 
index 72884749277ddc06c96e5d00c2f42dc5af465ffe..a22a7a69a97f0d3d8ebb78417d9b8ab45cf081d9 100644 (file)
@@ -355,6 +355,10 @@ func (f *fileServer) createEndpoint(i int) (*http.ServeMux, string, string, erro
                                w.Header().Set("Pragma", "no-cache")
                        }
 
+                       for _, header := range f.c.serverConfig.Match(r.RequestURI) {
+                               w.Header().Set(header.Key, header.Value)
+                       }
+
                        if f.c.fastRenderMode && f.c.buildErr == nil {
                                p := r.RequestURI
                                if strings.HasSuffix(p, "/") || strings.HasSuffix(p, "html") || strings.HasSuffix(p, "htm") {
index ab2cfe80b7fd6a3a451df61178d81a4ec08a2502..17d5619bb127a7922b4b9715def669130601c70d 100644 (file)
 package config
 
 import (
+       "sort"
        "strings"
+       "sync"
 
+       "github.com/gohugoio/hugo/common/types"
+
+       "github.com/gobwas/glob"
        "github.com/gohugoio/hugo/common/herrors"
        "github.com/mitchellh/mapstructure"
        "github.com/spf13/cast"
@@ -88,3 +93,57 @@ func DecodeSitemap(prototype Sitemap, input map[string]interface{}) Sitemap {
 
        return prototype
 }
+
+// Config for the dev server.
+type Server struct {
+       Headers []Headers
+
+       compiledInit sync.Once
+       compiled     []glob.Glob
+}
+
+func (s *Server) Match(pattern string) []types.KeyValueStr {
+       s.compiledInit.Do(func() {
+               for _, h := range s.Headers {
+                       s.compiled = append(s.compiled, glob.MustCompile(h.For))
+               }
+       })
+
+       if s.compiled == nil {
+               return nil
+       }
+
+       var matches []types.KeyValueStr
+
+       for i, g := range s.compiled {
+               if g.Match(pattern) {
+                       h := s.Headers[i]
+                       for k, v := range h.Values {
+                               matches = append(matches, types.KeyValueStr{Key: k, Value: cast.ToString(v)})
+                       }
+               }
+       }
+
+       sort.Slice(matches, func(i, j int) bool {
+               return matches[i].Key < matches[j].Key
+       })
+
+       return matches
+
+}
+
+type Headers struct {
+       For    string
+       Values map[string]interface{}
+}
+
+func DecodeServer(cfg Provider) *Server {
+       m := cfg.GetStringMap("server")
+       s := &Server{}
+       if m == nil {
+               return s
+       }
+
+       _ = mapstructure.WeakDecode(m, s)
+       return s
+}
index 281d2b0b6ead9802b34f74528fe3464ba6372ee4..41b2721bc4688d0d3b81af219759a1ba8f311e14 100644 (file)
@@ -18,6 +18,7 @@ import (
        "testing"
 
        "github.com/gohugoio/hugo/common/herrors"
+       "github.com/gohugoio/hugo/common/types"
 
        qt "github.com/frankban/quicktest"
 
@@ -58,3 +59,26 @@ func TestBuild(t *testing.T) {
        c.Assert(b.UseResourceCache(nil), qt.Equals, false)
 
 }
+
+func TestServer(t *testing.T) {
+       c := qt.New(t)
+
+       cfg, err := FromConfigString(`[[server.headers]]
+for = "/*.jpg"
+
+[server.headers.values]
+X-Frame-Options = "DENY"
+X-XSS-Protection = "1; mode=block"
+X-Content-Type-Options = "nosniff"
+`, "toml")
+
+       c.Assert(err, qt.IsNil)
+
+       s := DecodeServer(cfg)
+
+       c.Assert(s.Match("/foo.jpg"), qt.DeepEquals, []types.KeyValueStr{
+               {Key: "X-Content-Type-Options", Value: "nosniff"},
+               {Key: "X-Frame-Options", Value: "DENY"},
+               {Key: "X-XSS-Protection", Value: "1; mode=block"}})
+
+}