Add DialWithOptions
DialWithOptions accept a variadic number of options, allowing to introduce more options in the fure without breaking the API.
This commit is contained in:
		
							parent
							
								
									55546487cf
								
							
						
					
					
						commit
						04b1878733
					
				| @ -24,12 +24,7 @@ func TestConnEPSV(t *testing.T) { | ||||
| 
 | ||||
| func testConn(t *testing.T, disableEPSV bool) { | ||||
| 
 | ||||
| 	mock, c := openConn(t, "127.0.0.1") | ||||
| 
 | ||||
| 	if disableEPSV { | ||||
| 		delete(c.features, "EPSV") | ||||
| 		c.DisableEPSV = true | ||||
| 	} | ||||
| 	mock, c := openConn(t, "127.0.0.1", DialWithTimeout(5*time.Second), DialWithDisabledEPSV(disableEPSV)) | ||||
| 
 | ||||
| 	err := c.Login("anonymous", "anonymous") | ||||
| 	if err != nil { | ||||
|  | ||||
| @ -285,14 +285,14 @@ func (mock *ftpMock) Close() { | ||||
| } | ||||
| 
 | ||||
| // Helper to return a client connected to a mock server | ||||
| func openConn(t *testing.T, addr string) (*ftpMock, *ServerConn) { | ||||
| func openConn(t *testing.T, addr string, options ...DialOption) (*ftpMock, *ServerConn) { | ||||
| 	mock, err := newFtpMock(t, addr) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer mock.Close() | ||||
| 
 | ||||
| 	c, err := Dial(mock.Addr()) | ||||
| 	c, err := DialWithOptions(mock.Addr(), options...) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										159
									
								
								ftp.go
									
									
									
									
									
								
							
							
						
						
									
										159
									
								
								ftp.go
									
									
									
									
									
								
							| @ -5,6 +5,8 @@ package ftp | ||||
| 
 | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net" | ||||
| @ -27,19 +29,31 @@ const ( | ||||
| // ServerConn represents the connection to a remote FTP server. | ||||
| // It should be protected from concurrent accesses. | ||||
| type ServerConn struct { | ||||
| 	// Do not use EPSV mode | ||||
| 	DisableEPSV bool | ||||
| 	options *dialOptions | ||||
| 	conn    *textproto.Conn | ||||
| 	host    string | ||||
| 
 | ||||
| 	// Timezone that the server is in | ||||
| 	Location *time.Location | ||||
| 
 | ||||
| 	conn          *textproto.Conn | ||||
| 	host          string | ||||
| 	timeout       time.Duration | ||||
| 	// Server capabilities discovered at runtime | ||||
| 	features      map[string]string | ||||
| 	skipEPSV      bool | ||||
| 	mlstSupported bool | ||||
| } | ||||
| 
 | ||||
| // DialOption represents an option to start a new connection with DialWithOptions | ||||
| type DialOption struct { | ||||
| 	setup func(do *dialOptions) | ||||
| } | ||||
| 
 | ||||
| // dialOptions contains all the options set by DialOption.setup | ||||
| type dialOptions struct { | ||||
| 	context     context.Context | ||||
| 	dialer      net.Dialer | ||||
| 	tlsConfig   tls.Config | ||||
| 	conn        net.Conn | ||||
| 	disableEPSV bool | ||||
| 	location    *time.Location | ||||
| } | ||||
| 
 | ||||
| // Entry describes a file and is returned by List(). | ||||
| type Entry struct { | ||||
| 	Name string | ||||
| @ -55,41 +69,44 @@ type Response struct { | ||||
| 	closed bool | ||||
| } | ||||
| 
 | ||||
| // Connect is an alias to Dial, for backward compatibility | ||||
| func Connect(addr string) (*ServerConn, error) { | ||||
| 	return Dial(addr) | ||||
| } | ||||
| // DialWithOptions connects to the specified address with optinal options | ||||
| func DialWithOptions(addr string, options ...DialOption) (*ServerConn, error) { | ||||
| 	do := &dialOptions{} | ||||
| 	for _, option := range options { | ||||
| 		option.setup(do) | ||||
| 	} | ||||
| 
 | ||||
| // Dial is like DialTimeout with no timeout | ||||
| func Dial(addr string) (*ServerConn, error) { | ||||
| 	return DialTimeout(addr, 0) | ||||
| } | ||||
| 	if do.location == nil { | ||||
| 		do.location = time.UTC | ||||
| 	} | ||||
| 
 | ||||
| // DialTimeout initializes the connection to the specified ftp server address. | ||||
| // | ||||
| // It is generally followed by a call to Login() as most FTP commands require | ||||
| // an authenticated user. | ||||
| func DialTimeout(addr string, timeout time.Duration) (*ServerConn, error) { | ||||
| 	tconn, err := net.DialTimeout("tcp", addr, timeout) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	tconn := do.conn | ||||
| 	if tconn == nil { | ||||
| 		ctx := do.context | ||||
| 
 | ||||
| 		if ctx == nil { | ||||
| 			ctx = context.Background() | ||||
| 		} | ||||
| 
 | ||||
| 		conn, err := do.dialer.DialContext(ctx, "tcp", addr) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tconn = conn | ||||
| 	} | ||||
| 
 | ||||
| 	// Use the resolved IP address in case addr contains a domain name | ||||
| 	// If we use the domain name, we might not resolve to the same IP. | ||||
| 	remoteAddr := tconn.RemoteAddr().(*net.TCPAddr) | ||||
| 
 | ||||
| 	conn := textproto.NewConn(tconn) | ||||
| 
 | ||||
| 	c := &ServerConn{ | ||||
| 		conn:     conn, | ||||
| 		host:     remoteAddr.IP.String(), | ||||
| 		timeout:  timeout, | ||||
| 		options:  do, | ||||
| 		features: make(map[string]string), | ||||
| 		Location: time.UTC, | ||||
| 		conn:     textproto.NewConn(tconn), | ||||
| 		host:     remoteAddr.IP.String(), | ||||
| 	} | ||||
| 
 | ||||
| 	_, _, err = c.conn.ReadResponse(StatusReady) | ||||
| 	_, _, err := c.conn.ReadResponse(StatusReady) | ||||
| 	if err != nil { | ||||
| 		c.Quit() | ||||
| 		return nil, err | ||||
| @ -108,6 +125,76 @@ func DialTimeout(addr string, timeout time.Duration) (*ServerConn, error) { | ||||
| 	return c, nil | ||||
| } | ||||
| 
 | ||||
| // DialWithTimeout returns a DialOption that configures the ServerConn with specified timeout | ||||
| func DialWithTimeout(timeout time.Duration) DialOption { | ||||
| 	return DialOption{func(do *dialOptions) { | ||||
| 		do.dialer.Timeout = timeout | ||||
| 	}} | ||||
| } | ||||
| 
 | ||||
| // DialWithDialer returns a DialOption that configures the ServerConn with specified net.Dialer | ||||
| func DialWithDialer(dialer net.Dialer) DialOption { | ||||
| 	return DialOption{func(do *dialOptions) { | ||||
| 		do.dialer = dialer | ||||
| 	}} | ||||
| } | ||||
| 
 | ||||
| // DialWithNetConn returns a DialOption that configures the ServerConn with the underlying net.Conn | ||||
| func DialWithNetConn(conn net.Conn) DialOption { | ||||
| 	return DialOption{func(do *dialOptions) { | ||||
| 		do.conn = conn | ||||
| 	}} | ||||
| } | ||||
| 
 | ||||
| // DialWithDisabledEPSV returns a DialOption that configures the ServerConn with EPSV disabled | ||||
| // Note that EPSV is only used when advertised in the server features. | ||||
| func DialWithDisabledEPSV(disabled bool) DialOption { | ||||
| 	return DialOption{func(do *dialOptions) { | ||||
| 		do.disableEPSV = disabled | ||||
| 	}} | ||||
| } | ||||
| 
 | ||||
| // DialWithLocation returns a DialOption that configures the ServerConn with specified time.Location | ||||
| // The lococation is used to parse the dates sent by the server which are in server's timezone | ||||
| func DialWithLocation(location *time.Location) DialOption { | ||||
| 	return DialOption{func(do *dialOptions) { | ||||
| 		do.location = location | ||||
| 	}} | ||||
| } | ||||
| 
 | ||||
| // DialWithContext returns a DialOption that configures the ServerConn with specified context | ||||
| // The context will be used for the initial connection setup | ||||
| func DialWithContext(ctx context.Context) DialOption { | ||||
| 	return DialOption{func(do *dialOptions) { | ||||
| 		do.context = ctx | ||||
| 	}} | ||||
| } | ||||
| 
 | ||||
| // DialWithTLS returns a DialOption that configures the ServerConn with specified TLS config | ||||
| func DialWithTLS(tlsConfig tls.Config) DialOption { | ||||
| 	return DialOption{func(do *dialOptions) { | ||||
| 		do.tlsConfig = tlsConfig | ||||
| 	}} | ||||
| } | ||||
| 
 | ||||
| // Connect is an alias to Dial, for backward compatibility | ||||
| func Connect(addr string) (*ServerConn, error) { | ||||
| 	return Dial(addr) | ||||
| } | ||||
| 
 | ||||
| // Dial is like DialTimeout with no timeout | ||||
| func Dial(addr string) (*ServerConn, error) { | ||||
| 	return DialTimeout(addr, 0) | ||||
| } | ||||
| 
 | ||||
| // DialTimeout initializes the connection to the specified ftp server address. | ||||
| // | ||||
| // It is generally followed by a call to Login() as most FTP commands require | ||||
| // an authenticated user. | ||||
| func DialTimeout(addr string, timeout time.Duration) (*ServerConn, error) { | ||||
| 	return DialWithOptions(addr, DialWithTimeout(timeout)) | ||||
| } | ||||
| 
 | ||||
| // Login authenticates the client with specified user and password. | ||||
| // | ||||
| // "anonymous"/"anonymous" is a common user/password scheme for FTP servers | ||||
| @ -271,13 +358,13 @@ func (c *ServerConn) pasv() (host string, port int, err error) { | ||||
| // getDataConnPort returns a host, port for a new data connection | ||||
| // it uses the best available method to do so | ||||
| func (c *ServerConn) getDataConnPort() (string, int, error) { | ||||
| 	if !c.DisableEPSV { | ||||
| 	if !c.options.disableEPSV && !c.skipEPSV { | ||||
| 		if port, err := c.epsv(); err == nil { | ||||
| 			return c.host, port, nil | ||||
| 		} | ||||
| 
 | ||||
| 		// if there is an error, disable EPSV for the next attempts | ||||
| 		c.DisableEPSV = true | ||||
| 		// if there is an error, skip EPSV for the next attempts | ||||
| 		c.skipEPSV = true | ||||
| 	} | ||||
| 
 | ||||
| 	return c.pasv() | ||||
| @ -290,7 +377,7 @@ func (c *ServerConn) openDataConn() (net.Conn, error) { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return net.DialTimeout("tcp", net.JoinHostPort(host, strconv.Itoa(port)), c.timeout) | ||||
| 	return c.options.dialer.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(port))) | ||||
| } | ||||
| 
 | ||||
| // cmd is a helper function to execute a command and check for the expected FTP | ||||
| @ -383,7 +470,7 @@ func (c *ServerConn) List(path string) (entries []*Entry, err error) { | ||||
| 	scanner := bufio.NewScanner(r) | ||||
| 	now := time.Now() | ||||
| 	for scanner.Scan() { | ||||
| 		entry, err := parser(scanner.Text(), now, c.Location) | ||||
| 		entry, err := parser(scanner.Text(), now, c.options.location) | ||||
| 		if err == nil { | ||||
| 			entries = append(entries, entry) | ||||
| 		} | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user