golang net/http源代码阅读

文章目录

golang的net/http包同时支持客户端和服务端,而本文主要分析客户端的代码,通过查看源代码可以更好的理解net/http的使用,也能好好的学习一下golang官方的代码风格。

快速入门

常用的使用场景,net/http做了一些快捷方法,更复杂的操作需要构建Request对象,net/http客户端的更详细使用教程可以参考我之前的文章: https://youerning.top/post/go-http-client-tutorial/

package main

import (
	"fmt"
	"io"
	"net/http"
)

func main() {
	resp, err := http.Get("https://baidu.com")
	if err != nil {
		panic(err)
	}
	defer func() {
		err := resp.Body.Close()
		if err != nil {
			fmt.Println("关闭Body失败:", err)
		}
	}()
	data, err := io.ReadAll(resp.Body)
	if err != nil {
		panic(err)
	}
	fmt.Printf("%s", data)
}

调用链

看源代码的一个推荐方式是先全局在细部,如果在建立全局观之前就陷入细节,可能很久都走不出来甚至放弃,切记切记。

一个看golang源码的惯用技巧是先折叠 if err != nil {...}这样的代码块。

首先看看宏观或者说应用层的调用关系。

var DefaultClient = &Client{}

func Get(url string) (resp *Response, err error) {
    // 1.
	return DefaultClient.Get(url)
}

func (c *Client) Get(url string) (resp *Response, err error) {
    // 2.
	req, err := NewRequest("GET", url, nil)
    // 3.
	return c.Do(req)
}

func (c *Client) Do(req *Request) (*Response, error) {
	return c.do(req)
}

func (c *Client) do(req *Request) (retres *Response, reterr error) {
    // 4.
	for {
		// 5.
		if len(reqs) > 0 {
			err = c.checkRedirect(req, reqs)
            if err == ErrUseLastResponse {
				return resp, nil
			}
		}

		reqs = append(reqs, req)
		var err error
		var didTimeout func() bool
        // 6.
		if resp, didTimeout, err = c.send(req, deadline); err != nil {
			return nil, uerr(err)
		}

        // 7.
		var shouldRedirect bool
		redirectMethod, shouldRedirect, includeBody = redirectBehavior(req.Method, resp, reqs[0])
		if !shouldRedirect {
			return resp, nil
		}
		req.closeBody()
	}
}

func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
    // 8.
	resp, didTimeout, err = send(req, c.transport(), deadline)
	return resp, nil, nil
}

func (c *Client) transport() RoundTripper {
    // 9
	if c.Transport != nil {
		return c.Transport
	}
	return DefaultTransport
}

func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
	req := ireq
    // 10.
	stopTimer, didTimeout := setRequestCancel(req, rt, deadline)
    // 11.
	resp, err = rt.RoundTrip(req)
	if err != nil {
        // 12.
		stopTimer()
		return nil, didTimeout, err
	}
    // 13.
	if resp.Body == nil {
		if resp.ContentLength > 0 && req.Method != "HEAD" {
			return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a *Response with content length %d but a nil Body", rt, resp.ContentLength)
		}
		resp.Body = io.NopCloser(strings.NewReader(""))
	}
	return resp, nil, nil
}

代码分解如下:

  1. 使用默认的Client作为client请求内容
  2. 构造一个Request, 一般来说稍微复杂点的请求就需要我们自己构造了
  3. 请求request
  4. 用一个循环来处理重定向的情况,如果没有重定向就直接返回
  5. 请求前检查重新向 如果不是第二次请求就检查是否需要重定向
  6. 较底层的请求方法
  7. 请求后检查重新向 检查是否是3xx的重定向请求
  8. 第6步的细节, 获取RoundTripper对象
  9. 如果客户端没有设置Transport就使用DefaultTransport
  10. 设置超时的取消逻辑
  11. 发送请求的核心逻辑了,涉及到tcp连接的构造和发送数据,这里暂时不深入了,在后面专门介绍
  12. 检查响应体的Body,保证Body总是一个可读的对象

简化的数据流如下:

sequenceDiagram main ->>+ Get: 使用默认客户端DefaultClient Get ->>+ Do: 构造Request Do ->>+ send: 仅控制http协议层上的逻辑 send -->> RoundTrip: 做参数校验等 RoundTrip -->> main: 做tcp连接并返回(*Response, error)

