diff --git a/client.go b/client.go index 76d24dc..4b3bc80 100644 --- a/client.go +++ b/client.go @@ -269,9 +269,12 @@ func (c *Client) RemoveAll(path string) error { } // Mkdir makes a directory -func (c *Client) Mkdir(path string, _ os.FileMode) error { +func (c *Client) Mkdir(path string, _ os.FileMode) (err error) { path = FixSlashes(path) - status := c.mkcol(path) + status, err := c.mkcol(path) + if err != nil { + return + } if status == 201 { return nil } @@ -280,12 +283,16 @@ func (c *Client) Mkdir(path string, _ os.FileMode) error { } // MkdirAll like mkdir -p, but for webdav -func (c *Client) MkdirAll(path string, _ os.FileMode) error { +func (c *Client) MkdirAll(path string, _ os.FileMode) (err error) { path = FixSlashes(path) - status := c.mkcol(path) + status, err := c.mkcol(path) + if err != nil { + return + } if status == 201 { return nil - } else if status == 409 { + } + if status == 409 { paths := strings.Split(path, "/") sub := "/" for _, e := range paths { @@ -293,7 +300,10 @@ func (c *Client) MkdirAll(path string, _ os.FileMode) error { continue } sub += e + "/" - status = c.mkcol(sub) + status, err = c.mkcol(sub) + if err != nil { + return + } if status != 201 { return newPathError("MkdirAll", sub, status) } @@ -385,22 +395,29 @@ func (c *Client) ReadStreamRange(path string, offset, length int64) (io.ReadClos } // Write writes data to a given path -func (c *Client) Write(path string, data []byte, _ os.FileMode) error { - s := c.put(path, bytes.NewReader(data)) +func (c *Client) Write(path string, data []byte, _ os.FileMode) (err error) { + s, err := c.put(path, bytes.NewReader(data)) + if err != nil { + return + } + switch s { case 200, 201, 204: return nil case 409: - err := c.createParentCollection(path) + err = c.createParentCollection(path) if err != nil { - return err + return } - s = c.put(path, bytes.NewReader(data)) + s, err = c.put(path, bytes.NewReader(data)) + if err != nil { + return + } if s == 200 || s == 201 || s == 204 { - return nil + return } } @@ -408,14 +425,17 @@ func (c *Client) Write(path string, data []byte, _ os.FileMode) error { } // WriteStream writes a stream -func (c *Client) WriteStream(path string, stream io.Reader, _ os.FileMode) error { +func (c *Client) WriteStream(path string, stream io.Reader, _ os.FileMode) (err error) { - err := c.createParentCollection(path) + err = c.createParentCollection(path) if err != nil { return err } - s := c.put(path, stream) + s, err := c.put(path, stream) + if err != nil { + return err + } switch s { case 200, 201, 204: diff --git a/cmd/gowebdav/main.go b/cmd/gowebdav/main.go index 97c1010..ef1b7ed 100644 --- a/cmd/gowebdav/main.go +++ b/cmd/gowebdav/main.go @@ -4,13 +4,16 @@ import ( "errors" "flag" "fmt" - d "github.com/studio-b12/gowebdav" "io" + "io/fs" "os" "os/user" + "path" "path/filepath" "runtime" "strings" + + d "github.com/studio-b12/gowebdav" ) func main() { @@ -190,8 +193,18 @@ func cmdCp(c *d.Client, p0, p1 string) (err error) { func cmdPut(c *d.Client, p0, p1 string) (err error) { if p1 == "" { - p1 = filepath.Join(".", p0) + p1 = path.Join(".", p0) + } else { + var fi fs.FileInfo + fi, err = c.Stat(p0) + if err != nil && !d.IsErrNotFound(err) { + return + } + if !d.IsErrNotFound(err) && fi.IsDir() { + p0 = path.Join(p0, p1) + } } + stream, err := getStream(p1) if err != nil { return diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..bbf1e92 --- /dev/null +++ b/errors.go @@ -0,0 +1,49 @@ +package gowebdav + +import ( + "fmt" + "os" +) + +// StatusError implements error and wraps +// an erroneous status code. +type StatusError struct { + Status int +} + +func (se StatusError) Error() string { + return fmt.Sprintf("%d", se.Status) +} + +// IsErrCode returns true if the given error +// is an os.PathError wrapping a StatusError +// with the given status code. +func IsErrCode(err error, code int) bool { + if pe, ok := err.(*os.PathError); ok { + se, ok := pe.Err.(StatusError) + return ok && se.Status == code + } + return false +} + +// IsErrNotFound is shorthand for IsErrCode +// for status 404. +func IsErrNotFound(err error) bool { + return IsErrCode(err, 404) +} + +func newPathError(op string, path string, statusCode int) error { + return &os.PathError{ + Op: op, + Path: path, + Err: StatusError{statusCode}, + } +} + +func newPathErrorErr(op string, path string, err error) error { + return &os.PathError{ + Op: op, + Path: path, + Err: err, + } +} diff --git a/requests.go b/requests.go index 7e13b9b..ac33470 100644 --- a/requests.go +++ b/requests.go @@ -91,18 +91,19 @@ func (c *Client) req(method, path string, body io.Reader, intercept func(*http.R return rs, err } -func (c *Client) mkcol(path string) int { +func (c *Client) mkcol(path string) (status int, err error) { rs, err := c.req("MKCOL", path, nil, nil) if err != nil { - return 400 + return } defer rs.Body.Close() - if rs.StatusCode == 201 || rs.StatusCode == 405 { - return 201 + status = rs.StatusCode + if status == 405 { + status = 201 } - return rs.StatusCode + return } func (c *Client) options(path string) (*http.Response, error) { @@ -130,13 +131,22 @@ func (c *Client) propfind(path string, self bool, body string, resp interface{}, defer rs.Body.Close() if rs.StatusCode != 207 { - return fmt.Errorf("%s - %s %s", rs.Status, "PROPFIND", path) + return newPathError("PROPFIND", path, rs.StatusCode) } return parseXML(rs.Body, resp, parse) } -func (c *Client) doCopyMove(method string, oldpath string, newpath string, overwrite bool) (int, io.ReadCloser) { +func (c *Client) doCopyMove( + method string, + oldpath string, + newpath string, + overwrite bool, +) ( + status int, + r io.ReadCloser, + err error, +) { rs, err := c.req(method, oldpath, nil, func(rq *http.Request) { rq.Header.Add("Destination", PathEscape(Join(c.root, newpath))) if overwrite { @@ -146,13 +156,18 @@ func (c *Client) doCopyMove(method string, oldpath string, newpath string, overw } }) if err != nil { - return 400, nil + return } - return rs.StatusCode, rs.Body + status = rs.StatusCode + r = rs.Body + return } -func (c *Client) copymove(method string, oldpath string, newpath string, overwrite bool) error { - s, data := c.doCopyMove(method, oldpath, newpath, overwrite) +func (c *Client) copymove(method string, oldpath string, newpath string, overwrite bool) (err error) { + s, data, err := c.doCopyMove(method, oldpath, newpath, overwrite) + if err != nil { + return + } if data != nil { defer data.Close() } @@ -177,14 +192,15 @@ func (c *Client) copymove(method string, oldpath string, newpath string, overwri return newPathError(method, oldpath, s) } -func (c *Client) put(path string, stream io.Reader) int { +func (c *Client) put(path string, stream io.Reader) (status int, err error) { rs, err := c.req("PUT", path, stream, nil) if err != nil { - return 400 + return } defer rs.Body.Close() - return rs.StatusCode + status = rs.StatusCode + return } func (c *Client) createParentCollection(itemPath string) (err error) { diff --git a/utils.go b/utils.go index f82592a..c7a65ad 100644 --- a/utils.go +++ b/utils.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/url" - "os" "strconv" "strings" "time" @@ -16,22 +15,6 @@ func log(msg interface{}) { fmt.Println(msg) } -func newPathError(op string, path string, statusCode int) error { - return &os.PathError{ - Op: op, - Path: path, - Err: fmt.Errorf("%d", statusCode), - } -} - -func newPathErrorErr(op string, path string, err error) error { - return &os.PathError{ - Op: op, - Path: path, - Err: err, - } -} - // PathEscape escapes all segments of a given path func PathEscape(path string) string { s := strings.Split(path, "/")