diff --git a/downloader/torrent_info.go b/downloader/torrent.go similarity index 82% rename from downloader/torrent_info.go rename to downloader/torrent.go index d641975..1dda197 100644 --- a/downloader/torrent_info.go +++ b/downloader/torrent.go @@ -26,7 +26,7 @@ import ( "strconv" "github.com/xgfone/bt/metainfo" - "github.com/xgfone/bt/peerprotocol" + pp "github.com/xgfone/bt/peerprotocol" ) const peerBlockSize = 16384 // 16KiB. @@ -90,8 +90,8 @@ type TorrentDownloader struct { responses chan TorrentResponse ondht func(string, uint16) - ebits peerprotocol.ExtensionBits - ehmsg peerprotocol.ExtendedHandshakeMsg + ebits pp.ExtensionBits + ehmsg pp.ExtendedHandshakeMsg } // NewTorrentDownloader returns a new TorrentDownloader. @@ -107,8 +107,8 @@ func NewTorrentDownloader(c ...TorrentDownloaderConfig) *TorrentDownloader { requests: make(chan request, conf.WorkerNum), responses: make(chan TorrentResponse, 1024), - ehmsg: peerprotocol.ExtendedHandshakeMsg{ - M: map[string]uint8{peerprotocol.ExtendedMessageNameMetadata: 1}, + ehmsg: pp.ExtendedHandshakeMsg{ + M: map[string]uint8{pp.ExtendedMessageNameMetadata: 1}, }, } @@ -116,7 +116,7 @@ func NewTorrentDownloader(c ...TorrentDownloaderConfig) *TorrentDownloader { go d.worker() } - d.ebits.Set(peerprotocol.ExtensionBitExtended) + d.ebits.Set(pp.ExtensionBitExtended) return d } @@ -151,9 +151,9 @@ func (d *TorrentDownloader) Close() { func (d *TorrentDownloader) OnDHTNode(cb func(host string, port uint16)) { d.ondht = cb if cb == nil { - d.ebits.Unset(peerprotocol.ExtensionBitDHT) + d.ebits.Unset(pp.ExtensionBitDHT) } else { - d.ebits.Set(peerprotocol.ExtensionBitDHT) + d.ebits.Set(pp.ExtensionBitDHT) } } @@ -174,18 +174,19 @@ func (d *TorrentDownloader) worker() { func (d *TorrentDownloader) download(host string, port uint16, peerID, infohash metainfo.Hash) (err error) { addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)) - conn, err := peerprotocol.NewPeerConnByDial(d.conf.ID, addr) + conn, err := pp.NewPeerConnByDial(addr, d.conf.ID, infohash) if err != nil { return fmt.Errorf("fail to dial to '%s': %s", addr, err) } defer conn.Close() - conn.ExtensionBits = d.ebits - rmsg, err := conn.Handshake(infohash) - if err != nil || rmsg.InfoHash != infohash || !rmsg.IsSupportExtended() { + conn.ExtBits = d.ebits + if err = conn.Handshake(); err != nil { return - } else if !peerID.IsZero() && peerID != rmsg.PeerID { - return fmt.Errorf("inconsistent peer id '%s'", rmsg.PeerID.HexString()) + } else if !conn.PeerExtBits.IsSupportExtended() { + return fmt.Errorf("the remote peer '%s' does not support Extended", addr) + } else if !peerID.IsZero() && peerID != conn.PeerID { + return fmt.Errorf("inconsistent peer id '%s'", conn.PeerID.HexString()) } if err = conn.SendExtHandshakeMsg(d.ehmsg); err != nil { @@ -196,7 +197,7 @@ func (d *TorrentDownloader) download(host string, port uint16, var piecesNum int var metadataSize int var utmetadataID uint8 - var msg peerprotocol.Message + var msg pp.Message for { if msg, err = conn.ReadMsg(); err != nil { @@ -214,8 +215,8 @@ func (d *TorrentDownloader) download(host string, port uint16, } switch msg.Type { - case peerprotocol.Extended: - case peerprotocol.Port: + case pp.Extended: + case pp.Port: if d.ondht != nil { d.ondht(host, msg.Port) } @@ -225,17 +226,17 @@ func (d *TorrentDownloader) download(host string, port uint16, } switch msg.ExtendedID { - case peerprotocol.ExtendedIDHandshake: + case pp.ExtendedIDHandshake: if utmetadataID > 0 { return fmt.Errorf("rehandshake from the peer '%s'", conn.RemoteAddr().String()) } - var ehmsg peerprotocol.ExtendedHandshakeMsg + var ehmsg pp.ExtendedHandshakeMsg if err = ehmsg.Decode(msg.ExtendedPayload); err != nil { return } - utmetadataID = ehmsg.M[peerprotocol.ExtendedMessageNameMetadata] + utmetadataID = ehmsg.M[pp.ExtendedMessageNameMetadata] if utmetadataID == 0 { return errors.New(`the peer does not support "ut_metadata"`) } @@ -253,12 +254,12 @@ func (d *TorrentDownloader) download(host string, port uint16, return } - var utmsg peerprotocol.UtMetadataExtendedMsg + var utmsg pp.UtMetadataExtendedMsg if err = utmsg.DecodeFromPayload(msg.ExtendedPayload); err != nil { return } - if utmsg.MsgType != peerprotocol.UtMetadataExtendedMsgTypeData { + if utmsg.MsgType != pp.UtMetadataExtendedMsgTypeData { continue } @@ -283,7 +284,7 @@ func (d *TorrentDownloader) download(host string, port uint16, d.responses <- TorrentResponse{ Host: host, Port: port, - PeerID: rmsg.PeerID, + PeerID: conn.PeerID, InfoHash: infohash, InfoBytes: metadataInfo, } @@ -294,18 +295,18 @@ func (d *TorrentDownloader) download(host string, port uint16, } } -func (d *TorrentDownloader) requestPieces(conn *peerprotocol.PeerConn, utMetadataID uint8, piecesNum int) { +func (d *TorrentDownloader) requestPieces(conn *pp.PeerConn, utMetadataID uint8, piecesNum int) { for i := 0; i < piecesNum; i++ { - payload, err := peerprotocol.UtMetadataExtendedMsg{ - MsgType: peerprotocol.UtMetadataExtendedMsgTypeRequest, + payload, err := pp.UtMetadataExtendedMsg{ + MsgType: pp.UtMetadataExtendedMsgTypeRequest, Piece: i, }.EncodeToBytes() if err != nil { panic(err) } - msg := peerprotocol.Message{ - Type: peerprotocol.Extended, + msg := pp.Message{ + Type: pp.Extended, ExtendedID: utMetadataID, ExtendedPayload: payload, }