基于上面的流程分解,我们可以逐个攻破请求过程中各个比较重要的对象和方法。

Request

Request对象作为请求的抽象,包含了一个请求应该包含的所有信息,如URL, Headers, Body等。

func NewRequest(method, url string, body io.Reader) (*Request, error) {
	return NewRequestWithContext(context.Background(), method, url, body)
}

func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*Request, error) {
	if method == "" {
		method = "GET"
	}
	if !validMethod(method) {
		return nil, fmt.Errorf("net/http: invalid method %q", method)
	}
	if ctx == nil {
		return nil, errors.New("net/http: nil Context")
	}
	u, err := urlpkg.Parse(url)
	rc, ok := body.(io.ReadCloser)
	if !ok && body != nil {
		rc = io.NopCloser(body)
	}
	u.Host = removeEmptyPort(u.Host)
	req := &Request{
		ctx:        ctx,
		Method:     method,
		URL:        u,
		Proto:      "HTTP/1.1",
		ProtoMajor: 1,
		ProtoMinor: 1,
		Header:     make(Header),
		Body:       rc,
		Host:       u.Host,
	}
	if body != nil {
		switch v := body.(type) {
		case *bytes.Buffer:
			req.ContentLength = int64(v.Len())
			buf := v.Bytes()
			req.GetBody = func() (io.ReadCloser, error) {
				r := bytes.NewReader(buf)
				return io.NopCloser(r), nil
			}
		case *bytes.Reader:
			req.ContentLength = int64(v.Len())
			snapshot := *v
			req.GetBody = func() (io.ReadCloser, error) {
				r := snapshot
				return io.NopCloser(&r), nil
			}
		case *strings.Reader:
			req.ContentLength = int64(v.Len())
			snapshot := *v
			req.GetBody = func() (io.ReadCloser, error) {
				r := snapshot
				return io.NopCloser(&r), nil
			}
		default:
		}
		if req.GetBody != nil && req.ContentLength == 0 {
			req.Body = NoBody
			req.GetBody = func() (io.ReadCloser, error) { return NoBody, nil }
		}
	}

	return req, nil
}

可以看到Request的构造主要集中在参数校验, 比如方法名是否合法,ctx不可以为空,后面就是构造一个GetBody方法用于复制请求体内容。

Client

var DefaultClient = &Client{}

net/http的默认客户端构造并不复杂,如果我们需要更精细的控制,可以设置各个字段,比如CheckRedirect, Jar, Timeout等, 它们的作用从名字就可以看到,而最重要的字段Transport后文再介绍。

RoundTripper

从前文可以知道,最终的请求会落到client的Transport, 使用它的RoundTrip方法, 而DefaultClient是不设置这个字段的,所以会使用DefaultTransport

var DefaultTransport RoundTripper = &Transport{
	Proxy: ProxyFromEnvironment,
	DialContext: defaultTransportDialContext(&net.Dialer{
		Timeout:   30 * time.Second,
		KeepAlive: 30 * time.Second,
	}),
	ForceAttemptHTTP2:     true,
	MaxIdleConns:          100,
	IdleConnTimeout:       90 * time.Second,
	TLSHandshakeTimeout:   10 * time.Second,
	ExpectContinueTimeout: 1 * time.Second,
}

net/http将传输层的逻辑交给了RoundTripper这种分层的设计还是比较常见的。除了Transport, 还有fileTransport, http2Transport, 用于不同的用途。

RoundTrip

这是一个比较长的函数

