diff --git a/client_test.go b/client_test.go index 65724fb..016d9ac 100644 --- a/client_test.go +++ b/client_test.go @@ -39,6 +39,25 @@ func basicAuth(h http.Handler) http.HandlerFunc { } } +func basicAuthWithPostHandlerFunc(h http.Handler, postHandlerFunc http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user, passwd, ok := r.BasicAuth() + if !ok { + w.Header().Set("WWW-Authenticate", `Basic realm="x"`) + w.WriteHeader(401) + return + } + + if user != "user" || passwd != "password" { + http.Error(w, "not authorized", 403) + return + } + + h.ServeHTTP(w, r) + postHandlerFunc(w, r) + } +} + func multipleAuth(h http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { notAuthed := false @@ -130,6 +149,38 @@ func newAuthSrv(t *testing.T, auth func(h http.Handler) http.HandlerFunc) (*http return srv, fs, ctx } +func newAuthServerAcquireContentLength(t *testing.T) (*Client, *httptest.Server, webdav.FileSystem, context.Context) { + srv, fs, ctx := newAuthSrvAcquireContentLength(t, basicAuthWithPostHandlerFunc) + cli := NewClient(srv.URL, "user", "password") + return cli, srv, fs, ctx +} + +func newAuthSrvAcquireContentLength(t *testing.T, authWithPostHandlerFunc func(h http.Handler, postHandlerFunc http.HandlerFunc) http.HandlerFunc) (*httptest.Server, webdav.FileSystem, context.Context) { + mux := http.NewServeMux() + fs := webdav.NewMemFS() + ctx := fillFs(t, fs) + mux.HandleFunc("/", authWithPostHandlerFunc(&webdav.Handler{ + FileSystem: fs, + LockSystem: webdav.NewMemLS(), + }, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + return + } + + fileName := strings.TrimPrefix(r.URL.Path, "/") + stat, err := fs.Stat(ctx, fileName) + if err != nil { + t.Fatalf("got: %v, want nil", err) + } + + if r.ContentLength != stat.Size() { + t.Fatalf("acquire content length got: %v, want %v", r.ContentLength, stat.Size()) + } + })) + srv := httptest.NewServer(mux) + return srv, fs, ctx +} + func TestConnect(t *testing.T) { cli, srv, _, _ := newServer(t) defer srv.Close() @@ -572,3 +623,21 @@ func TestWriteStreamFromPipe(t *testing.T) { t.Fatalf("got: %v, want file size: %d bytes", info.Size(), 8) } } + +func TestWriteToServerAcquireContentLength(t *testing.T) { + cli, srv, _, _ := newAuthServerAcquireContentLength(t) + defer srv.Close() + + if err := cli.Write("/newfile.txt", []byte("foo bar\n"), 0660); err != nil { + t.Fatalf("got: %v, want nil", err) + } +} + +func TestWriteStreamToServerAcquireContentLength(t *testing.T) { + cli, srv, _, _ := newAuthServerAcquireContentLength(t) + defer srv.Close() + + if err := cli.WriteStream("/newfile.txt", strings.NewReader("foo bar\n"), 0660); err != nil { + t.Fatalf("got: %v, want nil", err) + } +}