diff --git a/basicAuth.go b/basicAuth.go index 5a69113..bdb86da 100644 --- a/basicAuth.go +++ b/basicAuth.go @@ -2,6 +2,7 @@ package gowebdav import ( "encoding/base64" + "net/http" ) // BasicAuth structure holds our credentials @@ -26,8 +27,8 @@ func (b *BasicAuth) Pass() string { } // Authorize the current request -func (b *BasicAuth) Authorize(c *Client, method string, path string) { +func (b *BasicAuth) Authorize(req *http.Request, method string, path string) { a := b.user + ":" + b.pw auth := "Basic " + base64.StdEncoding.EncodeToString([]byte(a)) - c.headers.Set("Authorization", auth) + req.Header.Set("Authorization", auth) } diff --git a/client.go b/client.go index 17459b9..1c6306e 100644 --- a/client.go +++ b/client.go @@ -9,6 +9,7 @@ import ( "os" pathpkg "path" "strings" + "sync" "time" ) @@ -17,7 +18,9 @@ type Client struct { root string headers http.Header c *http.Client - auth Authenticator + + authMutex sync.Mutex + auth Authenticator } // Authenticator stub @@ -25,7 +28,7 @@ type Authenticator interface { Type() string User() string Pass() string - Authorize(*Client, string, string) + Authorize(*http.Request, string, string) } // NoAuth structure holds our credentials @@ -50,12 +53,12 @@ func (n *NoAuth) Pass() string { } // Authorize the current request -func (n *NoAuth) Authorize(c *Client, method string, path string) { +func (n *NoAuth) Authorize(req *http.Request, method string, path string) { } // NewClient creates a new instance of client func NewClient(uri, user, pw string) *Client { - return &Client{FixSlash(uri), make(http.Header), &http.Client{}, &NoAuth{user, pw}} + return &Client{FixSlash(uri), make(http.Header), &http.Client{}, sync.Mutex{}, &NoAuth{user, pw}} } // SetHeader lets us set arbitrary headers for a given client diff --git a/digestAuth.go b/digestAuth.go index dd5c844..4a5eb62 100644 --- a/digestAuth.go +++ b/digestAuth.go @@ -33,12 +33,12 @@ func (d *DigestAuth) Pass() string { } // Authorize the current request -func (d *DigestAuth) Authorize(c *Client, method string, path string) { +func (d *DigestAuth) Authorize(req *http.Request, method string, path string) { d.digestParts["uri"] = path d.digestParts["method"] = method d.digestParts["username"] = d.user d.digestParts["password"] = d.pw - c.headers.Set("Authorization", getDigestAuthorization(d.digestParts)) + req.Header.Set("Authorization", getDigestAuthorization(d.digestParts)) } func digestParts(resp *http.Response) map[string]string { diff --git a/requests.go b/requests.go index 3b9f9c0..bfa1e94 100644 --- a/requests.go +++ b/requests.go @@ -25,14 +25,20 @@ func (c *Client) req(method, path string, body io.Reader, intercept func(*http.R return nil, err } - c.auth.Authorize(c, method, path) - for k, vals := range c.headers { for _, v := range vals { r.Header.Add(k, v) } } + // make sure we read 'c.auth' only once since it will be substituted below + // and that is unsafe to do when multiple goroutines are running at the same time. + c.authMutex.Lock() + auth := c.auth + c.authMutex.Unlock() + + auth.Authorize(r, method, path) + if intercept != nil { intercept(r) } @@ -42,16 +48,17 @@ func (c *Client) req(method, path string, body io.Reader, intercept func(*http.R return nil, err } - if rs.StatusCode == 401 && c.auth.Type() == "NoAuth" { - + if rs.StatusCode == 401 && auth.Type() == "NoAuth" { wwwAuthenticateHeader := strings.ToLower(rs.Header.Get("Www-Authenticate")) if strings.Index(wwwAuthenticateHeader, "digest") > -1 { - c.auth = &DigestAuth{c.auth.User(), c.auth.Pass(), digestParts(rs)} - + c.authMutex.Lock() + c.auth = &DigestAuth{auth.User(), auth.Pass(), digestParts(rs)} + c.authMutex.Unlock() } else if strings.Index(wwwAuthenticateHeader, "basic") > -1 { - c.auth = &BasicAuth{c.auth.User(), c.auth.Pass()} - + c.authMutex.Lock() + c.auth = &BasicAuth{auth.User(), auth.Pass()} + c.authMutex.Unlock() } else { return rs, newPathError("Authorize", c.root, rs.StatusCode) }