func (t *Transport) roundTrip(req *Request) (*Response, error) {
    // 1. 
	t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
	ctx := req.Context()
	scheme := req.URL.Scheme
	isHTTP := scheme == "http" || scheme == "https"
	if isHTTP {
		// 校验http header字段和值是否合法
	}

    // 2.
	if altRT := t.alternateRoundTripper(req); altRT != nil {
		if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
			return resp, err
		}
		var err error
		req, err = rewindBody(req)
		if err != nil {
			return nil, err
		}
	}
	// 不是http请求报错,以及检查方法和Host,
    
	// 3.
	for {
        // 判断context是否取消或者超时
		select {
		case <-ctx.Done():
			req.closeBody()
			return nil, ctx.Err()
		default:
		}

		// 4.
		treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey}
		cm, err := t.connectMethodForRequest(treq)

		// 5.
		pconn, err := t.getConn(treq, cm)
        
        
        // 6.
		var resp *Response
		if pconn.alt != nil {
			// HTTP/2 path.
			t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
			resp, err = pconn.alt.RoundTrip(req)
		} else {
			resp, err = pconn.roundTrip(treq)
		}
        
        // 7.
		if err == nil {
			resp.Request = origReq
			return resp, nil
		}

        // 8.
		if http2isNoCachedConnError(err) {
			if t.removeIdleConn(pconn) {
				t.decConnsPerHost(pconn.cacheKey)
			}
        // 9.
		} else if !pconn.shouldRetryRequest(req, err) {
			// 判断各种错误
			return nil, err
		}

		// 10.
		req, err = rewindBody(req)
	}
}

代码分解如下

  1. 设置可选的RoundTripper, 主要是设置协议的切换规则,默认https会尝试使用http2
  2. 看当前协议是否有可选的RoundTripper, 如果是https, 会使用http2Transport
  3. 一个for循环用来处理重试逻辑
  4. 因为roundTrip会修改treq, 所以每次重试都重新创建
  5. 获取缓存的持久化的tcp连接对象,如果没有就创建
  6. 通过持久化的tcp连接对象发送请求,将请求转成二进制数据发送,并读取对方的响应,最后将响应转成Response对象返回
  7. 如果没有问题就跳出循环
  8. http2的重试机制检查
  9. 非http2的重试检查
  10. 重置请求体并重试

上面的交互大致如下:

sequenceDiagram Transport.roundTrip ->>+ getConn: 获取持久化连接 getConn -->>- Transport.roundTrip: 返回pconn对象 Transport.roundTrip ->>+ pconn.roundTrip: 发送请求 pconn.roundTrip -->>+ Transport.roundTrip: 返回(*Response, error)

persistConn

建立TCP重复请求的时候是比较大的开销,所以缓存起来是一个不错的选择,大多数http请求库都有连接池的概念,net/http也实现了这种机制。

获取连接

func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persistConn, err error) {
	req := treq.Request
	trace := treq.trace
	ctx := req.Context()
    // 
	w := &wantConn{
		cm:         cm,
		key:        cm.key(),
		ctx:        ctx,
		ready:      make(chan struct{}, 1),
		beforeDial: testHookPrePendingDial,
		afterDial:  testHookPostPendingDial,
	}
	defer func() {
		if err != nil {
			w.cancel(t, err)
		}
	}()

	// 1.
	if delivered := t.queueForIdleConn(w); delivered {
		// 如果有空闲的连接,检查没有问题就返回
	}

	cancelc := make(chan error, 1)
	t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err })

	// 2.
	t.queueForDial(w)

	// Wait for completion or cancellation.
	select {
    // 3.
	case <-w.ready:
		// 判断是否出错并处理
		return w.pc, w.err
	case <-req.Cancel:
		return nil, errRequestCanceledConn
	case <-req.Context().Done():
		return nil, req.Context().Err()
	case err := <-cancelc:
		if err == errRequestCanceled {
			err = errRequestCanceledConn
		}
		return nil, err
	}
}

func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
    // 4.
	if t.DisableKeepAlives {
		return false
	}

	// 5.
	if list, ok := t.idleConn[w.key]; ok {
		// 检查连接是否可用
	}
	return false
}

func (t *Transport) queueForDial(w *wantConn) {
	w.beforeDial()
    // 6.
	if t.MaxConnsPerHost <= 0 {
		go t.dialConnFor(w)
		return
	}
    // 当MaxConnsPerHost>0时,依次建立请求
}

代码分解如下:

  1. 首先从空闲连接的队列里面检查是否有空闲的连接可以使用
  2. 没有空闲的链接,就将建立连接的请求排队
  3. 通过chan来等待连接建立, 当请求建立之后就会关闭w.ready
  4. 如果禁用keepalive就不会有空闲连接,因为使用过后就会断开
  5. 获取空闲连接的队列并判断是否有空闲连接
  6. 开启一个协程建立连接

