Protect against concurrent Scratch read and write
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Mon, 21 Mar 2016 19:42:27 +0000 (20:42 +0100)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Mon, 21 Mar 2016 19:42:27 +0000 (20:42 +0100)
Fixes #2005

hugolib/scratch.go
hugolib/scratch_test.go

index 29dfd492f359d507ad7422136ebf8e8010cc8d17..ec0b184201eaf3da7c2fd42391f232cfdb38a07c 100644 (file)
@@ -17,11 +17,13 @@ import (
        "github.com/spf13/hugo/helpers"
        "reflect"
        "sort"
+       "sync"
 )
 
 // Scratch is a writable context used for stateful operations in Page/Node rendering.
 type Scratch struct {
        values map[string]interface{}
+       mu     sync.RWMutex
 }
 
 // For single values, Add will add (using the + operator) the addend to the existing addend (if found).
@@ -29,6 +31,9 @@ type Scratch struct {
 //
 // If the first add for a key is an array or slice, then the next value(s) will be appended.
 func (c *Scratch) Add(key string, newAddend interface{}) (string, error) {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+
        var newVal interface{}
        existingAddend, found := c.values[key]
        if found {
@@ -59,18 +64,27 @@ func (c *Scratch) Add(key string, newAddend interface{}) (string, error) {
 // Set stores a value with the given key in the Node context.
 // This value can later be retrieved with Get.
 func (c *Scratch) Set(key string, value interface{}) string {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+
        c.values[key] = value
        return ""
 }
 
 // Get returns a value previously set by Add or Set
 func (c *Scratch) Get(key string) interface{} {
+       c.mu.RLock()
+       defer c.mu.RUnlock()
+
        return c.values[key]
 }
 
 // SetInMap stores a value to a map with the given key in the Node context.
 // This map can later be retrieved with GetSortedMapValues.
 func (c *Scratch) SetInMap(key string, mapKey string, value interface{}) string {
+       c.mu.Lock()
+       defer c.mu.Unlock()
+
        _, found := c.values[key]
        if !found {
                c.values[key] = make(map[string]interface{})
@@ -82,6 +96,9 @@ func (c *Scratch) SetInMap(key string, mapKey string, value interface{}) string
 
 // GetSortedMapValues returns a sorted map previously filled with SetInMap
 func (c *Scratch) GetSortedMapValues(key string) interface{} {
+       c.mu.RLock()
+       defer c.mu.RUnlock()
+
        if c.values[key] == nil {
                return nil
        }
index c90ef733df5578391ace23f3f5d011c99852eb8e..3632ec16cae52d5a8af43798a24d74df36a472e4 100644 (file)
@@ -16,6 +16,7 @@ package hugolib
 import (
        "github.com/stretchr/testify/assert"
        "reflect"
+       "sync"
        "testing"
 )
 
@@ -80,6 +81,41 @@ func TestScratchSet(t *testing.T) {
        assert.Equal(t, "val", scratch.Get("key"))
 }
 
+// Issue #2005
+func TestScratchInParallel(t *testing.T) {
+       var wg sync.WaitGroup
+       scratch := newScratch()
+       key := "counter"
+       scratch.Set(key, 1)
+       for i := 1; i <= 10; i++ {
+               wg.Add(1)
+               go func(j int) {
+                       for k := 0; k < 10; k++ {
+                               newVal := k + j
+
+                               _, err := scratch.Add(key, newVal)
+                               if err != nil {
+                                       t.Errorf("Got err %s", err)
+                               }
+
+                               scratch.Set(key, newVal)
+
+                               val := scratch.Get(key)
+
+                               if counter, ok := val.(int); ok {
+                                       if counter < 1 {
+                                               t.Errorf("Got %d", counter)
+                                       }
+                               } else {
+                                       t.Errorf("Got %T", val)
+                               }
+                       }
+                       wg.Done()
+               }(i)
+       }
+       wg.Wait()
+}
+
 func TestScratchGet(t *testing.T) {
        scratch := newScratch()
        nothing := scratch.Get("nothing")