超时处理程序中的竞争条件

时间:2016-11-16 22:18:36

标签: http go race-condition

我可以在下面的示例代码中看到两个主要问题,但我不知道如何正确解决它们。

如果超时处理程序没有通过errCh获得下一个处理程序已完成或发生错误的信号,它将回复" 408请求超时"对请求。

这里的问题是ResponseWriter不被多个goroutine使用是安全的。并且超时处理程序在执行下一个处理程序时启动一个新的goroutine。

的问题:

  1. 当ctx的Done通道在超时处理程序中超时时,如何防止下一个处理程序写入ResponseWriter。

  2. 如何阻止超时处理程序在下一个处理程序写入ResponseWriter时回复408状态代码但尚未完成且ctx的Done通道在超时处理程序中超时。

  3. package main
    
    import (
      "context"
      "fmt"
      "net/http"
      "time"
    )
    
    func main() {
      http.Handle("/race", handlerFunc(timeoutHandler))
      http.ListenAndServe(":8080", nil)
    }
    
    func timeoutHandler(w http.ResponseWriter, r *http.Request) error {
      const seconds = 1
      ctx, cancel := context.WithTimeout(r.Context(), time.Duration(seconds)*time.Second)
      defer cancel()
    
      r = r.WithContext(ctx)
    
      errCh := make(chan error, 1)
      go func() {
        // w is not safe for concurrent use by multiple goroutines
        errCh <- nextHandler(w, r)
      }()
    
      select {
      case err := <-errCh:
        return err
      case <-ctx.Done():
        // w is not safe for concurrent use by multiple goroutines
        http.Error(w, "Request timeout", 408)
        return nil
      }
    }
    
    func nextHandler(w http.ResponseWriter, r *http.Request) error {
      // just for fun to simulate a better race condition
      const seconds = 1
      time.Sleep(time.Duration(seconds) * time.Second)
      fmt.Fprint(w, "nextHandler")
      return nil
    }
    
    type handlerFunc func(w http.ResponseWriter, r *http.Request) error
    
    func (fn handlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
      if err := fn(w, r); err != nil {
        http.Error(w, "Server error", 500)
      }
    }
    

1 个答案:

答案 0 :(得分:0)

这是一个可能的解决方案,它基于@Andy的评论。

responseRecorder将传递给nextHandler,录制的回复将被复制回客户端:

func timeoutHandler(w http.ResponseWriter, r *http.Request) error {
    const seconds = 1
    ctx, cancel := context.WithTimeout(r.Context(),
        time.Duration(seconds)*time.Second)
    defer cancel()

    r = r.WithContext(ctx)

    errCh := make(chan error, 1)
    w2 := newResponseRecorder()
    go func() {
        errCh <- nextHandler(w2, r)
    }()

    select {
    case err := <-errCh:
        if err != nil {
            return err
        }

        w2.cloneHeader(w.Header())
        w.WriteHeader(w2.status)
        w.Write(w2.buf.Bytes())
        return nil
    case <-ctx.Done():
        http.Error(w, "Request timeout", 408)
        return nil
    }
}

这是responseRecorder

type responseRecorder struct {
    http.ResponseWriter
    header http.Header
    buf    *bytes.Buffer
    status int
}

func newResponseRecorder() *responseRecorder {
    return &responseRecorder{
        header: http.Header{},
        buf:    &bytes.Buffer{},
    }
}

func (w *responseRecorder) Header() http.Header {
    return w.header
}

func (w *responseRecorder) cloneHeader(dst http.Header) {
    for k, v := range w.header {
        tmp := make([]string, len(v))
        copy(tmp, v)
        dst[k] = tmp
    }
}

func (w *responseRecorder) Write(data []byte) (int, error) {
    if w.status == 0 {
        w.WriteHeader(http.StatusOK)
    }
    return w.buf.Write(data)
}

func (w *responseRecorder) WriteHeader(status int) {
    w.status = status
}
相关问题