diff --git a/client_test.go b/client_test.go index ac09cd5..afd37b9 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bytes" "io/ioutil" "net/textproto" + "sync" "testing" "time" ) @@ -221,3 +222,105 @@ func TestWrongLogin(t *testing.T) { t.Fatal("expected error, got nil") } } + +func TestConcurrentAccess(t *testing.T) { + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + c, err := DialTimeout("localhost:21", 5*time.Second) + if err != nil { + t.Fatal(err) + } + + err = c.Login("anonymous", "anonymous") + if err != nil { + t.Fatal(err) + } + + err = c.ChangeDir("incoming") + if err != nil { + t.Error(err) + } + + wg := sync.WaitGroup{} + + files := []string{"test1", "test2"} + for _, f := range files { + wg.Add(1) + + go func(fn string) { + data := bytes.NewBufferString(testData) + err := c.Stor(fn, data) + if err != nil { + t.Error(err) + } + + _, err = c.List(".") + if err != nil { + t.Error(err) + } + + err = c.Rename(fn, fn+"tset") + if err != nil { + t.Error(err) + } + + r, err := c.Retr(fn + "tset") + if err != nil { + t.Error(err) + } else { + buf, err := ioutil.ReadAll(r) + if err != nil { + t.Error(err) + } + if string(buf) != testData { + t.Errorf("'%s'", buf) + } + r.Close() + } + + r, err = c.RetrFrom(fn+"tset", 5) + if err != nil { + t.Error(err) + } else { + buf, err := ioutil.ReadAll(r) + if err != nil { + t.Error(err) + } + expected := testData[5:] + if string(buf) != expected { + t.Errorf("read %q, expected %q", buf, expected) + } + r.Close() + } + + err = c.Delete(fn + "tset") + if err != nil { + t.Error(err) + } + + wg.Done() + }(f) + } + + wg.Wait() + + err = c.Logout() + if err != nil { + if protoErr := err.(*textproto.Error); protoErr != nil { + if protoErr.Code != StatusNotImplemented { + t.Error(err) + } + } else { + t.Error(err) + } + } + + c.Quit() + + err = c.NoOp() + if err == nil { + t.Error("Expected error") + } +} diff --git a/ftp.go b/ftp.go index db7ac92..cf3e925 100644 --- a/ftp.go +++ b/ftp.go @@ -9,6 +9,7 @@ import ( "net/textproto" "strconv" "strings" + "sync" "time" ) @@ -28,6 +29,14 @@ type ServerConn struct { host string timeout time.Duration features map[string]string + + // mu provides concurrent-safe use of a single instance of this type. It is + // locked and unlocked around data actions such as STOR and RETR. + mu sync.Mutex + + // id is the current conn Pipeline id in use. This is set and used in every + // case that mu is locked and unlocked. + id uint } // Entry describes a file and is returned by List(). @@ -238,17 +247,29 @@ func (c *ServerConn) openDataConn() (net.Conn, error) { // cmd is a helper function to execute a command and check for the expected FTP // return code func (c *ServerConn) cmd(expected int, format string, args ...interface{}) (int, string, error) { - _, err := c.conn.Cmd(format, args...) + id, err := c.conn.Cmd(format, args...) if err != nil { return 0, "", err } + // Utilize the Pipeline on c.conn to be safe for concurrency. + c.conn.StartResponse(id) + defer c.conn.EndResponse(id) + return c.conn.ReadResponse(expected) } // cmdDataConnFrom executes a command which require a FTP data connection. // Issues a REST FTP command to specify the number of bytes to skip for the transfer. +// +// Caller MUST call c.conn.EndResponse(c.id) AND c.mu.Unlock() once the returned +// connection is closed. func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, args ...interface{}) (net.Conn, error) { + // Lock ServerConn so that the data action can run through its synchronous + // sequence of events. Must be Unlock()'d before any other data action can + // proceed. + c.mu.Lock() + conn, err := c.openDataConn() if err != nil { return nil, err @@ -261,19 +282,27 @@ func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, args ...inter } } - _, err = c.conn.Cmd(format, args...) + id, err := c.conn.Cmd(format, args...) if err != nil { conn.Close() return nil, err } + // Set ServerConn's Pipeline id and start the response. Must call + // c.conn.EndResponse(c.id) once this data action is complete. + c.id = id + c.conn.StartResponse(c.id) code, msg, err := c.conn.ReadResponse(-1) if err != nil { conn.Close() + c.conn.EndResponse(c.id) + c.mu.Unlock() return nil, err } if code != StatusAlreadyOpen && code != StatusAboutToSend { conn.Close() + c.conn.EndResponse(c.id) + c.mu.Unlock() return nil, &textproto.Error{Code: code, Msg: msg} } @@ -591,6 +620,8 @@ func (c *ServerConn) StorFrom(path string, r io.Reader, offset uint64) error { if err != nil { return err } + defer c.mu.Unlock() + defer c.conn.EndResponse(c.id) _, err = io.Copy(conn, r) conn.Close() @@ -662,10 +693,15 @@ func (r *response) Read(buf []byte) (int, error) { // Close implements the io.Closer interface on a FTP data connection. func (r *response) Close() error { - err := r.conn.Close() - _, _, err2 := r.c.conn.ReadResponse(StatusClosingDataConnection) - if err2 != nil { - err = err2 + defer r.c.mu.Unlock() + defer r.c.conn.EndResponse(r.c.id) + + if err := r.conn.Close(); err != nil { + return err } - return err + _, _, err := r.c.conn.ReadResponse(StatusClosingDataConnection) + if err != nil { + return err + } + return nil }