media: Also consider extension in FromContent
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Tue, 21 Dec 2021 09:35:33 +0000 (10:35 +0100)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Wed, 22 Dec 2021 10:35:53 +0000 (11:35 +0100)
As used in `resources.GetRemote`.

This will now reject image files with text and text files with images.

hugolib/resource_chain_test.go
media/mediaType.go
media/mediaType_test.go
media/testdata/fake.js [new file with mode: 0644]
media/testdata/fake.png [new file with mode: 0644]
media/testdata/resource.jpe [new file with mode: 0644]
resources/resource_factories/create/remote.go

index 131bce40fb6102ea65273b34c671f1c54fcd933a..8b17b01a4845b0eb9a5f11eab845200921af6374 100644 (file)
@@ -22,16 +22,15 @@ import (
        "net/http"
        "net/http/httptest"
        "os"
-
-       "github.com/gohugoio/hugo/config"
-
-       "github.com/gohugoio/hugo/resources/resource_transformers/tocss/dartsass"
-
        "path/filepath"
        "strings"
        "testing"
        "time"
 
+       "github.com/gohugoio/hugo/config"
+
+       "github.com/gohugoio/hugo/resources/resource_transformers/tocss/dartsass"
+
        jww "github.com/spf13/jwalterweatherman"
 
        "github.com/gohugoio/hugo/common/herrors"
@@ -57,7 +56,6 @@ func TestSCSSWithIncludePaths(t *testing.T) {
                {"libsass", func() bool { return scss.Supports() }},
                {"dartsass", func() bool { return dartsass.Supports() }},
        } {
-
                c.Run(test.name, func(c *qt.C) {
                        if !test.supports() {
                                c.Skip(fmt.Sprintf("Skip %s", test.name))
@@ -107,9 +105,7 @@ T1: {{ $r.Content }}
 
                        b.AssertFileContent(filepath.Join(workDir, "public/index.html"), `T1: moo{color:#fff}`)
                })
-
        }
-
 }
 
 func TestSCSSWithRegularCSSImport(t *testing.T) {
@@ -122,7 +118,6 @@ func TestSCSSWithRegularCSSImport(t *testing.T) {
                {"libsass", func() bool { return scss.Supports() }},
                {"dartsass", func() bool { return dartsass.Supports() }},
        } {
-
                c.Run(test.name, func(c *qt.C) {
                        if !test.supports() {
                                c.Skip(fmt.Sprintf("Skip %s", test.name))
@@ -202,11 +197,9 @@ moo {
 }
 
 /* foo */`)
-
                        }
                })
        }
-
 }
 
 func TestSCSSWithThemeOverrides(t *testing.T) {
@@ -219,7 +212,6 @@ func TestSCSSWithThemeOverrides(t *testing.T) {
                {"libsass", func() bool { return scss.Supports() }},
                {"dartsass", func() bool { return dartsass.Supports() }},
        } {
-
                c.Run(test.name, func(c *qt.C) {
                        if !test.supports() {
                                c.Skip(fmt.Sprintf("Skip %s", test.name))
@@ -319,7 +311,6 @@ T1: {{ $r.Content }}
                        )
                })
        }
-
 }
 
 // https://github.com/gohugoio/hugo/issues/6274
@@ -333,7 +324,6 @@ func TestSCSSWithIncludePathsSass(t *testing.T) {
                {"libsass", func() bool { return scss.Supports() }},
                {"dartsass", func() bool { return dartsass.Supports() }},
        } {
-
                c.Run(test.name, func(c *qt.C) {
                        if !test.supports() {
                                c.Skip(fmt.Sprintf("Skip %s", test.name))
@@ -620,6 +610,7 @@ func TestResourceChains(t *testing.T) {
                        return
 
                case "/authenticated/":
+                       w.Header().Set("Content-Type", "text/plain")
                        if r.Header.Get("Authorization") == "Bearer abcd" {
                                w.Write([]byte(`Welcome`))
                                return
@@ -628,6 +619,7 @@ func TestResourceChains(t *testing.T) {
                        return
 
                case "/post":
+                       w.Header().Set("Content-Type", "text/plain")
                        if r.Method == http.MethodPost {
                                body, err := ioutil.ReadAll(r.Body)
                                if err != nil {
@@ -1247,8 +1239,8 @@ class-in-b {
        // TODO(bep) for some reason, we have starting to get
        // execute of template failed: template: index.html:5:25
        // on CI (GitHub action).
-       //b.Assert(fe.Position().LineNumber, qt.Equals, 5)
-       //b.Assert(fe.Error(), qt.Contains, filepath.Join(workDir, "assets/css/components/b.css:4:1"))
+       // b.Assert(fe.Position().LineNumber, qt.Equals, 5)
+       // b.Assert(fe.Error(), qt.Contains, filepath.Join(workDir, "assets/css/components/b.css:4:1"))
 
        // Remove PostCSS
        b.Assert(os.RemoveAll(filepath.Join(workDir, "node_modules")), qt.IsNil)
index 819de9d80281aaf2caba4dc676d988a421db15f1..47a74ec567c382bc5c55327e8006d6a08894a78b 100644 (file)
@@ -28,6 +28,8 @@ import (
        "github.com/mitchellh/mapstructure"
 )
 
+var zero Type
+
 const (
        defaultDelimiter = "."
 )
@@ -64,16 +66,14 @@ type SuffixInfo struct {
 // FromContent resolve the Type primarily using http.DetectContentType.
 // If http.DetectContentType resolves to application/octet-stream, a zero Type is returned.
 // If http.DetectContentType  resolves to text/plain or application/xml, we try to get more specific using types and ext.
-func FromContent(types Types, ext string, content []byte) Type {
-       ext = strings.TrimPrefix(ext, ".")
+func FromContent(types Types, extensionHints []string, content []byte) Type {
        t := strings.Split(http.DetectContentType(content), ";")[0]
-       var m Type
        if t == "application/octet-stream" {
-               return m
+               return zero
        }
 
        var found bool
-       m, found = types.GetByType(t)
+       m, found := types.GetByType(t)
        if !found {
                if t == "text/xml" {
                        // This is how it's configured in Hugo by default.
@@ -81,19 +81,36 @@ func FromContent(types Types, ext string, content []byte) Type {
                }
        }
 
-       if !found || ext == "" {
-               return m
+       if !found {
+               return zero
+       }
+
+       var mm Type
+
+       for _, extension := range extensionHints {
+               extension = strings.TrimPrefix(extension, ".")
+               mm, _, found = types.GetFirstBySuffix(extension)
+               if found {
+                       break
+               }
        }
 
-       if m.Type() == "text/plain" || m.Type() == "application/xml" {
-               // http.DetectContentType isn't brilliant when it comes to common text formats, so we need to do better.
-               // For now we say that if it's detected to be a text format and the extension/content type in header reports
-               // it to be a text format, then we use that.
-               mm, _, found := types.GetFirstBySuffix(ext)
-               if found && mm.IsText() {
+       if found {
+               if m == mm {
+                       return m
+               }
+
+               if m.IsText() && mm.IsText() {
+                       // http.DetectContentType isn't brilliant when it comes to common text formats, so we need to do better.
+                       // For now we say that if it's detected to be a text format and the extension/content type in header reports
+                       // it to be a text format, then we use that.
                        return mm
                }
+
+               // E.g. an image with a *.js extension.
+               return zero
        }
+
        return m
 }
 
index cd4439fe79bd8216d96dd11ca27c09e977d05309..2e32568f1168f4f631ff72a825be0b2aed678e70 100644 (file)
@@ -15,7 +15,6 @@ package media
 
 import (
        "encoding/json"
-       "fmt"
        "io/ioutil"
        "path/filepath"
        "sort"
@@ -194,15 +193,39 @@ func TestFromContent(t *testing.T) {
                        content, err := ioutil.ReadFile(filename)
                        c.Assert(err, qt.IsNil)
                        ext := strings.TrimPrefix(paths.Ext(filename), ".")
-                       fmt.Println("=>", ext)
+                       var exts []string
+                       if ext == "jpg" {
+                               exts = append(exts, "foo", "bar", "jpg")
+                       } else {
+                               exts = []string{ext}
+                       }
                        expected, _, found := mtypes.GetFirstBySuffix(ext)
                        c.Assert(found, qt.IsTrue)
-                       got := FromContent(mtypes, ext, content)
+                       got := FromContent(mtypes, exts, content)
                        c.Assert(got, qt.Equals, expected)
                })
        }
 }
 
+func TestFromContentFakes(t *testing.T) {
+       c := qt.New(t)
+
+       files, err := filepath.Glob("./testdata/fake.*")
+       c.Assert(err, qt.IsNil)
+       mtypes := DefaultTypes
+
+       for _, filename := range files {
+               name := filepath.Base(filename)
+               c.Run(name, func(c *qt.C) {
+                       content, err := ioutil.ReadFile(filename)
+                       c.Assert(err, qt.IsNil)
+                       ext := strings.TrimPrefix(paths.Ext(filename), ".")
+                       got := FromContent(mtypes, []string{ext}, content)
+                       c.Assert(got, qt.Equals, zero)
+               })
+       }
+}
+
 func TestDecodeTypes(t *testing.T) {
        c := qt.New(t)
 
diff --git a/media/testdata/fake.js b/media/testdata/fake.js
new file mode 100644 (file)
index 0000000..08ae570
Binary files /dev/null and b/media/testdata/fake.js differ
diff --git a/media/testdata/fake.png b/media/testdata/fake.png
new file mode 100644 (file)
index 0000000..75ba3b7
--- /dev/null
@@ -0,0 +1,3 @@
+function foo() {
+    return "foo";
+}
\ No newline at end of file
diff --git a/media/testdata/resource.jpe b/media/testdata/resource.jpe
new file mode 100644 (file)
index 0000000..a9049e8
Binary files /dev/null and b/media/testdata/resource.jpe differ
index f6d3f13dd25663a230ea8474e6b71b73bfd019c2..f127f8edc5ab331f08f65c3ca4a0176350b7f876 100644 (file)
@@ -110,21 +110,30 @@ func (c *Client) FromRemote(uri string, optionsm map[string]interface{}) (resour
                }
        }
 
-       var extensionHint string
-
-       if arr, _ := mime.ExtensionsByType(res.Header.Get("Content-Type")); len(arr) == 1 {
-               extensionHint = arr[0]
+       var extensionHints []string
+
+       contentType := res.Header.Get("Content-Type")
+
+       // mime.ExtensionsByType gives a long list of extensions for text/plain,
+       // just use ".txt".
+       if strings.HasPrefix(contentType, "text/plain") {
+               extensionHints = []string{".txt"}
+       } else {
+               exts, _ := mime.ExtensionsByType(contentType)
+               if exts != nil {
+                       extensionHints = exts
+               }
        }
 
-       // Look for a file extention
-       if extensionHint == "" {
+       // Look for a file extention. If it's .txt, look for a more specific.
+       if extensionHints == nil || extensionHints[0] == ".txt" {
                if ext := path.Ext(filename); ext != "" {
-                       extensionHint = ext
+                       extensionHints = []string{ext}
                }
        }
 
        // Now resolve the media type primarily using the content.
-       mediaType := media.FromContent(c.rs.MediaTypes, extensionHint, body)
+       mediaType := media.FromContent(c.rs.MediaTypes, extensionHints, body)
        if mediaType.IsZero() {
                return nil, errors.Errorf("failed to resolve media type for remote resource %q", uri)
        }
@@ -140,7 +149,6 @@ func (c *Client) FromRemote(uri string, optionsm map[string]interface{}) (resour
                        },
                        RelTargetFilename: filepath.Clean(resourceID),
                })
-
 }
 
 func (c *Client) validateFromRemoteArgs(uri string, options fromRemoteOptions) error {
@@ -213,7 +221,7 @@ func (o fromRemoteOptions) BodyReader() io.Reader {
 }
 
 func decodeRemoteOptions(optionsm map[string]interface{}) (fromRemoteOptions, error) {
-       var options = fromRemoteOptions{
+       options := fromRemoteOptions{
                Method: "GET",
        }
 
@@ -224,5 +232,4 @@ func decodeRemoteOptions(optionsm map[string]interface{}) (fromRemoteOptions, er
        options.Method = strings.ToUpper(options.Method)
 
        return options, nil
-
 }