286 lines
5.6 KiB
Go
286 lines
5.6 KiB
Go
package downloader
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"mime"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"sync"
|
|
)
|
|
|
|
const UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36"
|
|
|
|
// PartFile 文件切片
|
|
type PartFile struct {
|
|
Index int64
|
|
From int64
|
|
To int64
|
|
Data []byte
|
|
Done bool
|
|
}
|
|
|
|
type Downloader struct {
|
|
FileSize int64
|
|
Url string
|
|
FileName string
|
|
Path string
|
|
PartNum int64
|
|
donePart []PartFile
|
|
}
|
|
|
|
// NewDownloader 创建下载器
|
|
func NewDownloader(url, outputDir, outputFileName string, partNum int64) *Downloader {
|
|
if outputDir == "" {
|
|
wd, err := os.Getwd() //获取当前工作目录
|
|
if err != nil {
|
|
log.Println(err)
|
|
}
|
|
outputDir = wd
|
|
}
|
|
return &Downloader{
|
|
FileSize: 0,
|
|
Url: url,
|
|
FileName: outputFileName,
|
|
Path: outputDir,
|
|
PartNum: partNum,
|
|
donePart: make([]PartFile, partNum),
|
|
}
|
|
}
|
|
|
|
func (d *Downloader) getNewRequest(method string) (*http.Request, error) {
|
|
req, err := http.NewRequest(
|
|
method,
|
|
d.Url,
|
|
nil)
|
|
|
|
return req, err
|
|
}
|
|
|
|
func (d *Downloader) head() (int64, error) {
|
|
req, err := d.getNewRequest(http.MethodHead)
|
|
req = setNewHeader(req)
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if resp.StatusCode > 299 {
|
|
return 0, errors.New(fmt.Sprintf("Can't process, response is %v", resp.StatusCode))
|
|
}
|
|
|
|
if resp.Header.Get("Accept-Ranges") != "bytes" {
|
|
return 0, errors.New("服务器不支持文件断点续传")
|
|
}
|
|
|
|
d.FileName = GetFileInfoFromResponse(resp)
|
|
length, err := strconv.Atoi(resp.Header.Get("Content-Length"))
|
|
|
|
return int64(length), err
|
|
}
|
|
|
|
// 下载切片
|
|
func (d *Downloader) downloadPart(c PartFile, f *os.File) error {
|
|
r, err := d.getNewRequest("GET")
|
|
r = setNewHeader(r)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.Printf("开始[%d]下载from:%d to:%d\n", c.Index, c.From, c.To)
|
|
r.Header.Set("Range", fmt.Sprintf("bytes=%v-%v", c.From, c.To))
|
|
resp, err := http.DefaultClient.Do(r)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp.StatusCode > 299 {
|
|
return errors.New(fmt.Sprintf("服务器错误状态码: %v", resp.StatusCode))
|
|
}
|
|
defer func(Body io.ReadCloser) {
|
|
_ = Body.Close()
|
|
}(resp.Body)
|
|
|
|
bs, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(bs) != int(c.To-c.From+1) {
|
|
}
|
|
c.Data = bs
|
|
c.Done = true
|
|
|
|
d.donePart[c.Index] = c
|
|
|
|
_, err = f.WriteAt(bs, c.From)
|
|
|
|
if err != nil {
|
|
c.Done = true
|
|
}
|
|
|
|
log.Printf("结束[%d]下载", c.Index)
|
|
return err
|
|
}
|
|
|
|
func (d *Downloader) checkIntegrity(t *os.File) error {
|
|
log.Println("开始合并文件")
|
|
|
|
totalSize := 0
|
|
|
|
for _, s := range d.donePart {
|
|
//hash.Write(s.Data)
|
|
totalSize += len(s.Data)
|
|
}
|
|
|
|
if int64(totalSize) != d.FileSize {
|
|
return errors.New("文件不完整")
|
|
}
|
|
|
|
_ = t.Close()
|
|
return os.Rename(filepath.Join(d.Path, d.FileName+".tmp"), filepath.Join(d.Path, d.FileName))
|
|
}
|
|
|
|
// Run 开始下载任务
|
|
func (d *Downloader) Run() error {
|
|
fileTotalSize, err := d.head()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d.FileSize = fileTotalSize
|
|
|
|
jobs := make([]PartFile, d.PartNum)
|
|
eachSize := fileTotalSize / d.PartNum
|
|
|
|
path := filepath.Join(d.Path, d.FileName+".tmp")
|
|
|
|
tmpFile := new(os.File)
|
|
|
|
fByte := make([]byte, d.FileSize)
|
|
|
|
if exists(path) {
|
|
tmpFile, err = os.OpenFile(path, os.O_RDWR|os.O_TRUNC|os.O_CREATE, 0)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fByte, err = io.ReadAll(tmpFile)
|
|
} else {
|
|
tmpFile, err = os.Create(path)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func(tmpFile *os.File) {
|
|
_ = tmpFile.Close()
|
|
}(tmpFile)
|
|
|
|
for i := range jobs {
|
|
i64 := int64(i)
|
|
jobs[i64].Index = i64
|
|
if i == 0 {
|
|
jobs[i64].From = 0
|
|
} else {
|
|
jobs[i64].From = jobs[i64-1].To + 1
|
|
}
|
|
if i64 < d.PartNum-1 {
|
|
jobs[i64].To = jobs[i64].From + eachSize
|
|
} else {
|
|
//the last filePart
|
|
jobs[i64].To = fileTotalSize - 1
|
|
}
|
|
}
|
|
|
|
for i, j := range jobs {
|
|
tmpJob := j
|
|
emptyTmp := make([]byte, tmpJob.To-j.From)
|
|
if bytes.Compare(emptyTmp, fByte[tmpJob.From:j.To]) != 0 {
|
|
tmpJob.Data = fByte[j.From : j.To+1]
|
|
tmpJob.Done = true
|
|
d.donePart[tmpJob.Index] = tmpJob
|
|
} else {
|
|
tmpJob.Done = false
|
|
}
|
|
jobs[i] = tmpJob
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for _, j := range jobs {
|
|
if !j.Done {
|
|
wg.Add(1)
|
|
go func(job PartFile) {
|
|
defer wg.Done()
|
|
err := d.downloadPart(job, tmpFile)
|
|
if err != nil {
|
|
log.Println("下载文件失败:", err, job)
|
|
}
|
|
}(j)
|
|
}
|
|
}
|
|
wg.Wait()
|
|
return d.checkIntegrity(tmpFile)
|
|
}
|
|
|
|
func getRedirectInfo(u, userAgent string) (*http.Response, error) {
|
|
log.Println("获取重定向信息")
|
|
var a *url.URL
|
|
a, _ = url.Parse(u)
|
|
header := http.Header{}
|
|
|
|
//header.Add("Cookie", rawCookies)
|
|
header.Add("User-Agent", userAgent)
|
|
request := http.Request{
|
|
Header: header,
|
|
Method: "GET",
|
|
URL: a,
|
|
}
|
|
|
|
client := &http.Client{
|
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
|
return http.ErrUseLastResponse
|
|
},
|
|
}
|
|
response, err := client.Do(&request)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
func setNewHeader(r *http.Request) *http.Request {
|
|
r.Header.Add("User-Agent", UserAgent)
|
|
r.Header.Add("Upgrade-Insecure-Requests", "1")
|
|
return r
|
|
}
|
|
|
|
func exists(path string) bool {
|
|
_, err := os.Stat(path) //os.Stat获取文件信息
|
|
if err != nil {
|
|
if os.IsExist(err) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func GetFileInfoFromResponse(resp *http.Response) string {
|
|
contentDisposition := resp.Header.Get("Content-Disposition")
|
|
if contentDisposition != "" {
|
|
_, params, err := mime.ParseMediaType(contentDisposition)
|
|
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return params["filename"]
|
|
}
|
|
filename := filepath.Base(resp.Request.URL.Path)
|
|
return filename
|
|
}
|