From c6155fbcaeb1e467b2e01e64beea30c5ce7a3dc6 Mon Sep 17 00:00:00 2001 From: Ben Tam Date: Sun, 29 Oct 2023 21:53:10 +0800 Subject: [PATCH] feat: provide content length for put method Signed-off-by: Ben Tam --- client.go | 24 +++++++++++++++++++++--- requests.go | 6 ++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 656e003..42805fc 100644 --- a/client.go +++ b/client.go @@ -387,7 +387,7 @@ func (c *Client) ReadStreamRange(path string, offset, length int64) (io.ReadClos // Write writes data to a given path func (c *Client) Write(path string, data []byte, _ os.FileMode) (err error) { - s, err := c.put(path, bytes.NewReader(data)) + s, err := c.put(path, bytes.NewReader(data), int64(len(data))) if err != nil { return } @@ -403,7 +403,7 @@ func (c *Client) Write(path string, data []byte, _ os.FileMode) (err error) { return } - s, err = c.put(path, bytes.NewReader(data)) + s, err = c.put(path, bytes.NewReader(data), int64(len(data))) if err != nil { return } @@ -423,7 +423,25 @@ func (c *Client) WriteStream(path string, stream io.Reader, _ os.FileMode) (err return err } - s, err := c.put(path, stream) + contentLength := int64(0) + if seeker, ok := stream.(io.Seeker); ok { + contentLength, err = seeker.Seek(0, io.SeekEnd) + if err != nil { + return err + } + + _, err = seeker.Seek(0, io.SeekStart) + if err != nil { + return err + } + } else { + contentLength, err = io.Copy(io.Discard, stream) + if err != nil { + return err + } + } + + s, err := c.put(path, stream, contentLength) if err != nil { return err } diff --git a/requests.go b/requests.go index 8e362e8..b51e5c0 100644 --- a/requests.go +++ b/requests.go @@ -160,8 +160,10 @@ func (c *Client) copymove(method string, oldpath string, newpath string, overwri return NewPathError(method, oldpath, s) } -func (c *Client) put(path string, stream io.Reader) (status int, err error) { - rs, err := c.req("PUT", path, stream, nil) +func (c *Client) put(path string, stream io.Reader, contentLength int64) (status int, err error) { + rs, err := c.req("PUT", path, stream, func(r *http.Request) { + r.ContentLength = contentLength + }) if err != nil { return }