diff --git a/ftp.go b/ftp.go index 91f7496..7b26050 100644 --- a/ftp.go +++ b/ftp.go @@ -50,6 +50,7 @@ type dialOptions struct { context context.Context dialer net.Dialer tlsConfig *tls.Config + explicitTLS bool conn net.Conn disableEPSV bool location *time.Location @@ -90,7 +91,7 @@ func Dial(addr string, options ...DialOption) (*ServerConn, error) { if do.dialFunc != nil { tconn, err = do.dialFunc("tcp", addr) - } else if do.tlsConfig != nil { + } else if do.tlsConfig != nil && !do.explicitTLS { tconn, err = tls.DialWithDialer(&do.dialer, "tcp", addr, do.tlsConfig) } else { ctx := do.context @@ -111,15 +112,10 @@ func Dial(addr string, options ...DialOption) (*ServerConn, error) { // If we use the domain name, we might not resolve to the same IP. remoteAddr := tconn.RemoteAddr().(*net.TCPAddr) - var sourceConn io.ReadWriteCloser = tconn - if do.debugOutput != nil { - sourceConn = newDebugWrapper(tconn, do.debugOutput) - } - c := &ServerConn{ options: do, features: make(map[string]string), - conn: textproto.NewConn(sourceConn), + conn: textproto.NewConn(do.wrapConn(tconn)), host: remoteAddr.IP.String(), } @@ -129,6 +125,15 @@ func Dial(addr string, options ...DialOption) (*ServerConn, error) { return nil, err } + if do.explicitTLS { + if err := c.authTLS(); err != nil { + _ = c.Quit() + return nil, err + } + tconn = tls.Client(tconn, do.tlsConfig) + c.conn = textproto.NewConn(do.wrapConn(tconn)) + } + err = c.feat() if err != nil { c.Quit() @@ -198,6 +203,15 @@ func DialWithTLS(tlsConfig *tls.Config) DialOption { }} } +// DialWithExplicitTLS returns a DialOption that configures the ServerConn to be upgraded to TLS +// See DialWithTLS for general TLS documentation +func DialWithExplicitTLS(tlsConfig *tls.Config) DialOption { + return DialOption{func(do *dialOptions) { + do.explicitTLS = true + do.tlsConfig = tlsConfig + }} +} + // DialWithDebugOutput returns a DialOption that configures the ServerConn to write to the Writer // everything it reads from the server func DialWithDebugOutput(w io.Writer) DialOption { @@ -218,6 +232,14 @@ func DialWithDialFunc(f func(network, address string) (net.Conn, error)) DialOpt }} } +func (o *dialOptions) wrapConn(netConn net.Conn) io.ReadWriteCloser { + if o.debugOutput == nil { + return netConn + } + + return newDebugWrapper(netConn, o.debugOutput) +} + // Connect is an alias to Dial, for backward compatibility func Connect(addr string) (*ServerConn, error) { return Dial(addr) @@ -269,6 +291,12 @@ func (c *ServerConn) Login(user, password string) error { return err } +// authTLS upgrades the connection to use TLS +func (c *ServerConn) authTLS() error { + _, _, err := c.cmd(StatusAuthOK, "AUTH TLS") + return err +} + // feat issues a FEAT FTP command to list the additional commands supported by // the remote FTP server. // FEAT is described in RFC 2389 diff --git a/status.go b/status.go index 7d281b8..c8ea026 100644 --- a/status.go +++ b/status.go @@ -25,6 +25,7 @@ const ( StatusLoggedIn = 230 StatusLoggedOut = 231 StatusLogoutAck = 232 + StatusAuthOK = 234 StatusRequestedFileActionOK = 250 StatusPathCreated = 257 @@ -73,6 +74,7 @@ var statusText = map[int]string{ StatusLoggedIn: "User logged in, proceed.", StatusLoggedOut: "User logged out; service terminated.", StatusLogoutAck: "Logout command noted, will complete when transfer done.", + StatusAuthOK: "AUTH command OK", StatusRequestedFileActionOK: "Requested file action okay, completed.", StatusPathCreated: "Path created.",