diff --git a/client_test.go b/client_test.go index e8995a9..4e6bb42 100644 --- a/client_test.go +++ b/client_test.go @@ -24,12 +24,7 @@ func TestConnEPSV(t *testing.T) { func testConn(t *testing.T, disableEPSV bool) { - mock, c := openConn(t, "127.0.0.1") - - if disableEPSV { - delete(c.features, "EPSV") - c.DisableEPSV = true - } + mock, c := openConn(t, "127.0.0.1", DialWithTimeout(5*time.Second), DialWithDisabledEPSV(disableEPSV)) err := c.Login("anonymous", "anonymous") if err != nil { diff --git a/conn_test.go b/conn_test.go index 4d4ab89..d5ca406 100644 --- a/conn_test.go +++ b/conn_test.go @@ -285,14 +285,14 @@ func (mock *ftpMock) Close() { } // Helper to return a client connected to a mock server -func openConn(t *testing.T, addr string) (*ftpMock, *ServerConn) { +func openConn(t *testing.T, addr string, options ...DialOption) (*ftpMock, *ServerConn) { mock, err := newFtpMock(t, addr) if err != nil { t.Fatal(err) } defer mock.Close() - c, err := Dial(mock.Addr()) + c, err := DialWithOptions(mock.Addr(), options...) if err != nil { t.Fatal(err) } diff --git a/ftp.go b/ftp.go index cae0c63..afaffca 100644 --- a/ftp.go +++ b/ftp.go @@ -5,6 +5,8 @@ package ftp import ( "bufio" + "context" + "crypto/tls" "errors" "io" "net" @@ -27,19 +29,31 @@ const ( // ServerConn represents the connection to a remote FTP server. // It should be protected from concurrent accesses. type ServerConn struct { - // Do not use EPSV mode - DisableEPSV bool + options *dialOptions + conn *textproto.Conn + host string - // Timezone that the server is in - Location *time.Location - - conn *textproto.Conn - host string - timeout time.Duration + // Server capabilities discovered at runtime features map[string]string + skipEPSV bool mlstSupported bool } +// DialOption represents an option to start a new connection with DialWithOptions +type DialOption struct { + setup func(do *dialOptions) +} + +// dialOptions contains all the options set by DialOption.setup +type dialOptions struct { + context context.Context + dialer net.Dialer + tlsConfig tls.Config + conn net.Conn + disableEPSV bool + location *time.Location +} + // Entry describes a file and is returned by List(). type Entry struct { Name string @@ -55,41 +69,44 @@ type Response struct { closed bool } -// Connect is an alias to Dial, for backward compatibility -func Connect(addr string) (*ServerConn, error) { - return Dial(addr) -} +// DialWithOptions connects to the specified address with optinal options +func DialWithOptions(addr string, options ...DialOption) (*ServerConn, error) { + do := &dialOptions{} + for _, option := range options { + option.setup(do) + } -// Dial is like DialTimeout with no timeout -func Dial(addr string) (*ServerConn, error) { - return DialTimeout(addr, 0) -} + if do.location == nil { + do.location = time.UTC + } -// DialTimeout initializes the connection to the specified ftp server address. -// -// It is generally followed by a call to Login() as most FTP commands require -// an authenticated user. -func DialTimeout(addr string, timeout time.Duration) (*ServerConn, error) { - tconn, err := net.DialTimeout("tcp", addr, timeout) - if err != nil { - return nil, err + tconn := do.conn + if tconn == nil { + ctx := do.context + + if ctx == nil { + ctx = context.Background() + } + + conn, err := do.dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + tconn = conn } // Use the resolved IP address in case addr contains a domain name // If we use the domain name, we might not resolve to the same IP. remoteAddr := tconn.RemoteAddr().(*net.TCPAddr) - conn := textproto.NewConn(tconn) - c := &ServerConn{ - conn: conn, - host: remoteAddr.IP.String(), - timeout: timeout, + options: do, features: make(map[string]string), - Location: time.UTC, + conn: textproto.NewConn(tconn), + host: remoteAddr.IP.String(), } - _, _, err = c.conn.ReadResponse(StatusReady) + _, _, err := c.conn.ReadResponse(StatusReady) if err != nil { c.Quit() return nil, err @@ -108,6 +125,76 @@ func DialTimeout(addr string, timeout time.Duration) (*ServerConn, error) { return c, nil } +// DialWithTimeout returns a DialOption that configures the ServerConn with specified timeout +func DialWithTimeout(timeout time.Duration) DialOption { + return DialOption{func(do *dialOptions) { + do.dialer.Timeout = timeout + }} +} + +// DialWithDialer returns a DialOption that configures the ServerConn with specified net.Dialer +func DialWithDialer(dialer net.Dialer) DialOption { + return DialOption{func(do *dialOptions) { + do.dialer = dialer + }} +} + +// DialWithNetConn returns a DialOption that configures the ServerConn with the underlying net.Conn +func DialWithNetConn(conn net.Conn) DialOption { + return DialOption{func(do *dialOptions) { + do.conn = conn + }} +} + +// DialWithDisabledEPSV returns a DialOption that configures the ServerConn with EPSV disabled +// Note that EPSV is only used when advertised in the server features. +func DialWithDisabledEPSV(disabled bool) DialOption { + return DialOption{func(do *dialOptions) { + do.disableEPSV = disabled + }} +} + +// DialWithLocation returns a DialOption that configures the ServerConn with specified time.Location +// The lococation is used to parse the dates sent by the server which are in server's timezone +func DialWithLocation(location *time.Location) DialOption { + return DialOption{func(do *dialOptions) { + do.location = location + }} +} + +// DialWithContext returns a DialOption that configures the ServerConn with specified context +// The context will be used for the initial connection setup +func DialWithContext(ctx context.Context) DialOption { + return DialOption{func(do *dialOptions) { + do.context = ctx + }} +} + +// DialWithTLS returns a DialOption that configures the ServerConn with specified TLS config +func DialWithTLS(tlsConfig tls.Config) DialOption { + return DialOption{func(do *dialOptions) { + do.tlsConfig = tlsConfig + }} +} + +// Connect is an alias to Dial, for backward compatibility +func Connect(addr string) (*ServerConn, error) { + return Dial(addr) +} + +// Dial is like DialTimeout with no timeout +func Dial(addr string) (*ServerConn, error) { + return DialTimeout(addr, 0) +} + +// DialTimeout initializes the connection to the specified ftp server address. +// +// It is generally followed by a call to Login() as most FTP commands require +// an authenticated user. +func DialTimeout(addr string, timeout time.Duration) (*ServerConn, error) { + return DialWithOptions(addr, DialWithTimeout(timeout)) +} + // Login authenticates the client with specified user and password. // // "anonymous"/"anonymous" is a common user/password scheme for FTP servers @@ -271,13 +358,13 @@ func (c *ServerConn) pasv() (host string, port int, err error) { // getDataConnPort returns a host, port for a new data connection // it uses the best available method to do so func (c *ServerConn) getDataConnPort() (string, int, error) { - if !c.DisableEPSV { + if !c.options.disableEPSV && !c.skipEPSV { if port, err := c.epsv(); err == nil { return c.host, port, nil } - // if there is an error, disable EPSV for the next attempts - c.DisableEPSV = true + // if there is an error, skip EPSV for the next attempts + c.skipEPSV = true } return c.pasv() @@ -290,7 +377,7 @@ func (c *ServerConn) openDataConn() (net.Conn, error) { return nil, err } - return net.DialTimeout("tcp", net.JoinHostPort(host, strconv.Itoa(port)), c.timeout) + return c.options.dialer.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(port))) } // cmd is a helper function to execute a command and check for the expected FTP @@ -383,7 +470,7 @@ func (c *ServerConn) List(path string) (entries []*Entry, err error) { scanner := bufio.NewScanner(r) now := time.Now() for scanner.Scan() { - entry, err := parser(scanner.Text(), now, c.Location) + entry, err := parser(scanner.Text(), now, c.options.location) if err == nil { entries = append(entries, entry) }