Use closures for data connections

Breaks the API, but avoids common pitfalls like unclosed connections.
This commit is contained in:
Julian Kornberger 2016-11-18 01:56:24 +01:00
parent 988909ab28
commit 619c7c29e5
2 changed files with 59 additions and 99 deletions

View File

@ -3,6 +3,7 @@ package ftp
import ( import (
"bytes" "bytes"
"io/ioutil" "io/ioutil"
"net"
"net/textproto" "net/textproto"
"testing" "testing"
"time" "time"
@ -67,25 +68,22 @@ func testConn(t *testing.T, disableEPSV bool) {
t.Error(err) t.Error(err)
} }
r, err := c.Retr("tset") err = c.Retr("tset", func(conn net.Conn) error {
if err != nil { buf, err := ioutil.ReadAll(conn)
t.Error(err)
} else {
buf, err := ioutil.ReadAll(r)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
if string(buf) != testData { if string(buf) != testData {
t.Errorf("'%s'", buf) t.Errorf("'%s'", buf)
} }
r.Close() return nil
} })
r, err = c.RetrFrom("tset", 5)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} else { }
buf, err := ioutil.ReadAll(r)
err = c.RetrFrom("tset", 5, func(conn net.Conn) error {
buf, err := ioutil.ReadAll(conn)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -93,7 +91,10 @@ func testConn(t *testing.T, disableEPSV bool) {
if string(buf) != expected { if string(buf) != expected {
t.Errorf("read %q, expected %q", buf, expected) t.Errorf("read %q, expected %q", buf, expected)
} }
r.Close() return nil
})
if err != nil {
t.Error(err)
} }
err = c.Delete("tset") err = c.Delete("tset")

133
ftp.go
View File

