tpl/data: Allows user-defined HTTP headers with getJSON and getCSV
authorPaul Chamberlain <pauljc@umich.edu>
Mon, 26 Apr 2021 20:41:37 +0000 (16:41 -0400)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Sun, 6 Jun 2021 11:32:12 +0000 (13:32 +0200)
Updates #5617

tpl/data/data.go
tpl/data/data_test.go

index d383447aca0764944f99b5f5b66c5b8b2c42035c..4cb8b5e785250c4194de805be254e3a40b1ecf80 100644 (file)
@@ -58,8 +58,8 @@ type Namespace struct {
 // The data separator can be a comma, semi-colon, pipe, etc, but only one character.
 // If you provide multiple parts for the URL they will be joined together to the final URL.
 // GetCSV returns nil or a slice slice to use in a short code.
-func (ns *Namespace) GetCSV(sep string, urlParts ...interface{}) (d [][]string, err error) {
-       url := joinURL(urlParts)
+func (ns *Namespace) GetCSV(sep string, args ...interface{}) (d [][]string, err error) {
+       url := joinURL(args)
        cache := ns.cacheGetCSV
 
        unmarshal := func(b []byte) (bool, error) {
@@ -85,6 +85,15 @@ func (ns *Namespace) GetCSV(sep string, urlParts ...interface{}) (d [][]string,
        req.Header.Add("Accept", "text/csv")
        req.Header.Add("Accept", "text/plain")
 
+       // Add custom user headers to the get request
+       finalArg := args[len(args)-1]
+
+       if userHeaders, ok := finalArg.(map[string]interface{}); ok {
+               for key, val := range userHeaders {
+                       req.Header.Add(key, val.(string))
+               }
+       }
+
        err = ns.getResource(cache, unmarshal, req)
        if err != nil {
                ns.deps.Log.(loggers.IgnorableLogger).Errorsf(constants.ErrRemoteGetCSV, "Failed to get CSV resource %q: %s", url, err)
@@ -97,9 +106,9 @@ func (ns *Namespace) GetCSV(sep string, urlParts ...interface{}) (d [][]string,
 // GetJSON expects one or n-parts of a URL to a resource which can either be a local or a remote one.
 // If you provide multiple parts they will be joined together to the final URL.
 // GetJSON returns nil or parsed JSON to use in a short code.
-func (ns *Namespace) GetJSON(urlParts ...interface{}) (interface{}, error) {
+func (ns *Namespace) GetJSON(args ...interface{}) (interface{}, error) {
        var v interface{}
-       url := joinURL(urlParts)
+       url := joinURL(args)
        cache := ns.cacheGetJSON
 
        req, err := http.NewRequest("GET", url, nil)
@@ -118,6 +127,15 @@ func (ns *Namespace) GetJSON(urlParts ...interface{}) (interface{}, error) {
        req.Header.Add("Accept", "application/json")
        req.Header.Add("User-Agent", "Hugo Static Site Generator")
 
+       // Add custom user headers to the get request
+       finalArg := args[len(args)-1]
+
+       if userHeaders, ok := finalArg.(map[string]interface{}); ok {
+               for key, val := range userHeaders {
+                       req.Header.Add(key, val.(string))
+               }
+       }
+
        err = ns.getResource(cache, unmarshal, req)
        if err != nil {
                ns.deps.Log.(loggers.IgnorableLogger).Errorsf(constants.ErrRemoteGetJSON, "Failed to get JSON resource %q: %s", url, err)
index f9e8621f24c5c5e046cdad3688364cd3717a0ef2..6b62a2b0d40aeae3fa23bdda715f2b5129f9a64e 100644 (file)
@@ -119,6 +119,20 @@ func TestGetCSV(t *testing.T) {
                c.Assert(got, qt.Not(qt.IsNil), msg)
                c.Assert(got, qt.DeepEquals, test.expect, msg)
 
+               // Test user-defined headers as well
+               gotHeader, _ := ns.GetCSV(test.sep, test.url, map[string]interface{}{"Accept-Charset": "utf-8", "Max-Forwards": "10"})
+
+               if _, ok := test.expect.(bool); ok {
+                       c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1)
+                       // c.Assert(err, msg, qt.Not(qt.IsNil))
+                       c.Assert(got, qt.IsNil)
+                       continue
+               }
+
+               c.Assert(err, qt.IsNil, msg)
+               c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0)
+               c.Assert(gotHeader, qt.Not(qt.IsNil), msg)
+               c.Assert(gotHeader, qt.DeepEquals, test.expect, msg)
        }
 }
 
@@ -206,6 +220,19 @@ func TestGetJSON(t *testing.T) {
                c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0, msg)
                c.Assert(got, qt.Not(qt.IsNil), msg)
                c.Assert(got, qt.DeepEquals, test.expect)
+
+               // Test user-defined headers as well
+               gotHeader, _ := ns.GetJSON(test.url, map[string]interface{}{"Accept-Charset": "utf-8", "Max-Forwards": "10"})
+
+               if _, ok := test.expect.(bool); ok {
+                       c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 1)
+                       // c.Assert(err, msg, qt.Not(qt.IsNil))
+                       continue
+               }
+
+               c.Assert(int(ns.deps.Log.LogCounters().ErrorCounter.Count()), qt.Equals, 0, msg)
+               c.Assert(gotHeader, qt.Not(qt.IsNil), msg)
+               c.Assert(gotHeader, qt.DeepEquals, test.expect)
        }
 }