From 619c7c29e581c4e6b5d073f162c039c63959b06a Mon Sep 17 00:00:00 2001 From: Julian Kornberger Date: Fri, 18 Nov 2016 01:56:24 +0100 Subject: [PATCH] Use closures for data connections Breaks the API, but avoids common pitfalls like unclosed connections. --- client_test.go | 25 +++++----- ftp.go | 133 +++++++++++++++++-------------------------------- 2 files changed, 59 insertions(+), 99 deletions(-) diff --git a/client_test.go b/client_test.go index 92d4779..4b09f4c 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,7 @@ package ftp import ( "bytes" "io/ioutil" + "net" "net/textproto" "testing" "time" @@ -67,25 +68,22 @@ func testConn(t *testing.T, disableEPSV bool) { t.Error(err) } - r, err := c.Retr("tset") - if err != nil { - t.Error(err) - } else { - buf, err := ioutil.ReadAll(r) + err = c.Retr("tset", func(conn net.Conn) error { + buf, err := ioutil.ReadAll(conn) if err != nil { t.Error(err) } if string(buf) != testData { t.Errorf("'%s'", buf) } - r.Close() - } - - r, err = c.RetrFrom("tset", 5) + return nil + }) if err != nil { t.Error(err) - } else { - buf, err := ioutil.ReadAll(r) + } + + err = c.RetrFrom("tset", 5, func(conn net.Conn) error { + buf, err := ioutil.ReadAll(conn) if err != nil { t.Error(err) } @@ -93,7 +91,10 @@ func testConn(t *testing.T, disableEPSV bool) { if string(buf) != expected { t.Errorf("read %q, expected %q", buf, expected) } - r.Close() + return nil + }) + if err != nil { + t.Error(err) } err = c.Delete("tset") diff --git a/ftp.go b/ftp.go index cfa1711..26480f6 100644 --- a/ftp.go +++ b/ftp.go @@ -39,11 +39,7 @@ type Entry struct { Time time.Time } -// response represent a data-connection -type response struct { - conn net.Conn - c *ServerConn -} +type connHandler func(net.Conn) error // Connect is an alias to Dial, for backward compatibility func Connect(addr string) (*ServerConn, error) { @@ -258,37 +254,44 @@ func (c *ServerConn) cmd(expected int, format string, args ...interface{}) (int, // cmdDataConnFrom executes a command which require a FTP data connection. // Issues a REST FTP command to specify the number of bytes to skip for the transfer. -func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, args ...interface{}) (net.Conn, error) { +func (c *ServerConn) cmdDataConnFrom(offset uint64, format string, arg interface{}, h connHandler) error { conn, err := c.openDataConn() if err != nil { - return nil, err + return err } if offset != 0 { - _, _, err := c.cmd(StatusRequestFilePending, "REST %d", offset) - if err != nil { + if _, _, err := c.cmd(StatusRequestFilePending, "REST %d", offset); err != nil { conn.Close() - return nil, err + return err } } - _, err = c.conn.Cmd(format, args...) - if err != nil { + if _, err := c.conn.Cmd(format, arg); err != nil { conn.Close() - return nil, err + return err } - code, msg, err := c.conn.ReadResponse(-1) - if err != nil { + if code, msg, err := c.conn.ReadResponse(-1); err != nil { conn.Close() - return nil, err - } - if code != StatusAlreadyOpen && code != StatusAboutToSend { + return err + } else if code != StatusAlreadyOpen && code != StatusAboutToSend { conn.Close() - return nil, &textproto.Error{Code: code, Msg: msg} + return &textproto.Error{Code: code, Msg: msg} } - return conn, nil + // Execute handler + if err = h(conn); err != nil { + conn.Close() + return err + } + + err = conn.Close() + _, _, err2 := c.conn.ReadResponse(StatusClosingDataConnection) + if err2 != nil { + err = err2 + } + return err } var errUnsupportedListLine = errors.New("Unsupported LIST line") @@ -508,45 +511,30 @@ func (e *Entry) setTime(fields []string) (err error) { // NameList issues an NLST FTP command. func (c *ServerConn) NameList(path string) (entries []string, err error) { - conn, err := c.cmdDataConnFrom(0, "NLST %s", path) - if err != nil { - return - } + err = c.cmdDataConnFrom(0, "NLST %s", path, func(conn net.Conn) error { - r := &response{conn, c} - defer r.Close() - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - entries = append(entries, scanner.Text()) - } - if err = scanner.Err(); err != nil { - return entries, err - } + scanner := bufio.NewScanner(conn) + for scanner.Scan() { + entries = append(entries, scanner.Text()) + } + return scanner.Err() + }) return } // List issues a LIST FTP command. func (c *ServerConn) List(path string) (entries []*Entry, err error) { - conn, err := c.cmdDataConnFrom(0, "LIST %s", path) - if err != nil { - return - } + err = c.cmdDataConnFrom(0, "LIST %s", path, func(conn net.Conn) error { - r := &response{conn, c} - defer r.Close() - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - entry, err := parseListLine(line) - if err == nil { - entries = append(entries, entry) + scanner := bufio.NewScanner(conn) + for scanner.Scan() { + line := scanner.Text() + if entry, err := parseListLine(line); err == nil { + entries = append(entries, entry) + } } - } - if err := scanner.Err(); err != nil { - return nil, err - } + return scanner.Err() + }) return } @@ -587,21 +575,16 @@ func (c *ServerConn) CurrentDir() (string, error) { // FTP server. // // The returned ReadCloser must be closed to cleanup the FTP data connection. -func (c *ServerConn) Retr(path string) (io.ReadCloser, error) { - return c.RetrFrom(path, 0) +func (c *ServerConn) Retr(path string, h connHandler) error { + return c.RetrFrom(path, 0, h) } // RetrFrom issues a RETR FTP command to fetch the specified file from the remote // FTP server, the server will not send the offset first bytes of the file. // // The returned ReadCloser must be closed to cleanup the FTP data connection. -func (c *ServerConn) RetrFrom(path string, offset uint64) (io.ReadCloser, error) { - conn, err := c.cmdDataConnFrom(offset, "RETR %s", path) - if err != nil { - return nil, err - } - - return &response{conn, c}, nil +func (c *ServerConn) RetrFrom(path string, offset uint64, h connHandler) error { + return c.cmdDataConnFrom(offset, "RETR %s", path, h) } // Stor issues a STOR FTP command to store a file to the remote FTP server. @@ -618,19 +601,10 @@ func (c *ServerConn) Stor(path string, r io.Reader) error { // // Hint: io.Pipe() can be used if an io.Writer is required. func (c *ServerConn) StorFrom(path string, r io.Reader, offset uint64) error { - conn, err := c.cmdDataConnFrom(offset, "STOR %s", path) - if err != nil { + return c.cmdDataConnFrom(offset, "STOR %s", path, func(conn net.Conn) error { + _, err := io.Copy(conn, r) return err - } - - _, err = io.Copy(conn, r) - conn.Close() - if err != nil { - return err - } - - _, _, err = c.conn.ReadResponse(StatusClosingDataConnection) - return err + }) } // Rename renames a file on the remote FTP server. @@ -685,18 +659,3 @@ func (c *ServerConn) Quit() error { c.conn.Cmd("QUIT") return c.conn.Close() } - -// Read implements the io.Reader interface on a FTP data connection. -func (r *response) Read(buf []byte) (int, error) { - return r.conn.Read(buf) -} - -// Close implements the io.Closer interface on a FTP data connection. -func (r *response) Close() error { - err := r.conn.Close() - _, _, err2 := r.c.conn.ReadResponse(StatusClosingDataConnection) - if err2 != nil { - err = err2 - } - return err -}