queueForDial的队列概念其实是当MaxConnsPerHost>0的时候需要排队

建立连接

func (t *Transport) dialConnFor(w *wantConn) {
	defer w.afterDial()
	pc, err := t.dialConn(w.ctx, w.cm)
}

func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) {
	pconn = &persistConn{
		t:             t,
		cacheKey:      cm.key(),
		reqch:         make(chan requestAndChan, 1),
		writech:       make(chan writeRequest, 1),
		closech:       make(chan struct{}),
		writeErrCh:    make(chan error, 1),
		writeLoopDone: make(chan struct{}),
	}

	if cm.scheme() == "https" && t.hasCustomTLSDialer() {
		var err error
        // 1.
		pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr())
	} else {
        // 2.
		conn, err := t.dial(ctx, "tcp", cm.addr())
		pconn.conn = conn
        // 3.
		if cm.scheme() == "https" {
			if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil {
				return nil, wrapErr(err)
			}
		}
	}

	// 如果有代理就设置代理
	// 4.
	pconn.br = bufio.NewReaderSize(pconn, t.readBufferSize())
	pconn.bw = bufio.NewWriterSize(persistConnWriter{pconn}, t.writeBufferSize())

    // 5.
	go pconn.readLoop()
	go pconn.writeLoop()
	return pconn, nil
}

代码分解如下:

  1. 如果是https协议并设置了自定义的TLS拨号器,就基于此建立请求
  2. 首先建立tcp连接
  3. 如果是https协议就在外面包一层https连接
  4. 用bufio包装一下连接对象,提升读写速度
  5. 同时开始读写循环

发送和接受请求

func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
	
	startBytesWritten := pc.nwrite
	writeErrCh := make(chan error, 1)
    // 1.
	pc.writech <- writeRequest{req, writeErrCh, continueCh}

	resc := make(chan responseAndError)
    // 2.
	pc.reqch <- requestAndChan{
		req:        req.Request,
		cancelKey:  req.cancelKey,
		ch:         resc,
		addedGzip:  requestedGzip,
		continueCh: continueCh,
		callerGone: gone,
	}

	for {
		testHookWaitResLoop()
		select {
        // 3.
		case re := <-resc:
			return re.res, nil
		case <-cancelChan:
			canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled)
			cancelChan = nil
		case <-ctxDoneChan:
			canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err())
			cancelChan = nil
			ctxDoneChan = nil
		}
	}
}


func (pc *persistConn) writeLoop() {
	defer close(pc.writeLoopDone)
	for {
		select {
        // 4.
		case wr := <-pc.writech:
			startBytesWritten := pc.nwrite
            // 5.
			err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh))
		case <-pc.closech:
			return
		}
	}
}

func (pc *persistConn) readLoop() {
    // 6.
	alive := true
	for alive {
		var resp *Response
		if err == nil {
            // 7.
			resp, err = pc.readResponse(rc, trace)
		} else {
			err = transportReadFromServerError{err}
			closeErr = err
		}

		select {
        // 8.
		case rc.ch <- responseAndError{res: resp}:
		case <-rc.callerGone:
			return
		}
	}
}

代码分解如下:

  1. 通过writech将请求发送给WriteLoop, writeloop负责将请求序列化成二进制数据流发送给对端
  2. 通过reqch将请求发送给ReadLoop, ReadLoop基于请求来解析对端发送过来的二进制数据流
  3. 等待结果,结果有ReadLoop发送
  4. WriteLoop接受到请求后开始工作
  5. 发送请求
  6. for循环用于重复读,不太理解
  7. 读取并解析对端发送过来的二进制数据流
  8. 将结果通过rc.ch发送给roundTrip

http请求序列化和解析http请求的内容还是比较复杂的,这里就略过了。

总结

通过阅读net/http源代码发现,细节真多,很多异常的处理看不懂,再者就是golang比较喜欢用chan来同步数据,这篇文章主要是梳理了一下net/http的调用流程,很多细节都忽略了,比如请求的序列化,响应的解析,超时的处理等,这些东西都比较细,值得单独写一篇文章,以后有机会在写吧。