@ -39,11 +39,7 @@ type Entry struct {
Time time.Time Time time.Time
} }
// response represent a data-connection type connHandler func(net.Conn) error
type response struct {
conn net.Conn
c *ServerConn
}
// Connect is an alias to Dial, for backward compatibility // Connect is an alias to Dial, for backward compatibility
func Connect(addr string) (*ServerConn, error) { func Connect(addr string) (*ServerConn, error) {
@ -258,37 +254,44 @@ func (c *ServerConn) cmd(expected int, format string, args ...interface{}) (int,
// cmdDataConnFrom executes a command which require a FTP data connection. // cmdDataConnFrom executes a command which require a FTP data connection.
// Issues a REST FTP command to specify the number of bytes to skip for the transfer. // Issues a REST FTP command to specify the number of bytes to skip for the transfer.
func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, args ...interface{}) (net.Conn, error) { func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, arg interface{}, h connHandler) error {
conn, err := c.openDataConn() conn, err := c.openDataConn()
if err != nil { if err != nil {
return nil, err return err
} }
if offset != 0 { if offset != 0 {
_, _, err := c.cmd(StatusRequestFilePending, "REST %d", offset) if _, _, err := c.cmd(StatusRequestFilePending, "REST %d", offset); err != nil {
if err != nil {
conn.Close() conn.Close()
return nil, err return err
} }
} }
_, err = c.conn.Cmd(format, args...) if _, err := c.conn.Cmd(format, arg); err != nil {
if err != nil {
conn.Close() conn.Close()
return nil, err return err
} }
code, msg, err := c.conn.ReadResponse(-1) if code, msg, err := c.conn.ReadResponse(-1); err != nil {
if err != nil {
conn.Close() conn.Close()
return nil, err return err
} } else if code != StatusAlreadyOpen && code != StatusAboutToSend {
if code != StatusAlreadyOpen && code != StatusAboutToSend {
conn.Close() conn.Close()
return nil, &textproto.Error{Code: code, Msg: msg} return &textproto.Error{Code: code, Msg: msg}
} }
return conn, nil // Execute handler
if err = h(conn); err != nil {
conn.Close()
return err
}
err = conn.Close()
_, _, err2 := c.conn.ReadResponse(StatusClosingDataConnection)
if err2 != nil {
err = err2
}
return err
} }
var errUnsupportedListLine = errors.New("Unsupported LIST line") var errUnsupportedListLine = errors.New("Unsupported LIST line")
@ -508,45 +511,30 @@ func (e *Entry) setTime(fields []string) (err error) {
// NameList issues an NLST FTP command. // NameList issues an NLST FTP command.
func (c *ServerConn) NameList(path string) (entries []string, err error) { func (c *ServerConn) NameList(path string) (entries []string, err error) {
conn, err := c.cmdDataConnFrom(0, "NLST %s", path) err = c.cmdDataConnFrom(0, "NLST %s", path, func(conn net.Conn) error {
if err != nil {
return
}
r := &response{conn, c} scanner := bufio.NewScanner(conn)
defer r.Close() for scanner.Scan() {
entries = append(entries, scanner.Text())
scanner := bufio.NewScanner(r) }
for scanner.Scan() { return scanner.Err()
entries = append(entries, scanner.Text()) })
}
if err = scanner.Err(); err != nil {
return entries, err
}
return return
} }
// List issues a LIST FTP command. // List issues a LIST FTP command.
func (c *ServerConn) List(path string) (entries []*Entry, err error) { func (c *ServerConn) List(path string) (entries []*Entry, err error) {
conn, err := c.cmdDataConnFrom(0, "LIST %s", path) err = c.cmdDataConnFrom(0, "LIST %s", path, func(conn net.Conn) error {
if err != nil {
return
}
r := &response{conn, c} scanner := bufio.NewScanner(conn)
defer r.Close() for scanner.Scan() {
line := scanner.Text()
scanner := bufio.NewScanner(r) if entry, err := parseListLine(line); err == nil {
for scanner.Scan() { entries = append(entries, entry)
line := scanner.Text() }
entry, err := parseListLine(line)
if err == nil {
entries = append(entries, entry)
} }
} return scanner.Err()
if err := scanner.Err(); err != nil { })
return nil, err
}
return return
} }
@ -587,21 +575,16 @@ func (c *ServerConn) CurrentDir() (string, error) {
// FTP server. // FTP server.
// //
// The returned ReadCloser must be closed to cleanup the FTP data connection. // The returned ReadCloser must be closed to cleanup the FTP data connection.
func (c *ServerConn) Retr(path string) (io.ReadCloser, error) { func (c *ServerConn) Retr(path string, h connHandler) error {
return c.RetrFrom(path, 0) return c.RetrFrom(path, 0, h)
} }
// RetrFrom issues a RETR FTP command to fetch the specified file from the remote // RetrFrom issues a RETR FTP command to fetch the specified file from the remote
// FTP server, the server will not send the offset first bytes of the file. // FTP server, the server will not send the offset first bytes of the file.
// //
// The returned ReadCloser must be closed to cleanup the FTP data connection. // The returned ReadCloser must be closed to cleanup the FTP data connection.
func (c *ServerConn) RetrFrom(path string, offset uint64) (io.ReadCloser, error) { func (c *ServerConn) RetrFrom(path string, offset uint64, h connHandler) error {
conn, err := c.cmdDataConnFrom(offset, "RETR %s", path) return c.cmdDataConnFrom(offset, "RETR %s", path, h)
if err != nil {
return nil, err
}
return &response{conn, c}, nil
} }
// Stor issues a STOR FTP command to store a file to the remote FTP server. // Stor issues a STOR FTP command to store a file to the remote FTP server.
@ -618,19 +601,10 @@ func (c *ServerConn) Stor(path string, r io.Reader) error {
// //
// Hint: io.Pipe() can be used if an io.Writer is required. // Hint: io.Pipe() can be used if an io.Writer is required.
func (c *ServerConn) StorFrom(path string, r io.Reader, offset uint64) error { func (c *ServerConn) StorFrom(path string, r io.Reader, offset uint64) error {
conn, err := c.cmdDataConnFrom(offset, "STOR %s", path) return c.cmdDataConnFrom(offset, "STOR %s", path, func(conn net.Conn) error {
if err != nil { _, err := io.Copy(conn, r)
return err return err
} })
_, err = io.Copy(conn, r)
conn.Close()
if err != nil {
return err
}
_, _, err = c.conn.ReadResponse(StatusClosingDataConnection)
return err
} }
// Rename renames a file on the remote FTP server. // Rename renames a file on the remote FTP server.
@ -685,18 +659,3 @@ func (c *ServerConn) Quit() error {
c.conn.Cmd("QUIT") c.conn.Cmd("QUIT")
return c.conn.Close() return c.conn.Close()
} }
// Read implements the io.Reader interface on a FTP data connection.
func (r *response) Read(buf []byte) (int, error) {
return r.conn.Read(buf)
}
// Close implements the io.Closer interface on a FTP data connection.
func (r *response) Close() error {
err := r.conn.Close()
_, _, err2 := r.c.conn.ReadResponse(StatusClosingDataConnection)
if err2 != nil {
err = err2
}
return err
}