Make ServerConn safe for concurrent use

ServerConn's textproto.Conn is safe for concurrent access. This adds
calls to StartResponse() and EndResponse() to put that to use wherever
Cmd() is called.
https://golang.org/pkg/net/textproto/#Conn.Cmd

Additionally, this adds a sync.Mutex on ServerConn to provide
concurrency protection for data commands such as STOR and RETR that need
to run through synchronous steps.
This commit is contained in:
Brian Foshee 2016-02-29 15:44:03 -05:00
parent 025815df64
commit 025459f901
2 changed files with 146 additions and 7 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes"
"io/ioutil"
"net/textproto"
"sync"
"testing"
"time"
)
@ -221,3 +222,105 @@ func TestWrongLogin(t *testing.T) {
t.Fatal("expected error, got nil")
}
}
func TestConcurrentAccess(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
c, err := DialTimeout("localhost:21", 5*time.Second)
if err != nil {
t.Fatal(err)
}
err = c.Login("anonymous", "anonymous")
if err != nil {
t.Fatal(err)
}
err = c.ChangeDir("incoming")
if err != nil {
t.Error(err)
}
wg := sync.WaitGroup{}
files := []string{"test1", "test2"}
for _, f := range files {
wg.Add(1)
go func(fn string) {
data := bytes.NewBufferString(testData)
err := c.Stor(fn, data)
if err != nil {
t.Error(err)
}
_, err = c.List(".")
if err != nil {
t.Error(err)
}
err = c.Rename(fn, fn+"tset")
if err != nil {
t.Error(err)
}
r, err := c.Retr(fn + "tset")
if err != nil {
t.Error(err)
} else {
buf, err := ioutil.ReadAll(r)
if err != nil {
t.Error(err)
}
if string(buf) != testData {
t.Errorf("'%s'", buf)
}
r.Close()
}
r, err = c.RetrFrom(fn+"tset", 5)
if err != nil {
t.Error(err)
} else {
buf, err := ioutil.ReadAll(r)
if err != nil {
t.Error(err)
}
expected := testData[5:]
if string(buf) != expected {
t.Errorf("read %q, expected %q", buf, expected)
}
r.Close()
}
err = c.Delete(fn + "tset")
if err != nil {
t.Error(err)
}
wg.Done()
}(f)
}
wg.Wait()
err = c.Logout()
if err != nil {
if protoErr := err.(*textproto.Error); protoErr != nil {
if protoErr.Code != StatusNotImplemented {
t.Error(err)
}
} else {
t.Error(err)
}
}
c.Quit()
err = c.NoOp()
if err == nil {
t.Error("Expected error")
}
}

50
ftp.go
View File

@ -9,6 +9,7 @@ import (
"net/textproto"
"strconv"
"strings"
"sync"
"time"
)
@ -28,6 +29,14 @@ type ServerConn struct {
host string
timeout time.Duration
features map[string]string
// mu provides concurrent-safe use of a single instance of this type. It is
// locked and unlocked around data actions such as STOR and RETR.
mu sync.Mutex
// id is the current conn Pipeline id in use. This is set and used in every
// case that mu is locked and unlocked.
id uint
}
// Entry describes a file and is returned by List().
@ -238,17 +247,29 @@ func (c *ServerConn) openDataConn() (net.Conn, error) {
// cmd is a helper function to execute a command and check for the expected FTP
// return code
func (c *ServerConn) cmd(expected int, format string, args ...interface{}) (int, string, error) {
_, err := c.conn.Cmd(format, args...)
id, err := c.conn.Cmd(format, args...)
if err != nil {
return 0, "", err
}
// Utilize the Pipeline on c.conn to be safe for concurrency.
c.conn.StartResponse(id)
defer c.conn.EndResponse(id)
return c.conn.ReadResponse(expected)
}
// cmdDataConnFrom executes a command which require a FTP data connection.
// Issues a REST FTP command to specify the number of bytes to skip for the transfer.
//
// Caller MUST call c.conn.EndResponse(c.id) AND c.mu.Unlock() once the returned
// connection is closed.
func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, args ...interface{}) (net.Conn, error) {
// Lock ServerConn so that the data action can run through its synchronous
// sequence of events. Must be Unlock()'d before any other data action can
// proceed.
c.mu.Lock()
conn, err := c.openDataConn()
if err != nil {
return nil, err
@ -261,19 +282,27 @@ func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, args ...inter
}
}
_, err = c.conn.Cmd(format, args...)
id, err := c.conn.Cmd(format, args...)
if err != nil {
conn.Close()
return nil, err
}
// Set ServerConn's Pipeline id and start the response. Must call
// c.conn.EndResponse(c.id) once this data action is complete.
c.id = id
c.conn.StartResponse(c.id)
code, msg, err := c.conn.ReadResponse(-1)
if err != nil {
conn.Close()
c.conn.EndResponse(c.id)
c.mu.Unlock()
return nil, err
}
if code != StatusAlreadyOpen && code != StatusAboutToSend {
conn.Close()
c.conn.EndResponse(c.id)
c.mu.Unlock()
return nil, &textproto.Error{Code: code, Msg: msg}
}
@ -591,6 +620,8 @@ func (c *ServerConn) StorFrom(path string, r io.Reader, offset uint64) error {
if err != nil {
return err
}
defer c.mu.Unlock()
defer c.conn.EndResponse(c.id)
_, err = io.Copy(conn, r)
conn.Close()
@ -662,10 +693,15 @@ func (r *response) Read(buf []byte) (int, error) {
// Close implements the io.Closer interface on a FTP data connection.
func (r *response) Close() error {
err := r.conn.Close()
_, _, err2 := r.c.conn.ReadResponse(StatusClosingDataConnection)
if err2 != nil {
err = err2
defer r.c.mu.Unlock()
defer r.c.conn.EndResponse(r.c.id)
if err := r.conn.Close(); err != nil {
return err
}
return err
_, _, err := r.c.conn.ReadResponse(StatusClosingDataConnection)
if err != nil {
return err
}
return nil
}