From b45378c08f94337804b6fa1c9525746937678c8f Mon Sep 17 00:00:00 2001 From: David Date: Mon, 18 Jun 2018 10:02:01 -0500 Subject: [PATCH] Check status on every request to fix #14 --- client.go | 13 +------------ requests.go | 26 +++++++++++++++++++++++--- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index 9bdbf21..59525e3 100644 --- a/client.go +++ b/client.go @@ -85,18 +85,7 @@ func (c *Client) Connect() error { return err } - if rs.StatusCode == 401 && c.auth.Type() == "NoAuth" { - if strings.Index(rs.Header.Get("Www-Authenticate"), "Digest") > -1 { - c.auth = &DigestAuth{c.auth.User(), c.auth.Pass(), digestParts(rs)} - } else if strings.Index(rs.Header.Get("Www-Authenticate"), "Basic") > -1 { - c.auth = &BasicAuth{c.auth.User(), c.auth.Pass()} - } else { - return newPathError("Authorize", c.root, rs.StatusCode) - } - return c.Connect() - } else if rs.StatusCode == 401 { - return newPathError("Authorize", c.root, rs.StatusCode) - } else if rs.StatusCode != 200 || (rs.Header.Get("Dav") == "" && rs.Header.Get("DAV") == "") { + if rs.StatusCode != 200 || (rs.Header.Get("Dav") == "" && rs.Header.Get("DAV") == "") { return newPathError("Connect", c.root, rs.StatusCode) } diff --git a/requests.go b/requests.go index 17b5515..8444b49 100644 --- a/requests.go +++ b/requests.go @@ -1,6 +1,7 @@ package gowebdav import ( + "bytes" "fmt" "io" "net/http" @@ -8,13 +9,17 @@ import ( ) func (c *Client) req(method, path string, body io.Reader, intercept func(*http.Request)) (req *http.Response, err error) { - r, err := http.NewRequest(method, PathEscape(Join(c.root, path)), body) + // Tee the body, because if authorization fails we will need to read from it again. + var ba bytes.Buffer + bb := io.TeeReader(body, &ba) + + r, err := http.NewRequest(method, PathEscape(Join(c.root, path)), &ba) if err != nil { return nil, err } c.auth.Authorize(c, method, path) - + for k, vals := range c.headers { for _, v := range vals { r.Header.Add(k, v) @@ -25,7 +30,22 @@ func (c *Client) req(method, path string, body io.Reader, intercept func(*http.R intercept(r) } - return c.c.Do(r) + rs, err := c.c.Do(r) + + if rs.StatusCode == 401 && c.auth.Type() == "NoAuth" { + if strings.Index(rs.Header.Get("Www-Authenticate"), "Digest") > -1 { + c.auth = &DigestAuth{c.auth.User(), c.auth.Pass(), digestParts(rs)} + } else if strings.Index(rs.Header.Get("Www-Authenticate"), "Basic") > -1 { + c.auth = &BasicAuth{c.auth.User(), c.auth.Pass()} + } else { + return rs, newPathError("Authorize", c.root, rs.StatusCode) + } + return c.req(method, path, bb, intercept) + } else if rs.StatusCode == 401 { + return rs, newPathError("Authorize", c.root, rs.StatusCode) + } + + return rs, err } func (c *Client) mkcol(path string) int {