diff --git a/ftp.go b/ftp.go index 02136ae..71cd4fe 100644 --- a/ftp.go +++ b/ftp.go @@ -8,6 +8,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "net" "net/textproto" @@ -144,21 +145,33 @@ func Dial(addr string, options ...DialOption) (*ServerConn, error) { } } - tconn, err := dialFunc("tcp", addr) + // Use the resolved IP address in case addr contains a domain name + // If we use the domain name, we might not resolve to the same IP. + dnsResolver := net.Resolver{ + Dial: func(ctx context.Context, network string, address string) (net.Conn, error) { + return dialFunc(network, address) + }, + } + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + addrs, err := dnsResolver.LookupIPAddr(context.Background(), host) if err != nil { return nil, err } - // Use the resolved IP address in case addr contains a domain name - // If we use the domain name, we might not resolve to the same IP. - remoteAddr := tconn.RemoteAddr().(*net.TCPAddr) + tconn, err := dialFunc("tcp", fmt.Sprintf("%s:%s", addrs[0].IP.String(), port)) + if err != nil { + return nil, err + } c := &ServerConn{ options: do, features: make(map[string]string), conn: textproto.NewConn(do.wrapConn(tconn)), netConn: tconn, - host: remoteAddr.IP.String(), + host: addrs[0].IP.String(), } _, _, err = c.conn.ReadResponse(StatusReady) @@ -528,26 +541,9 @@ func (c *ServerConn) pasv() (host string, port int, err error) { // Make the IP address to connect to host = strings.Join(pasvData[0:4], ".") - - if c.host != host { - if cmdIP := net.ParseIP(c.host); cmdIP != nil { - if dataIP := net.ParseIP(host); dataIP != nil { - if isBogusDataIP(cmdIP, dataIP) { - return c.host, port, nil - } - } - } - } return host, port, nil } -func isBogusDataIP(cmdIP, dataIP net.IP) bool { - // Logic stolen from lftp (https://github.com/lavv17/lftp/blob/d67fc14d085849a6b0418bb3e912fea2e94c18d1/src/ftpclass.cc#L769) - return dataIP.IsMulticast() || - cmdIP.IsPrivate() != dataIP.IsPrivate() || - cmdIP.IsLoopback() != dataIP.IsLoopback() -} - // getDataConnPort returns a host, port for a new data connection // it uses the best available method to do so func (c *ServerConn) getDataConnPort() (string, int, error) {