Add SafeHtmlAttr, SafeCSS template function
authorTatsushi Demachi <tdemachi@gmail.com>
Mon, 19 Jan 2015 23:55:16 +0000 (08:55 +0900)
committerAnthony Fok <foka@debian.org>
Tue, 20 Jan 2015 02:47:37 +0000 (19:47 -0700)
This allows a template user to keep a safe HTML attribute or CSS string
as is in a template.

This is implementation of @anthonyfok great insight

Fix #784, #347

tpl/template.go
tpl/template_test.go

index 8d55fc897ddb00091dea9174e8025d881f19973b..819343a97aae62c782c9647e0a61e23217ad1488 100644 (file)
@@ -910,6 +910,14 @@ func SafeHtml(text string) template.HTML {
        return template.HTML(text)
 }
 
+func SafeHtmlAttr(text string) template.HTMLAttr {
+       return template.HTMLAttr(text)
+}
+
+func SafeCSS(text string) template.CSS {
+       return template.CSS(text)
+}
+
 func doArithmetic(a, b interface{}, op rune) (interface{}, error) {
        av := reflect.ValueOf(a)
        bv := reflect.ValueOf(b)
@@ -1230,41 +1238,43 @@ func (t *GoHtmlTemplate) LoadTemplates(absPath string) {
 
 func init() {
        funcMap = template.FuncMap{
-               "urlize":      helpers.Urlize,
-               "sanitizeurl": helpers.SanitizeUrl,
-               "eq":          Eq,
-               "ne":          Ne,
-               "gt":          Gt,
-               "ge":          Ge,
-               "lt":          Lt,
-               "le":          Le,
-               "in":          In,
-               "intersect":   Intersect,
-               "isset":       IsSet,
-               "echoParam":   ReturnWhenSet,
-               "safeHtml":    SafeHtml,
-               "markdownify": Markdownify,
-               "first":       First,
-               "where":       Where,
-               "delimit":     Delimit,
-               "sort":        Sort,
-               "highlight":   Highlight,
-               "add":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '+') },
-               "sub":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '-') },
-               "div":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '/') },
-               "mod":         Mod,
-               "mul":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '*') },
-               "modBool":     ModBool,
-               "lower":       func(a string) string { return strings.ToLower(a) },
-               "upper":       func(a string) string { return strings.ToUpper(a) },
-               "title":       func(a string) string { return strings.Title(a) },
-               "partial":     Partial,
-               "ref":         Ref,
-               "relref":      RelRef,
-               "apply":       Apply,
-               "chomp":       Chomp,
-               "replace":     Replace,
-               "trim":        Trim,
+               "urlize":       helpers.Urlize,
+               "sanitizeurl":  helpers.SanitizeUrl,
+               "eq":           Eq,
+               "ne":           Ne,
+               "gt":           Gt,
+               "ge":           Ge,
+               "lt":           Lt,
+               "le":           Le,
+               "in":           In,
+               "intersect":    Intersect,
+               "isset":        IsSet,
+               "echoParam":    ReturnWhenSet,
+               "safeHtml":     SafeHtml,
+               "safeHtmlAttr": SafeHtmlAttr,
+               "safeCSS":      SafeCSS,
+               "markdownify":  Markdownify,
+               "first":        First,
+               "where":        Where,
+               "delimit":      Delimit,
+               "sort":         Sort,
+               "highlight":    Highlight,
+               "add":          func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '+') },
+               "sub":          func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '-') },
+               "div":          func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '/') },
+               "mod":          Mod,
+               "mul":          func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '*') },
+               "modBool":      ModBool,
+               "lower":        func(a string) string { return strings.ToLower(a) },
+               "upper":        func(a string) string { return strings.ToUpper(a) },
+               "title":        func(a string) string { return strings.Title(a) },
+               "partial":      Partial,
+               "ref":          Ref,
+               "relref":       RelRef,
+               "apply":        Apply,
+               "chomp":        Chomp,
+               "replace":      Replace,
+               "trim":         Trim,
        }
 
        chompRegexp = regexp.MustCompile("[\r\n]+$")
index 98cf2d061dc119045b938f02ba663fc2aaf170fc..f857e6341439c81b4df03f0ad28ca51683be233d 100644 (file)
@@ -1,6 +1,7 @@
 package tpl
 
 import (
+       "bytes"
        "errors"
        "fmt"
        "html/template"
@@ -826,3 +827,107 @@ func TestMarkdownify(t *testing.T) {
                t.Errorf("Markdownify: got '%s', expected '%s'", result, expect)
        }
 }
