lazy: Fix concurrent initialization order
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Mon, 29 Apr 2019 17:05:28 +0000 (19:05 +0200)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Wed, 1 May 2019 14:00:31 +0000 (16:00 +0200)
Fixes #5901

lazy/init.go
lazy/init_test.go

index 5c1bee6095ab7107a57b91e6aa9b334f0d7ceb8c..a54fda96afd6d259c5ca63d5517e34d5a8ab3c3e 100644 (file)
@@ -77,30 +77,19 @@ func (ini *Init) Do() (interface{}, error) {
        }
 
        ini.init.Do(func() {
-               var (
-                       dependencies []*Init
-                       children     []*Init
-               )
-
                prev := ini.prev
-               for prev != nil {
+               if prev != nil {
+                       // A branch. Initialize the ancestors.
                        if prev.shouldInitialize() {
-                               dependencies = append(dependencies, prev)
-                       }
-                       prev = prev.prev
-               }
-
-               for _, child := range ini.children {
-                       if child.shouldInitialize() {
-                               children = append(children, child)
-                       }
-               }
-
-               for _, dep := range dependencies {
-                       _, err := dep.Do()
-                       if err != nil {
-                               ini.err = err
-                               return
+                               _, err := prev.Do()
+                               if err != nil {
+                                       ini.err = err
+                                       return
+                               }
+                       } else if prev.inProgress() {
+                               // Concurrent initialization. The following init func
+                               // may depend on earlier state, so wait.
+                               prev.wait()
                        }
                }
 
@@ -108,16 +97,25 @@ func (ini *Init) Do() (interface{}, error) {
                        ini.out, ini.err = ini.f()
                }
 
-               for _, dep := range children {
-                       _, err := dep.Do()
-                       if err != nil {
-                               ini.err = err
-                               return
+               for _, child := range ini.children {
+                       if child.shouldInitialize() {
+                               _, err := child.Do()
+                               if err != nil {
+                                       ini.err = err
+                                       return
+                               }
                        }
                }
-
        })
 
+       ini.wait()
+
+       return ini.out, ini.err
+
+}
+
+// TODO(bep) investigate if we can use sync.Cond for this.
+func (ini *Init) wait() {
        var counter time.Duration
        for !ini.init.Done() {
                counter += 10
@@ -126,8 +124,10 @@ func (ini *Init) Do() (interface{}, error) {
                }
                time.Sleep(counter * time.Microsecond)
        }
+}
 
-       return ini.out, ini.err
+func (ini *Init) inProgress() bool {
+       return ini != nil && ini.init.InProgress()
 }
 
 func (ini *Init) shouldInitialize() bool {
@@ -147,20 +147,19 @@ func (ini *Init) add(branch bool, initFn func() (interface{}, error)) *Init {
        ini.mu.Lock()
        defer ini.mu.Unlock()
 
-       if !branch {
-               ini.checkDone()
-       }
-
-       init := &Init{
-               f:    initFn,
-               prev: ini,
+       if branch {
+               return &Init{
+                       f:    initFn,
+                       prev: ini,
+               }
        }
 
-       if !branch {
-               ini.children = append(ini.children, init)
-       }
+       ini.checkDone()
+       ini.children = append(ini.children, &Init{
+               f: initFn,
+       })
 
-       return init
+       return ini
 }
 
 func (ini *Init) checkDone() {
index bcb57acb32ae522f01323cfbe0e6f1ac90b76366..ea1b22fe95fd56a40cf3b760a19d9868d8f0b55b 100644 (file)
@@ -25,32 +25,41 @@ import (
        "github.com/stretchr/testify/require"
 )
 
+var (
+       rnd        = rand.New(rand.NewSource(time.Now().UnixNano()))
+       bigOrSmall = func() int {
+               if rnd.Intn(10) < 5 {
+                       return 10000 + rnd.Intn(100000)
+               }
+               return 1 + rnd.Intn(50)
+       }
+)
+
+func doWork() {
+       doWorkOfSize(bigOrSmall())
+}
+
+func doWorkOfSize(size int) {
+       _ = strings.Repeat("Hugo Rocks! ", size)
+}
+
 func TestInit(t *testing.T) {
        assert := require.New(t)
 
        var result string
 
-       bigOrSmall := func() int {
-               if rand.Intn(10) < 3 {
-                       return 10000 + rand.Intn(100000)
-               }
-               return 1 + rand.Intn(50)
-       }
-
        f1 := func(name string) func() (interface{}, error) {
                return func() (interface{}, error) {
                        result += name + "|"
-                       size := bigOrSmall()
-                       _ = strings.Repeat("Hugo Rocks! ", size)
+                       doWork()
                        return name, nil
                }
        }
 
        f2 := func() func() (interface{}, error) {
                return func() (interface{}, error) {
-                       size := bigOrSmall()
-                       _ = strings.Repeat("Hugo Rocks! ", size)
-                       return size, nil
+                       doWork()
+                       return nil, nil
                }
        }
 
@@ -73,16 +82,15 @@ func TestInit(t *testing.T) {
                go func(i int) {
                        defer wg.Done()
                        var err error
-                       if rand.Intn(10) < 5 {
+                       if rnd.Intn(10) < 5 {
                                _, err = root.Do()
                                assert.NoError(err)
                        }
 
                        // Add a new branch on the fly.
-                       if rand.Intn(10) > 5 {
+                       if rnd.Intn(10) > 5 {
                                branch := branch1_2.Branch(f2())
-                               init := branch.Add(f2())
-                               _, err = init.Do()
+                               _, err = branch.Do()
                                assert.NoError(err)
                        } else {
                                _, err = branch1_2_1.Do()
@@ -148,3 +156,71 @@ func TestInitAddWithTimeoutError(t *testing.T) {
 
        assert.Error(err)
 }
+
+type T struct {
+       sync.Mutex
+       V1 string
+       V2 string
+}
+
+func (t *T) Add1(v string) {
+       t.Lock()
+       t.V1 += v
+       t.Unlock()
+}
+
+func (t *T) Add2(v string) {
+       t.Lock()
+       t.V2 += v
+       t.Unlock()
+}
+
+// https://github.com/gohugoio/hugo/issues/5901
+func TestInitBranchOrder(t *testing.T) {
+       assert := require.New(t)
+
+       base := New()
+
+       work := func(size int, f func()) func() (interface{}, error) {
+               return func() (interface{}, error) {
+                       doWorkOfSize(size)
+                       if f != nil {
+                               f()
+                       }
+
+                       return nil, nil
+               }
+       }
+
+       state := &T{}
+
+       base = base.Add(work(10000, func() {
+               state.Add1("A")
+       }))
+
+       inits := make([]*Init, 2)
+       for i := range inits {
+               inits[i] = base.Branch(work(i+1*100, func() {
+                       // V1 is A
+                       ab := state.V1 + "B"
+                       state.Add2(ab)
+
+               }))
+       }
+
+       var wg sync.WaitGroup
+
+       for _, v := range inits {
+               v := v
+               wg.Add(1)
+               go func() {
+                       defer wg.Done()
+                       _, err := v.Do()
+                       assert.NoError(err)
+               }()
+       }
+
+       wg.Wait()
+
+       assert.Equal("ABAB", state.V2)
+}