Make ServerConn safe for concurrent use
ServerConn's textproto.Conn is safe for concurrent access. This adds calls to StartResponse() and EndResponse() to put that to use wherever Cmd() is called. https://golang.org/pkg/net/textproto/#Conn.Cmd Additionally, this adds a sync.Mutex on ServerConn to provide concurrency protection for data commands such as STOR and RETR that need to run through synchronous steps.
This commit is contained in:
parent
025815df64
commit
025459f901
103
client_test.go
103
client_test.go
@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -221,3 +222,105 @@ func TestWrongLogin(t *testing.T) {
|
|||||||
t.Fatal("expected error, got nil")
|
t.Fatal("expected error, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConcurrentAccess(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping test in short mode.")
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := DialTimeout("localhost:21", 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.Login("anonymous", "anonymous")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.ChangeDir("incoming")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
|
||||||
|
files := []string{"test1", "test2"}
|
||||||
|
for _, f := range files {
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
go func(fn string) {
|
||||||
|
data := bytes.NewBufferString(testData)
|
||||||
|
err := c.Stor(fn, data)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.List(".")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.Rename(fn, fn+"tset")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := c.Retr(fn + "tset")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
buf, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if string(buf) != testData {
|
||||||
|
t.Errorf("'%s'", buf)
|
||||||
|
}
|
||||||
|
r.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err = c.RetrFrom(fn+"tset", 5)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
} else {
|
||||||
|
buf, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
expected := testData[5:]
|
||||||
|
if string(buf) != expected {
|
||||||
|
t.Errorf("read %q, expected %q", buf, expected)
|
||||||
|
}
|
||||||
|
r.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.Delete(fn + "tset")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Done()
|
||||||
|
}(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
err = c.Logout()
|
||||||
|
if err != nil {
|
||||||
|
if protoErr := err.(*textproto.Error); protoErr != nil {
|
||||||
|
if protoErr.Code != StatusNotImplemented {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Quit()
|
||||||
|
|
||||||
|
err = c.NoOp()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
50
ftp.go
50
ftp.go
@ -9,6 +9,7 @@ import (
|
|||||||
"net/textproto"
|
"net/textproto"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,6 +29,14 @@ type ServerConn struct {
|
|||||||
host string
|
host string
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
features map[string]string
|
features map[string]string
|
||||||
|
|
||||||
|
// mu provides concurrent-safe use of a single instance of this type. It is
|
||||||
|
// locked and unlocked around data actions such as STOR and RETR.
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// id is the current conn Pipeline id in use. This is set and used in every
|
||||||
|
// case that mu is locked and unlocked.
|
||||||
|
id uint
|
||||||
}
|
}
|
||||||
|
|
||||||
// Entry describes a file and is returned by List().
|
// Entry describes a file and is returned by List().
|
||||||
@ -238,17 +247,29 @@ func (c *ServerConn) openDataConn() (net.Conn, error) {
|
|||||||
// cmd is a helper function to execute a command and check for the expected FTP
|
// cmd is a helper function to execute a command and check for the expected FTP
|
||||||
// return code
|
// return code
|
||||||
func (c *ServerConn) cmd(expected int, format string, args ...interface{}) (int, string, error) {
|
func (c *ServerConn) cmd(expected int, format string, args ...interface{}) (int, string, error) {
|
||||||
_, err := c.conn.Cmd(format, args...)
|
id, err := c.conn.Cmd(format, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", err
|
return 0, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Utilize the Pipeline on c.conn to be safe for concurrency.
|
||||||
|
c.conn.StartResponse(id)
|
||||||
|
defer c.conn.EndResponse(id)
|
||||||
|
|
||||||
return c.conn.ReadResponse(expected)
|
return c.conn.ReadResponse(expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
// cmdDataConnFrom executes a command which require a FTP data connection.
|
// cmdDataConnFrom executes a command which require a FTP data connection.
|
||||||
// Issues a REST FTP command to specify the number of bytes to skip for the transfer.
|
// Issues a REST FTP command to specify the number of bytes to skip for the transfer.
|
||||||
|
//
|
||||||
|
// Caller MUST call c.conn.EndResponse(c.id) AND c.mu.Unlock() once the returned
|
||||||
|
// connection is closed.
|
||||||
func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, args ...interface{}) (net.Conn, error) {
|
func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, args ...interface{}) (net.Conn, error) {
|
||||||
|
// Lock ServerConn so that the data action can run through its synchronous
|
||||||
|
// sequence of events. Must be Unlock()'d before any other data action can
|
||||||
|
// proceed.
|
||||||
|
c.mu.Lock()
|
||||||
|
|
||||||
conn, err := c.openDataConn()
|
conn, err := c.openDataConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -261,19 +282,27 @@ func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, args ...inter
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = c.conn.Cmd(format, args...)
|
id, err := c.conn.Cmd(format, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// Set ServerConn's Pipeline id and start the response. Must call
|
||||||
|
// c.conn.EndResponse(c.id) once this data action is complete.
|
||||||
|
c.id = id
|
||||||
|
c.conn.StartResponse(c.id)
|
||||||
|
|
||||||
code, msg, err := c.conn.ReadResponse(-1)
|
code, msg, err := c.conn.ReadResponse(-1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
c.conn.EndResponse(c.id)
|
||||||
|
c.mu.Unlock()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if code != StatusAlreadyOpen && code != StatusAboutToSend {
|
if code != StatusAlreadyOpen && code != StatusAboutToSend {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
c.conn.EndResponse(c.id)
|
||||||
|
c.mu.Unlock()
|
||||||
return nil, &textproto.Error{Code: code, Msg: msg}
|
return nil, &textproto.Error{Code: code, Msg: msg}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -591,6 +620,8 @@ func (c *ServerConn) StorFrom(path string, r io.Reader, offset uint64) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
defer c.conn.EndResponse(c.id)
|
||||||
|
|
||||||
_, err = io.Copy(conn, r)
|
_, err = io.Copy(conn, r)
|
||||||
conn.Close()
|
conn.Close()
|
||||||
@ -662,10 +693,15 @@ func (r *response) Read(buf []byte) (int, error) {
|
|||||||
|
|
||||||
// Close implements the io.Closer interface on a FTP data connection.
|
// Close implements the io.Closer interface on a FTP data connection.
|
||||||
func (r *response) Close() error {
|
func (r *response) Close() error {
|
||||||
err := r.conn.Close()
|
defer r.c.mu.Unlock()
|
||||||
_, _, err2 := r.c.conn.ReadResponse(StatusClosingDataConnection)
|
defer r.c.conn.EndResponse(r.c.id)
|
||||||
if err2 != nil {
|
|
||||||
err = err2
|
if err := r.conn.Close(); err != nil {
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
_, _, err := r.c.conn.ReadResponse(StatusClosingDataConnection)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user