Add DialWithDialFunc to specify dial function used for both control and data connections (#140)

Add DialWithDialFunc to specify dial function used for both control and data connections

If used DialWithNetConn, DialWithNetConn takes precedence for
the control connection, while data connections will be established
using function specified with the DialWithDialFunc option
This commit is contained in:
Alexander Pevzner 2019-04-27 19:36:46 +03:00 committed by Julien Laffaye
parent e6de3d35bf
commit 6a014d5e22

30
ftp.go
View File

@ -54,6 +54,7 @@ type dialOptions struct {
disableEPSV bool disableEPSV bool
location *time.Location location *time.Location
debugOutput io.Writer debugOutput io.Writer
dialFunc func(network, address string) (net.Conn, error)
} }
// Entry describes a file and is returned by List(). // Entry describes a file and is returned by List().
@ -84,17 +85,23 @@ func Dial(addr string, options ...DialOption) (*ServerConn, error) {
tconn := do.conn tconn := do.conn
if tconn == nil { if tconn == nil {
var err error
if do.dialFunc != nil {
tconn, err = do.dialFunc("tcp", addr)
} else {
ctx := do.context ctx := do.context
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
conn, err := do.dialer.DialContext(ctx, "tcp", addr) tconn, err = do.dialer.DialContext(ctx, "tcp", addr)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
tconn = conn
} }
// Use the resolved IP address in case addr contains a domain name // Use the resolved IP address in case addr contains a domain name
@ -192,6 +199,18 @@ func DialWithDebugOutput(w io.Writer) DialOption {
}} }}
} }
// DialWithDialFunc returns a DialOption that configures the ServerConn to use the
// specified function to establish both control and data connections
//
// If used together with the DialWithNetConn option, the DialWithNetConn
// takes precedence for the control connection, while data connections will
// be established using function specified with the DialWithDialFunc option
func DialWithDialFunc(f func(network, address string) (net.Conn, error)) DialOption {
return DialOption{func(do *dialOptions) {
do.dialFunc = f
}}
}
// Connect is an alias to Dial, for backward compatibility // Connect is an alias to Dial, for backward compatibility
func Connect(addr string) (*ServerConn, error) { func Connect(addr string) (*ServerConn, error) {
return Dial(addr) return Dial(addr)
@ -387,7 +406,12 @@ func (c *ServerConn) openDataConn() (net.Conn, error) {
return nil, err return nil, err
} }
return c.options.dialer.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(port))) addr := net.JoinHostPort(host, strconv.Itoa(port))
if c.options.dialFunc != nil {
return c.options.dialFunc("tcp", addr)
}
return c.options.dialer.Dial("tcp", addr)
} }
// cmd is a helper function to execute a command and check for the expected FTP // cmd is a helper function to execute a command and check for the expected FTP