Feat: Authentication API

The changes simplify the `req` method by moving the
authentication-related code into the API.
This makes it easy to add additional authentication methods.

The API introduces an `Authorizer` that acts as an
authenticator factory. The authentication flow itself
is divided down into `Authorize` and `Verify` steps in order
to encapsulate and control complex authentication challenges.

The default `NewAutoAuth` negotiates the algorithms.
Under the hood, it creates an authenticator shim per request,
which delegates the authentication flow to our authenticators.

The `NewEmptyAuth` and `NewPreemptiveAuth` authorizers
allow you to have more control over algorithms and resources.

The API also allows interception of the redirect mechanism by setting
the `XInhibitRedirect` header.

This closes: #15 #24 #38
This commit is contained in:
Christoph Polcin
2023-02-03 10:18:35 +01:00
committed by Christoph Polcin
parent 3282f94193
commit ca40e2802e
12 changed files with 994 additions and 285 deletions

View File

@@ -1,7 +1,6 @@
package gowebdav
import (
"bytes"
"io"
"log"
"net/http"
@@ -9,83 +8,54 @@ import (
"strings"
)
func (c *Client) req(method, path string, body io.Reader, intercept func(*http.Request)) (req *http.Response, err error) {
func (c *Client) req(method, path string, body io.Reader, intercept func(*http.Request)) (rs *http.Response, err error) {
var redo bool
var r *http.Request
var retryBuf io.Reader
var uri = PathEscape(Join(c.root, path))
auth, body := c.auth.NewAuthenticator(body)
defer auth.Close()
if body != nil {
// If the authorization fails, we will need to restart reading
// from the passed body stream.
// When body is seekable, use seek to reset the streams
// cursor to the start.
// Otherwise, copy the stream into a buffer while uploading
// and use the buffers content on retry.
if sk, ok := body.(io.Seeker); ok {
if _, err = sk.Seek(0, io.SeekStart); err != nil {
return
for { // TODO auth.continue() strategy(true|n times|until)?
if r, err = http.NewRequest(method, uri, body); err != nil {
return
}
for k, vals := range c.headers {
for _, v := range vals {
r.Header.Add(k, v)
}
retryBuf = body
} else {
buff := &bytes.Buffer{}
retryBuf = buff
body = io.TeeReader(body, buff)
}
r, err = http.NewRequest(method, PathEscape(Join(c.root, path)), body)
} else {
r, err = http.NewRequest(method, PathEscape(Join(c.root, path)), nil)
}
if err != nil {
return nil, err
}
for k, vals := range c.headers {
for _, v := range vals {
r.Header.Add(k, v)
}
}
// make sure we read 'c.auth' only once since it will be substituted below
// and that is unsafe to do when multiple goroutines are running at the same time.
c.authMutex.Lock()
auth := c.auth
c.authMutex.Unlock()
auth.Authorize(r, method, path)
if intercept != nil {
intercept(r)
}
if c.interceptor != nil {
c.interceptor(method, r)
}
rs, err := c.c.Do(r)
if err != nil {
return nil, err
}
if rs.StatusCode == 401 && auth.Type() == "NoAuth" {
wwwAuthenticateHeader := strings.ToLower(rs.Header.Get("Www-Authenticate"))
if strings.Index(wwwAuthenticateHeader, "digest") > -1 {
c.authMutex.Lock()
c.auth = &DigestAuth{auth.User(), auth.Pass(), digestParts(rs)}
c.authMutex.Unlock()
} else if strings.Index(wwwAuthenticateHeader, "basic") > -1 {
c.authMutex.Lock()
c.auth = &BasicAuth{auth.User(), auth.Pass()}
c.authMutex.Unlock()
} else {
return rs, newPathError("Authorize", c.root, rs.StatusCode)
}
// retryBuf will be nil if body was nil initially so no check
// for body == nil is required here.
return c.req(method, path, retryBuf, intercept)
} else if rs.StatusCode == 401 {
return rs, newPathError("Authorize", c.root, rs.StatusCode)
if err = auth.Authorize(c.c, r, path); err != nil {
return
}
if intercept != nil {
intercept(r)
}
if c.interceptor != nil {
c.interceptor(method, r)
}
if rs, err = c.c.Do(r); err != nil {
return
}
if redo, err = auth.Verify(c.c, rs, path); err != nil {
io.Copy(io.Discard, rs.Body)
rs.Body.Close()
return nil, err
}
if redo {
io.Copy(io.Discard, rs.Body)
rs.Body.Close()
if body, err = r.GetBody(); err != nil {
return nil, err
}
continue
}
break
}
return rs, err
@@ -131,7 +101,7 @@ func (c *Client) propfind(path string, self bool, body string, resp interface{},
defer rs.Body.Close()
if rs.StatusCode != 207 {
return newPathError("PROPFIND", path, rs.StatusCode)
return NewPathError("PROPFIND", path, rs.StatusCode)
}
return parseXML(rs.Body, resp, parse)
@@ -189,7 +159,7 @@ func (c *Client) copymove(method string, oldpath string, newpath string, overwri
return c.copymove(method, oldpath, newpath, overwrite)
}
return newPathError(method, oldpath, s)
return NewPathError(method, oldpath, s)
}
func (c *Client) put(path string, stream io.Reader) (status int, err error) {