Check status on every request to fix #14

This commit is contained in:
David 2018-06-18 10:02:01 -05:00
parent 9ff8e33634
commit b45378c08f
2 changed files with 24 additions and 15 deletions

View File

@ -85,18 +85,7 @@ func (c *Client) Connect() error {
return err return err
} }
if rs.StatusCode == 401 && c.auth.Type() == "NoAuth" { if rs.StatusCode != 200 || (rs.Header.Get("Dav") == "" && rs.Header.Get("DAV") == "") {
if strings.Index(rs.Header.Get("Www-Authenticate"), "Digest") > -1 {
c.auth = &DigestAuth{c.auth.User(), c.auth.Pass(), digestParts(rs)}
} else if strings.Index(rs.Header.Get("Www-Authenticate"), "Basic") > -1 {
c.auth = &BasicAuth{c.auth.User(), c.auth.Pass()}
} else {
return newPathError("Authorize", c.root, rs.StatusCode)
}
return c.Connect()
} else if rs.StatusCode == 401 {
return newPathError("Authorize", c.root, rs.StatusCode)
} else if rs.StatusCode != 200 || (rs.Header.Get("Dav") == "" && rs.Header.Get("DAV") == "") {
return newPathError("Connect", c.root, rs.StatusCode) return newPathError("Connect", c.root, rs.StatusCode)
} }

View File

@ -1,6 +1,7 @@
package gowebdav package gowebdav
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -8,13 +9,17 @@ import (
) )
func (c *Client) req(method, path string, body io.Reader, intercept func(*http.Request)) (req *http.Response, err error) { func (c *Client) req(method, path string, body io.Reader, intercept func(*http.Request)) (req *http.Response, err error) {
r, err := http.NewRequest(method, PathEscape(Join(c.root, path)), body) // Tee the body, because if authorization fails we will need to read from it again.
var ba bytes.Buffer
bb := io.TeeReader(body, &ba)
r, err := http.NewRequest(method, PathEscape(Join(c.root, path)), &ba)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.auth.Authorize(c, method, path) c.auth.Authorize(c, method, path)
for k, vals := range c.headers { for k, vals := range c.headers {
for _, v := range vals { for _, v := range vals {
r.Header.Add(k, v) r.Header.Add(k, v)
@ -25,7 +30,22 @@ func (c *Client) req(method, path string, body io.Reader, intercept func(*http.R
intercept(r) intercept(r)
} }
return c.c.Do(r) rs, err := c.c.Do(r)
if rs.StatusCode == 401 && c.auth.Type() == "NoAuth" {
if strings.Index(rs.Header.Get("Www-Authenticate"), "Digest") > -1 {
c.auth = &DigestAuth{c.auth.User(), c.auth.Pass(), digestParts(rs)}
} else if strings.Index(rs.Header.Get("Www-Authenticate"), "Basic") > -1 {
c.auth = &BasicAuth{c.auth.User(), c.auth.Pass()}
} else {
return rs, newPathError("Authorize", c.root, rs.StatusCode)
}
return c.req(method, path, bb, intercept)
} else if rs.StatusCode == 401 {
return rs, newPathError("Authorize", c.root, rs.StatusCode)
}
return rs, err
} }
func (c *Client) mkcol(path string) int { func (c *Client) mkcol(path string) int {