You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
576 lines
16 KiB
576 lines
16 KiB
// Copyright © 2023 OpenIM SDK. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package file
|
|
|
|
import (
|
|
"context"
|
|
"crypto/md5"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/openimsdk/openim-sdk-core/v3/internal/util"
|
|
"github.com/openimsdk/openim-sdk-core/v3/pkg/constant"
|
|
"github.com/openimsdk/openim-sdk-core/v3/pkg/db/db_interface"
|
|
"github.com/openimsdk/openim-sdk-core/v3/pkg/db/model_struct"
|
|
"github.com/openimsdk/tools/errs"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/openimsdk/protocol/third"
|
|
"github.com/openimsdk/tools/log"
|
|
)
|
|
|
|
type UploadFileReq struct {
|
|
Filepath string `json:"filepath"`
|
|
Name string `json:"name"`
|
|
ContentType string `json:"contentType"`
|
|
Cause string `json:"cause"`
|
|
Uuid string `json:"uuid"`
|
|
}
|
|
|
|
type UploadFileResp struct {
|
|
URL string `json:"url"`
|
|
}
|
|
|
|
type partInfo struct {
|
|
ContentType string
|
|
PartSize int64
|
|
PartNum int
|
|
FileMd5 string
|
|
PartMd5 string
|
|
PartSizes []int64
|
|
PartMd5s []string
|
|
}
|
|
|
|
func NewFile(database db_interface.DataBase, loginUserID string) *File {
|
|
return &File{database: database, loginUserID: loginUserID, confLock: &sync.Mutex{}, mapLocker: &sync.Mutex{}, uploading: make(map[string]*lockInfo)}
|
|
}
|
|
|
|
type File struct {
|
|
database db_interface.DataBase
|
|
loginUserID string
|
|
confLock sync.Locker
|
|
partLimit *third.PartLimitResp
|
|
mapLocker sync.Locker
|
|
uploading map[string]*lockInfo
|
|
}
|
|
|
|
type lockInfo struct {
|
|
count int32
|
|
locker sync.Locker
|
|
}
|
|
|
|
func (f *File) lockHash(hash string) {
|
|
f.mapLocker.Lock()
|
|
locker, ok := f.uploading[hash]
|
|
if !ok {
|
|
locker = &lockInfo{count: 0, locker: &sync.Mutex{}}
|
|
f.uploading[hash] = locker
|
|
}
|
|
atomic.AddInt32(&locker.count, 1)
|
|
f.mapLocker.Unlock()
|
|
locker.locker.Lock()
|
|
}
|
|
|
|
func (f *File) unlockHash(hash string) {
|
|
f.mapLocker.Lock()
|
|
locker, ok := f.uploading[hash]
|
|
if !ok {
|
|
f.mapLocker.Unlock()
|
|
return
|
|
}
|
|
if atomic.AddInt32(&locker.count, -1) == 0 {
|
|
delete(f.uploading, hash)
|
|
}
|
|
f.mapLocker.Unlock()
|
|
locker.locker.Unlock()
|
|
}
|
|
|
|
func (f *File) UploadFile(ctx context.Context, req *UploadFileReq, cb UploadFileCallback) (*UploadFileResp, error) {
|
|
if cb == nil {
|
|
cb = emptyUploadCallback{}
|
|
}
|
|
if req.Name == "" {
|
|
return nil, errors.New("name is empty")
|
|
}
|
|
if req.Name[0] == '/' {
|
|
req.Name = req.Name[1:]
|
|
}
|
|
if prefix := f.loginUserID + "/"; !strings.HasPrefix(req.Name, prefix) {
|
|
req.Name = prefix + req.Name
|
|
}
|
|
file, err := Open(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer file.Close()
|
|
fileSize := file.Size()
|
|
cb.Open(fileSize)
|
|
info, err := f.getPartInfo(ctx, file, fileSize, cb)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if req.ContentType == "" {
|
|
req.ContentType = info.ContentType
|
|
}
|
|
partSize := info.PartSize
|
|
partSizes := info.PartSizes
|
|
partMd5s := info.PartMd5s
|
|
partMd5Val := info.PartMd5
|
|
if err := file.StartSeek(0); err != nil {
|
|
return nil, err
|
|
}
|
|
f.lockHash(partMd5Val)
|
|
defer f.unlockHash(partMd5Val)
|
|
maxParts := 20
|
|
if maxParts > len(partSizes) {
|
|
maxParts = len(partSizes)
|
|
}
|
|
uploadInfo, err := f.getUpload(ctx, &third.InitiateMultipartUploadReq{
|
|
Hash: partMd5Val,
|
|
Size: fileSize,
|
|
PartSize: partSize,
|
|
MaxParts: int32(maxParts), // 一次性获取签名数量
|
|
Cause: req.Cause,
|
|
Name: req.Name,
|
|
ContentType: req.ContentType,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if uploadInfo.Resp.Upload == nil {
|
|
cb.Complete(fileSize, uploadInfo.Resp.Url, 0)
|
|
return &UploadFileResp{
|
|
URL: uploadInfo.Resp.Url,
|
|
}, nil
|
|
}
|
|
if uploadInfo.Resp.Upload.PartSize != partSize {
|
|
f.cleanPartLimit()
|
|
return nil, fmt.Errorf("part fileSize not match, expect %d, got %d", partSize, uploadInfo.Resp.Upload.PartSize)
|
|
}
|
|
cb.UploadID(uploadInfo.Resp.Upload.UploadID)
|
|
uploadedSize := fileSize
|
|
for i := 0; i < len(partSizes); i++ {
|
|
if !uploadInfo.Bitmap.Get(i) {
|
|
uploadedSize -= partSizes[i]
|
|
}
|
|
}
|
|
continueUpload := uploadedSize > 0
|
|
for i, currentPartSize := range partSizes {
|
|
partNumber := int32(i + 1)
|
|
md5Reader := NewMd5Reader(io.LimitReader(file, currentPartSize))
|
|
if uploadInfo.Bitmap.Get(i) {
|
|
if _, err := io.Copy(io.Discard, md5Reader); err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
reader := NewProgressReader(md5Reader, func(current int64) {
|
|
cb.UploadComplete(fileSize, uploadedSize+current, uploadedSize)
|
|
})
|
|
urlval, header, err := uploadInfo.GetPartSign(ctx, partNumber)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := f.doPut(ctx, http.DefaultClient, urlval, header, reader, currentPartSize); err != nil {
|
|
log.ZError(ctx, "doPut", err, "partMd5Val", partMd5Val, "name", req.Name, "partNumber", partNumber)
|
|
return nil, err
|
|
}
|
|
uploadedSize += currentPartSize
|
|
if uploadInfo.DBInfo != nil && uploadInfo.Bitmap != nil {
|
|
uploadInfo.Bitmap.Set(i)
|
|
uploadInfo.DBInfo.UploadInfo = base64.StdEncoding.EncodeToString(uploadInfo.Bitmap.Serialize())
|
|
if err := f.database.UpdateUpload(ctx, uploadInfo.DBInfo); err != nil {
|
|
log.ZError(ctx, "SetUploadPartPush", err, "partMd5Val", partMd5Val, "name", req.Name, "partNumber", partNumber)
|
|
}
|
|
}
|
|
}
|
|
md5val := md5Reader.Md5()
|
|
if md5val != partMd5s[i] {
|
|
return nil, fmt.Errorf("upload part %d failed, md5 not match, expect %s, got %s", i, partMd5s[i], md5val)
|
|
}
|
|
cb.UploadPartComplete(i, currentPartSize, partMd5s[i])
|
|
log.ZDebug(ctx, "upload part success", "partMd5Val", md5val, "name", req.Name, "partNumber", partNumber)
|
|
}
|
|
log.ZDebug(ctx, "upload all part success", "partHash", partMd5Val, "name", req.Name)
|
|
resp, err := f.completeMultipartUpload(ctx, &third.CompleteMultipartUploadReq{
|
|
UploadID: uploadInfo.Resp.Upload.UploadID,
|
|
Parts: partMd5s,
|
|
Name: req.Name,
|
|
ContentType: req.ContentType,
|
|
Cause: req.Cause,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
typ := 1
|
|
if continueUpload {
|
|
typ++
|
|
}
|
|
cb.Complete(fileSize, resp.Url, typ)
|
|
if uploadInfo.DBInfo != nil {
|
|
if err := f.database.DeleteUpload(ctx, info.PartMd5); err != nil {
|
|
log.ZError(ctx, "DeleteUpload", err, "partMd5Val", info.PartMd5, "name", req.Name)
|
|
}
|
|
}
|
|
return &UploadFileResp{
|
|
URL: resp.Url,
|
|
}, nil
|
|
}
|
|
|
|
func (f *File) cleanPartLimit() {
|
|
f.confLock.Lock()
|
|
defer f.confLock.Unlock()
|
|
f.partLimit = nil
|
|
}
|
|
|
|
func (f *File) initiateMultipartUploadResp(ctx context.Context, req *third.InitiateMultipartUploadReq) (*third.InitiateMultipartUploadResp, error) {
|
|
return util.CallApi[third.InitiateMultipartUploadResp](ctx, constant.ObjectInitiateMultipartUpload, req)
|
|
}
|
|
|
|
func (f *File) authSign(ctx context.Context, req *third.AuthSignReq) (*third.AuthSignResp, error) {
|
|
if len(req.PartNumbers) == 0 {
|
|
return nil, errs.ErrArgs.WrapMsg("partNumbers is empty")
|
|
}
|
|
return util.CallApi[third.AuthSignResp](ctx, constant.ObjectAuthSign, req)
|
|
}
|
|
|
|
func (f *File) completeMultipartUpload(ctx context.Context, req *third.CompleteMultipartUploadReq) (*third.CompleteMultipartUploadResp, error) {
|
|
return util.CallApi[third.CompleteMultipartUploadResp](ctx, constant.ObjectCompleteMultipartUpload, req)
|
|
}
|
|
|
|
func (f *File) getPartNum(fileSize int64, partSize int64) int {
|
|
partNum := fileSize / partSize
|
|
if fileSize%partSize != 0 {
|
|
partNum++
|
|
}
|
|
return int(partNum)
|
|
}
|
|
|
|
func (f *File) partSize(ctx context.Context, size int64) (int64, error) {
|
|
f.confLock.Lock()
|
|
defer f.confLock.Unlock()
|
|
if f.partLimit == nil {
|
|
resp, err := util.CallApi[third.PartLimitResp](ctx, constant.ObjectPartLimit, &third.PartLimitReq{})
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
f.partLimit = resp
|
|
}
|
|
if size <= 0 {
|
|
return 0, errors.New("size must be greater than 0")
|
|
}
|
|
if size > f.partLimit.MaxPartSize*int64(f.partLimit.MaxNumSize) {
|
|
return 0, fmt.Errorf("size must be less than %db", f.partLimit.MaxPartSize*int64(f.partLimit.MaxNumSize))
|
|
}
|
|
if size <= f.partLimit.MinPartSize*int64(f.partLimit.MaxNumSize) {
|
|
return f.partLimit.MinPartSize, nil
|
|
}
|
|
partSize := size / int64(f.partLimit.MaxNumSize)
|
|
if size%int64(f.partLimit.MaxNumSize) != 0 {
|
|
partSize++
|
|
}
|
|
return partSize, nil
|
|
}
|
|
|
|
func (f *File) accessURL(ctx context.Context, req *third.AccessURLReq) (*third.AccessURLResp, error) {
|
|
return util.CallApi[third.AccessURLResp](ctx, constant.ObjectAccessURL, req)
|
|
}
|
|
|
|
func (f *File) doHttpReq(req *http.Request) ([]byte, *http.Response, error) {
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return data, resp, nil
|
|
}
|
|
|
|
func (f *File) partMD5(parts []string) string {
|
|
s := strings.Join(parts, ",")
|
|
md5Sum := md5.Sum([]byte(s))
|
|
return hex.EncodeToString(md5Sum[:])
|
|
}
|
|
|
|
type AuthSignParts struct {
|
|
Sign *third.SignPart
|
|
Times []time.Time
|
|
}
|
|
|
|
type UploadInfo struct {
|
|
PartNum int
|
|
Bitmap *Bitmap
|
|
DBInfo *model_struct.LocalUpload
|
|
Resp *third.InitiateMultipartUploadResp
|
|
//Signs *AuthSignParts
|
|
CreateTime time.Time
|
|
BatchSignNum int32
|
|
f *File
|
|
}
|
|
|
|
func (u *UploadInfo) getIndex(partNumber int32) int {
|
|
if u.Resp.Upload.Sign == nil {
|
|
return -1
|
|
} else {
|
|
if u.CreateTime.IsZero() {
|
|
return -1
|
|
} else {
|
|
if time.Since(u.CreateTime) > time.Minute {
|
|
return -1
|
|
}
|
|
}
|
|
}
|
|
for i, part := range u.Resp.Upload.Sign.Parts {
|
|
if part.PartNumber == partNumber {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func (u *UploadInfo) buildRequest(i int) (*url.URL, http.Header, error) {
|
|
sign := u.Resp.Upload.Sign
|
|
part := sign.Parts[i]
|
|
rawURL := sign.Url
|
|
if part.Url != "" {
|
|
rawURL = part.Url
|
|
}
|
|
urlval, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
if len(sign.Query)+len(part.Query) > 0 {
|
|
query := urlval.Query()
|
|
for i := range sign.Query {
|
|
v := sign.Query[i]
|
|
query[v.Key] = v.Values
|
|
}
|
|
for i := range part.Query {
|
|
v := part.Query[i]
|
|
query[v.Key] = v.Values
|
|
}
|
|
urlval.RawQuery = query.Encode()
|
|
}
|
|
header := make(http.Header)
|
|
for i := range sign.Header {
|
|
v := sign.Header[i]
|
|
header[v.Key] = v.Values
|
|
}
|
|
for i := range part.Header {
|
|
v := part.Header[i]
|
|
header[v.Key] = v.Values
|
|
}
|
|
return urlval, header, nil
|
|
}
|
|
|
|
func (u *UploadInfo) GetPartSign(ctx context.Context, partNumber int32) (*url.URL, http.Header, error) {
|
|
if partNumber < 1 || int(partNumber) > u.PartNum {
|
|
return nil, nil, errors.New("invalid partNumber")
|
|
}
|
|
if index := u.getIndex(partNumber); index >= 0 {
|
|
return u.buildRequest(index)
|
|
}
|
|
partNumbers := make([]int32, 0, u.BatchSignNum)
|
|
for i := int32(0); i < u.BatchSignNum; i++ {
|
|
if int(partNumber+i) > u.PartNum {
|
|
break
|
|
}
|
|
partNumbers = append(partNumbers, partNumber+i)
|
|
}
|
|
authSignResp, err := u.f.authSign(ctx, &third.AuthSignReq{
|
|
UploadID: u.Resp.Upload.UploadID,
|
|
PartNumbers: partNumbers,
|
|
})
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
u.Resp.Upload.Sign.Url = authSignResp.Url
|
|
u.Resp.Upload.Sign.Query = authSignResp.Query
|
|
u.Resp.Upload.Sign.Header = authSignResp.Header
|
|
u.Resp.Upload.Sign.Parts = authSignResp.Parts
|
|
u.CreateTime = time.Now()
|
|
index := u.getIndex(partNumber)
|
|
if index < 0 {
|
|
return nil, nil, errs.ErrInternalServer.WrapMsg("server part sign invalid")
|
|
}
|
|
return u.buildRequest(index)
|
|
}
|
|
|
|
func (f *File) getUpload(ctx context.Context, req *third.InitiateMultipartUploadReq) (*UploadInfo, error) {
|
|
partNum := f.getPartNum(req.Size, req.PartSize)
|
|
var bitmap *Bitmap
|
|
if f.database != nil {
|
|
dbUpload, err := f.database.GetUpload(ctx, req.Hash)
|
|
if err == nil {
|
|
bitmapBytes, err := base64.StdEncoding.DecodeString(dbUpload.UploadInfo)
|
|
if err != nil || len(bitmapBytes) == 0 || partNum <= 1 || dbUpload.ExpireTime-3600*1000 < time.Now().UnixMilli() {
|
|
if err := f.database.DeleteUpload(ctx, req.Hash); err != nil {
|
|
return nil, err
|
|
}
|
|
dbUpload = nil
|
|
}
|
|
if dbUpload == nil {
|
|
bitmap = NewBitmap(partNum)
|
|
} else {
|
|
bitmap = ParseBitmap(bitmapBytes, partNum)
|
|
}
|
|
tUpInfo := &third.UploadInfo{
|
|
PartSize: req.PartSize,
|
|
Sign: &third.AuthSignParts{},
|
|
}
|
|
if dbUpload != nil {
|
|
tUpInfo.UploadID = dbUpload.UploadID
|
|
tUpInfo.ExpireTime = dbUpload.ExpireTime
|
|
}
|
|
return &UploadInfo{
|
|
PartNum: partNum,
|
|
Bitmap: bitmap,
|
|
DBInfo: dbUpload,
|
|
Resp: &third.InitiateMultipartUploadResp{
|
|
Upload: tUpInfo,
|
|
},
|
|
BatchSignNum: req.MaxParts,
|
|
f: f,
|
|
}, nil
|
|
}
|
|
log.ZError(ctx, "get upload db", err, "pratsMd5", req.Hash)
|
|
}
|
|
resp, err := f.initiateMultipartUploadResp(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.Upload == nil {
|
|
return &UploadInfo{
|
|
Resp: resp,
|
|
}, nil
|
|
}
|
|
bitmap = NewBitmap(partNum)
|
|
var dbUpload *model_struct.LocalUpload
|
|
if f.database != nil {
|
|
dbUpload = &model_struct.LocalUpload{
|
|
PartHash: req.Hash,
|
|
UploadID: resp.Upload.UploadID,
|
|
UploadInfo: base64.StdEncoding.EncodeToString(bitmap.Serialize()),
|
|
ExpireTime: resp.Upload.ExpireTime,
|
|
CreateTime: time.Now().UnixMilli(),
|
|
}
|
|
if err := f.database.InsertUpload(ctx, dbUpload); err != nil {
|
|
log.ZError(ctx, "insert upload db", err, "pratsHash", req.Hash, "name", req.Name)
|
|
}
|
|
}
|
|
if req.MaxParts >= 0 && len(resp.Upload.Sign.Parts) != int(req.MaxParts) {
|
|
resp.Upload.Sign.Parts = nil
|
|
}
|
|
return &UploadInfo{
|
|
PartNum: partNum,
|
|
Bitmap: bitmap,
|
|
DBInfo: dbUpload,
|
|
Resp: resp,
|
|
CreateTime: time.Now(),
|
|
BatchSignNum: req.MaxParts,
|
|
f: f,
|
|
}, nil
|
|
}
|
|
|
|
func (f *File) doPut(ctx context.Context, client *http.Client, url *url.URL, header http.Header, reader io.Reader, size int64) error {
|
|
rawURL := url.String()
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, rawURL, reader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for key := range header {
|
|
req.Header[key] = header[key]
|
|
}
|
|
req.ContentLength = size
|
|
log.ZDebug(ctx, "do put req", "url", rawURL, "contentLength", size, "header", req.Header)
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
_ = resp.Body.Close()
|
|
}()
|
|
log.ZDebug(ctx, "do put resp status", "url", rawURL, "status", resp.Status, "contentLength", resp.ContentLength, "header", resp.Header)
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
log.ZDebug(ctx, "do put resp body", "url", rawURL, "body", string(body))
|
|
if resp.StatusCode/200 != 1 {
|
|
return fmt.Errorf("PUT %s failed, status code %d, body %s", rawURL, resp.StatusCode, string(body))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *File) getPartInfo(ctx context.Context, r io.Reader, fileSize int64, cb UploadFileCallback) (*partInfo, error) {
|
|
partSize, err := f.partSize(ctx, fileSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
partNum := int(fileSize / partSize)
|
|
if fileSize%partSize != 0 {
|
|
partNum++
|
|
}
|
|
cb.PartSize(partSize, partNum)
|
|
partSizes := make([]int64, partNum)
|
|
for i := 0; i < partNum; i++ {
|
|
partSizes[i] = partSize
|
|
}
|
|
partSizes[partNum-1] = fileSize - partSize*(int64(partNum)-1)
|
|
partMd5s := make([]string, partNum)
|
|
buf := make([]byte, 1024*8)
|
|
fileMd5 := md5.New()
|
|
var contentType string
|
|
for i := 0; i < partNum; i++ {
|
|
h := md5.New()
|
|
r := io.LimitReader(r, partSize)
|
|
for {
|
|
if n, err := r.Read(buf); err == nil {
|
|
if contentType == "" {
|
|
contentType = http.DetectContentType(buf[:n])
|
|
}
|
|
h.Write(buf[:n])
|
|
fileMd5.Write(buf[:n])
|
|
} else if err == io.EOF {
|
|
break
|
|
} else {
|
|
return nil, err
|
|
}
|
|
}
|
|
partMd5s[i] = hex.EncodeToString(h.Sum(nil))
|
|
cb.HashPartProgress(i, partSizes[i], partMd5s[i])
|
|
}
|
|
partMd5Val := f.partMD5(partMd5s)
|
|
fileMd5val := hex.EncodeToString(fileMd5.Sum(nil))
|
|
cb.HashPartComplete(f.partMD5(partMd5s), hex.EncodeToString(fileMd5.Sum(nil)))
|
|
return &partInfo{
|
|
ContentType: contentType,
|
|
PartSize: partSize,
|
|
PartNum: partNum,
|
|
FileMd5: fileMd5val,
|
|
PartMd5: partMd5Val,
|
|
PartSizes: partSizes,
|
|
PartMd5s: partMd5s,
|
|
}, nil
|
|
}
|
|
|