+
+func TestSafeHtml(t *testing.T) {
+       for i, this := range []struct {
+               str                 string
+               tmplStr             string
+               expectWithoutEscape string
+               expectWithEscape    string
+       }{
+               {`<div></div>`, `{{ . }}`, `&lt;div&gt;&lt;/div&gt;`, `<div></div>`},
+       } {
+               tmpl, err := template.New("test").Parse(this.tmplStr)
+               if err != nil {
+                       t.Errorf("[%d] unable to create new html template %q: %s", this.tmplStr, err)
+                       continue
+               }
+
+               buf := new(bytes.Buffer)
+               err = tmpl.Execute(buf, this.str)
+               if err != nil {
+                       t.Errorf("[%d] execute template with a raw string value returns unexpected error: %s", i, err)
+               }
+               if buf.String() != this.expectWithoutEscape {
+                       t.Errorf("[%d] execute template with a raw string value, got %v but expected %v", i, buf.String(), this.expectWithoutEscape)
+               }
+
+               buf.Reset()
+               err = tmpl.Execute(buf, SafeHtml(this.str))
+               if err != nil {
+                       t.Errorf("[%d] execute template with an escaped string value by SafeHtml returns unexpected error: %s", i, err)
+               }
+               if buf.String() != this.expectWithEscape {
+                       t.Errorf("[%d] execute template with an escaped string value by SafeHtml, got %v but expected %v", i, buf.String(), this.expectWithEscape)
+               }
+       }
+}
+
+func TestSafeHtmlAttr(t *testing.T) {
+       for i, this := range []struct {
+               str                 string
+               tmplStr             string
+               expectWithoutEscape string
+               expectWithEscape    string
+       }{
+               {`href="irc://irc.freenode.net/#golang"`, `<a {{ . }}>irc</a>`, `<a ZgotmplZ>irc</a>`, `<a href="irc://irc.freenode.net/#golang">irc</a>`},
+       } {
+               tmpl, err := template.New("test").Parse(this.tmplStr)
+               if err != nil {
+                       t.Errorf("[%d] unable to create new html template %q: %s", this.tmplStr, err)
+                       continue
+               }
+
+               buf := new(bytes.Buffer)
+               err = tmpl.Execute(buf, this.str)
+               if err != nil {
+                       t.Errorf("[%d] execute template with a raw string value returns unexpected error: %s", i, err)
+               }
+               if buf.String() != this.expectWithoutEscape {
+                       t.Errorf("[%d] execute template with a raw string value, got %v but expected %v", i, buf.String(), this.expectWithoutEscape)
+               }
+
+               buf.Reset()
+               err = tmpl.Execute(buf, SafeHtmlAttr(this.str))
+               if err != nil {
+                       t.Errorf("[%d] execute template with an escaped string value by SafeHtmlAttr returns unexpected error: %s", i, err)
+               }
+               if buf.String() != this.expectWithEscape {
+                       t.Errorf("[%d] execute template with an escaped string value by SafeHtmlAttr, got %v but expected %v", i, buf.String(), this.expectWithEscape)
+               }
+       }
+}
+
+func TestSafeCSS(t *testing.T) {
+       for i, this := range []struct {
+               str                 string
+               tmplStr             string
+               expectWithoutEscape string
+               expectWithEscape    string
+       }{
+               {`width: 60px;`, `<div style="{{ . }}"></div>`, `<div style="ZgotmplZ"></div>`, `<div style="width: 60px;"></div>`},
+       } {
+               tmpl, err := template.New("test").Parse(this.tmplStr)
+               if err != nil {
+                       t.Errorf("[%d] unable to create new html template %q: %s", this.tmplStr, err)
+               }
+
+               buf := new(bytes.Buffer)
+               err = tmpl.Execute(buf, this.str)
+               if err != nil {
+                       t.Errorf("[%d] execute template with a raw string value returns unexpected error: %s", i, err)
+               }
+               if buf.String() != this.expectWithoutEscape {
+                       t.Errorf("[%d] execute template with a raw string value, got %v but expected %v", i, buf.String(), this.expectWithoutEscape)
+               }
+
+               buf.Reset()
+               err = tmpl.Execute(buf, SafeCSS(this.str))
+               if err != nil {
+                       t.Errorf("[%d] execute template with an escaped string value by SafeCSS returns unexpected error: %s", i, err)
+               }
+               if buf.String() != this.expectWithEscape {
+                       t.Errorf("[%d] execute template with an escaped string value by SafeCSS, got %v but expected %v", i, buf.String(), this.expectWithEscape)
+               }
+       }
+}