You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
openim-sdk-cpp/go/chao-sdk-core/internal/file/upload.go

576 lines
16 KiB

// Copyright © 2023 OpenIM SDK. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package file
import (
"context"
"crypto/md5"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"github.com/openimsdk/openim-sdk-core/v3/internal/util"
"github.com/openimsdk/openim-sdk-core/v3/pkg/constant"
"github.com/openimsdk/openim-sdk-core/v3/pkg/db/db_interface"
"github.com/openimsdk/openim-sdk-core/v3/pkg/db/model_struct"
"github.com/openimsdk/tools/errs"
"io"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/openimsdk/protocol/third"
"github.com/openimsdk/tools/log"
)
type UploadFileReq struct {
Filepath string `json:"filepath"`
Name string `json:"name"`
ContentType string `json:"contentType"`
Cause string `json:"cause"`
Uuid string `json:"uuid"`
}
type UploadFileResp struct {
URL string `json:"url"`
}
type partInfo struct {
ContentType string
PartSize int64
PartNum int
FileMd5 string
PartMd5 string
PartSizes []int64
PartMd5s []string
}
func NewFile(database db_interface.DataBase, loginUserID string) *File {
return &File{database: database, loginUserID: loginUserID, confLock: &sync.Mutex{}, mapLocker: &sync.Mutex{}, uploading: make(map[string]*lockInfo)}
}
type File struct {
database db_interface.DataBase
loginUserID string
confLock sync.Locker
partLimit *third.PartLimitResp
mapLocker sync.Locker
uploading map[string]*lockInfo
}
type lockInfo struct {
count int32
locker sync.Locker
}
func (f *File) lockHash(hash string) {
f.mapLocker.Lock()
locker, ok := f.uploading[hash]
if !ok {
locker = &lockInfo{count: 0, locker: &sync.Mutex{}}
f.uploading[hash] = locker
}
atomic.AddInt32(&locker.count, 1)
f.mapLocker.Unlock()
locker.locker.Lock()
}
func (f *File) unlockHash(hash string) {
f.mapLocker.Lock()
locker, ok := f.uploading[hash]
if !ok {
f.mapLocker.Unlock()
return
}
if atomic.AddInt32(&locker.count, -1) == 0 {
delete(f.uploading, hash)
}
f.mapLocker.Unlock()
locker.locker.Unlock()
}
func (f *File) UploadFile(ctx context.Context, req *UploadFileReq, cb UploadFileCallback) (*UploadFileResp, error) {
if cb == nil {
cb = emptyUploadCallback{}
}
if req.Name == "" {
return nil, errors.New("name is empty")
}
if req.Name[0] == '/' {
req.Name = req.Name[1:]
}
if prefix := f.loginUserID + "/"; !strings.HasPrefix(req.Name, prefix) {
req.Name = prefix + req.Name
}
file, err := Open(req)
if err != nil {
return nil, err
}
defer file.Close()
fileSize := file.Size()
cb.Open(fileSize)
info, err := f.getPartInfo(ctx, file, fileSize, cb)
if err != nil {
return nil, err
}
if req.ContentType == "" {
req.ContentType = info.ContentType
}
partSize := info.PartSize
partSizes := info.PartSizes
partMd5s := info.PartMd5s
partMd5Val := info.PartMd5
if err := file.StartSeek(0); err != nil {
return nil, err
}
f.lockHash(partMd5Val)
defer f.unlockHash(partMd5Val)
maxParts := 20
if maxParts > len(partSizes) {
maxParts = len(partSizes)
}
uploadInfo, err := f.getUpload(ctx, &third.InitiateMultipartUploadReq{
Hash: partMd5Val,
Size: fileSize,
PartSize: partSize,
MaxParts: int32(maxParts), // 一次性获取签名数量
Cause: req.Cause,
Name: req.Name,
ContentType: req.ContentType,
})
if err != nil {
return nil, err
}
if uploadInfo.Resp.Upload == nil {
cb.Complete(fileSize, uploadInfo.Resp.Url, 0)
return &UploadFileResp{
URL: uploadInfo.Resp.Url,
}, nil
}
if uploadInfo.Resp.Upload.PartSize != partSize {
f.cleanPartLimit()
return nil, fmt.Errorf("part fileSize not match, expect %d, got %d", partSize, uploadInfo.Resp.Upload.PartSize)
}
cb.UploadID(uploadInfo.Resp.Upload.UploadID)
uploadedSize := fileSize
for i := 0; i < len(partSizes); i++ {
if !uploadInfo.Bitmap.Get(i) {
uploadedSize -= partSizes[i]
}
}
continueUpload := uploadedSize > 0
for i, currentPartSize := range partSizes {
partNumber := int32(i + 1)
md5Reader := NewMd5Reader(io.LimitReader(file, currentPartSize))
if uploadInfo.Bitmap.Get(i) {
if _, err := io.Copy(io.Discard, md5Reader); err != nil {
return nil, err
}
} else {
reader := NewProgressReader(md5Reader, func(current int64) {
cb.UploadComplete(fileSize, uploadedSize+current, uploadedSize)
})
urlval, header, err := uploadInfo.GetPartSign(ctx, partNumber)
if err != nil {
return nil, err
}
if err := f.doPut(ctx, http.DefaultClient, urlval, header, reader, currentPartSize); err != nil {
log.ZError(ctx, "doPut", err, "partMd5Val", partMd5Val, "name", req.Name, "partNumber", partNumber)
return nil, err
}
uploadedSize += currentPartSize
if uploadInfo.DBInfo != nil && uploadInfo.Bitmap != nil {
uploadInfo.Bitmap.Set(i)
uploadInfo.DBInfo.UploadInfo = base64.StdEncoding.EncodeToString(uploadInfo.Bitmap.Serialize())
if err := f.database.UpdateUpload(ctx, uploadInfo.DBInfo); err != nil {
log.ZError(ctx, "SetUploadPartPush", err, "partMd5Val", partMd5Val, "name", req.Name, "partNumber", partNumber)
}
}
}
md5val := md5Reader.Md5()
if md5val != partMd5s[i] {
return nil, fmt.Errorf("upload part %d failed, md5 not match, expect %s, got %s", i, partMd5s[i], md5val)
}
cb.UploadPartComplete(i, currentPartSize, partMd5s[i])
log.ZDebug(ctx, "upload part success", "partMd5Val", md5val, "name", req.Name, "partNumber", partNumber)
}
log.ZDebug(ctx, "upload all part success", "partHash", partMd5Val, "name", req.Name)
resp, err := f.completeMultipartUpload(ctx, &third.CompleteMultipartUploadReq{
UploadID: uploadInfo.Resp.Upload.UploadID,
Parts: partMd5s,
Name: req.Name,
ContentType: req.ContentType,
Cause: req.Cause,
})
if err != nil {
return nil, err
}
typ := 1
if continueUpload {
typ++
}
cb.Complete(fileSize, resp.Url, typ)
if uploadInfo.DBInfo != nil {
if err := f.database.DeleteUpload(ctx, info.PartMd5); err != nil {
log.ZError(ctx, "DeleteUpload", err, "partMd5Val", info.PartMd5, "name", req.Name)
}
}
return &UploadFileResp{
URL: resp.Url,
}, nil
}
func (f *File) cleanPartLimit() {
f.confLock.Lock()
defer f.confLock.Unlock()
f.partLimit = nil
}
func (f *File) initiateMultipartUploadResp(ctx context.Context, req *third.InitiateMultipartUploadReq) (*third.InitiateMultipartUploadResp, error) {
return util.CallApi[third.InitiateMultipartUploadResp](ctx, constant.ObjectInitiateMultipartUpload, req)
}
func (f *File) authSign(ctx context.Context, req *third.AuthSignReq) (*third.AuthSignResp, error) {
if len(req.PartNumbers) == 0 {
return nil, errs.ErrArgs.WrapMsg("partNumbers is empty")
}
return util.CallApi[third.AuthSignResp](ctx, constant.ObjectAuthSign, req)
}
func (f *File) completeMultipartUpload(ctx context.Context, req *third.CompleteMultipartUploadReq) (*third.CompleteMultipartUploadResp, error) {
return util.CallApi[third.CompleteMultipartUploadResp](ctx, constant.ObjectCompleteMultipartUpload, req)
}
func (f *File) getPartNum(fileSize int64, partSize int64) int {
partNum := fileSize / partSize
if fileSize%partSize != 0 {
partNum++
}
return int(partNum)
}
func (f *File) partSize(ctx context.Context, size int64) (int64, error) {
f.confLock.Lock()
defer f.confLock.Unlock()
if f.partLimit == nil {
resp, err := util.CallApi[third.PartLimitResp](ctx, constant.ObjectPartLimit, &third.PartLimitReq{})
if err != nil {
return 0, err
}
f.partLimit = resp
}
if size <= 0 {
return 0, errors.New("size must be greater than 0")
}
if size > f.partLimit.MaxPartSize*int64(f.partLimit.MaxNumSize) {
return 0, fmt.Errorf("size must be less than %db", f.partLimit.MaxPartSize*int64(f.partLimit.MaxNumSize))
}
if size <= f.partLimit.MinPartSize*int64(f.partLimit.MaxNumSize) {
return f.partLimit.MinPartSize, nil
}
partSize := size / int64(f.partLimit.MaxNumSize)
if size%int64(f.partLimit.MaxNumSize) != 0 {
partSize++
}
return partSize, nil
}
func (f *File) accessURL(ctx context.Context, req *third.AccessURLReq) (*third.AccessURLResp, error) {
return util.CallApi[third.AccessURLResp](ctx, constant.ObjectAccessURL, req)
}
func (f *File) doHttpReq(req *http.Request) ([]byte, *http.Response, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, nil, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, err
}
return data, resp, nil
}
func (f *File) partMD5(parts []string) string {
s := strings.Join(parts, ",")
md5Sum := md5.Sum([]byte(s))
return hex.EncodeToString(md5Sum[:])
}
type AuthSignParts struct {
Sign *third.SignPart
Times []time.Time
}
type UploadInfo struct {
PartNum int
Bitmap *Bitmap
DBInfo *model_struct.LocalUpload
Resp *third.InitiateMultipartUploadResp
//Signs *AuthSignParts
CreateTime time.Time
BatchSignNum int32
f *File
}
func (u *UploadInfo) getIndex(partNumber int32) int {
if u.Resp.Upload.Sign == nil {
return -1
} else {
if u.CreateTime.IsZero() {
return -1
} else {
if time.Since(u.CreateTime) > time.Minute {
return -1
}
}
}
for i, part := range u.Resp.Upload.Sign.Parts {
if part.PartNumber == partNumber {
return i
}
}
return -1
}
func (u *UploadInfo) buildRequest(i int) (*url.URL, http.Header, error) {
sign := u.Resp.Upload.Sign
part := sign.Parts[i]
rawURL := sign.Url
if part.Url != "" {
rawURL = part.Url
}
urlval, err := url.Parse(rawURL)
if err != nil {
return nil, nil, err
}
if len(sign.Query)+len(part.Query) > 0 {
query := urlval.Query()
for i := range sign.Query {
v := sign.Query[i]
query[v.Key] = v.Values
}
for i := range part.Query {
v := part.Query[i]
query[v.Key] = v.Values
}
urlval.RawQuery = query.Encode()
}
header := make(http.Header)
for i := range sign.Header {
v := sign.Header[i]
header[v.Key] = v.Values
}
for i := range part.Header {
v := part.Header[i]
header[v.Key] = v.Values
}
return urlval, header, nil
}
func (u *UploadInfo) GetPartSign(ctx context.Context, partNumber int32) (*url.URL, http.Header, error) {
if partNumber < 1 || int(partNumber) > u.PartNum {
return nil, nil, errors.New("invalid partNumber")
}
if index := u.getIndex(partNumber); index >= 0 {
return u.buildRequest(index)
}
partNumbers := make([]int32, 0, u.BatchSignNum)
for i := int32(0); i < u.BatchSignNum; i++ {
if int(partNumber+i) > u.PartNum {
break
}
partNumbers = append(partNumbers, partNumber+i)
}
authSignResp, err := u.f.authSign(ctx, &third.AuthSignReq{
UploadID: u.Resp.Upload.UploadID,
PartNumbers: partNumbers,
})
if err != nil {
return nil, nil, err
}
u.Resp.Upload.Sign.Url = authSignResp.Url
u.Resp.Upload.Sign.Query = authSignResp.Query
u.Resp.Upload.Sign.Header = authSignResp.Header
u.Resp.Upload.Sign.Parts = authSignResp.Parts
u.CreateTime = time.Now()
index := u.getIndex(partNumber)
if index < 0 {
return nil, nil, errs.ErrInternalServer.WrapMsg("server part sign invalid")
}
return u.buildRequest(index)
}
func (f *File) getUpload(ctx context.Context, req *third.InitiateMultipartUploadReq) (*UploadInfo, error) {
partNum := f.getPartNum(req.Size, req.PartSize)
var bitmap *Bitmap
if f.database != nil {
dbUpload, err := f.database.GetUpload(ctx, req.Hash)
if err == nil {
bitmapBytes, err := base64.StdEncoding.DecodeString(dbUpload.UploadInfo)
if err != nil || len(bitmapBytes) == 0 || partNum <= 1 || dbUpload.ExpireTime-3600*1000 < time.Now().UnixMilli() {
if err := f.database.DeleteUpload(ctx, req.Hash); err != nil {
return nil, err
}
dbUpload = nil
}
if dbUpload == nil {
bitmap = NewBitmap(partNum)
} else {
bitmap = ParseBitmap(bitmapBytes, partNum)
}
tUpInfo := &third.UploadInfo{
PartSize: req.PartSize,
Sign: &third.AuthSignParts{},
}
if dbUpload != nil {
tUpInfo.UploadID = dbUpload.UploadID
tUpInfo.ExpireTime = dbUpload.ExpireTime
}
return &UploadInfo{
PartNum: partNum,
Bitmap: bitmap,
DBInfo: dbUpload,
Resp: &third.InitiateMultipartUploadResp{
Upload: tUpInfo,
},
BatchSignNum: req.MaxParts,
f: f,
}, nil
}
log.ZError(ctx, "get upload db", err, "pratsMd5", req.Hash)
}
resp, err := f.initiateMultipartUploadResp(ctx, req)
if err != nil {
return nil, err
}
if resp.Upload == nil {
return &UploadInfo{
Resp: resp,
}, nil
}
bitmap = NewBitmap(partNum)
var dbUpload *model_struct.LocalUpload
if f.database != nil {
dbUpload = &model_struct.LocalUpload{
PartHash: req.Hash,
UploadID: resp.Upload.UploadID,
UploadInfo: base64.StdEncoding.EncodeToString(bitmap.Serialize()),
ExpireTime: resp.Upload.ExpireTime,
CreateTime: time.Now().UnixMilli(),
}
if err := f.database.InsertUpload(ctx, dbUpload); err != nil {
log.ZError(ctx, "insert upload db", err, "pratsHash", req.Hash, "name", req.Name)
}
}
if req.MaxParts >= 0 && len(resp.Upload.Sign.Parts) != int(req.MaxParts) {
resp.Upload.Sign.Parts = nil
}
return &UploadInfo{
PartNum: partNum,
Bitmap: bitmap,
DBInfo: dbUpload,
Resp: resp,
CreateTime: time.Now(),
BatchSignNum: req.MaxParts,
f: f,
}, nil
}
func (f *File) doPut(ctx context.Context, client *http.Client, url *url.URL, header http.Header, reader io.Reader, size int64) error {
rawURL := url.String()
req, err := http.NewRequestWithContext(ctx, http.MethodPut, rawURL, reader)
if err != nil {
return err
}
for key := range header {
req.Header[key] = header[key]
}
req.ContentLength = size
log.ZDebug(ctx, "do put req", "url", rawURL, "contentLength", size, "header", req.Header)
resp, err := client.Do(req)
if err != nil {
return err
}
defer func() {
_ = resp.Body.Close()
}()
log.ZDebug(ctx, "do put resp status", "url", rawURL, "status", resp.Status, "contentLength", resp.ContentLength, "header", resp.Header)
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
log.ZDebug(ctx, "do put resp body", "url", rawURL, "body", string(body))
if resp.StatusCode/200 != 1 {
return fmt.Errorf("PUT %s failed, status code %d, body %s", rawURL, resp.StatusCode, string(body))
}
return nil
}
func (f *File) getPartInfo(ctx context.Context, r io.Reader, fileSize int64, cb UploadFileCallback) (*partInfo, error) {
partSize, err := f.partSize(ctx, fileSize)
if err != nil {
return nil, err
}
partNum := int(fileSize / partSize)
if fileSize%partSize != 0 {
partNum++
}
cb.PartSize(partSize, partNum)
partSizes := make([]int64, partNum)
for i := 0; i < partNum; i++ {
partSizes[i] = partSize
}
partSizes[partNum-1] = fileSize - partSize*(int64(partNum)-1)
partMd5s := make([]string, partNum)
buf := make([]byte, 1024*8)
fileMd5 := md5.New()
var contentType string
for i := 0; i < partNum; i++ {
h := md5.New()
r := io.LimitReader(r, partSize)
for {
if n, err := r.Read(buf); err == nil {
if contentType == "" {
contentType = http.DetectContentType(buf[:n])
}
h.Write(buf[:n])
fileMd5.Write(buf[:n])
} else if err == io.EOF {
break
} else {
return nil, err
}
}
partMd5s[i] = hex.EncodeToString(h.Sum(nil))
cb.HashPartProgress(i, partSizes[i], partMd5s[i])
}
partMd5Val := f.partMD5(partMd5s)
fileMd5val := hex.EncodeToString(fileMd5.Sum(nil))
cb.HashPartComplete(f.partMD5(partMd5s), hex.EncodeToString(fileMd5.Sum(nil)))
return &partInfo{
ContentType: contentType,
PartSize: partSize,
PartNum: partNum,
FileMd5: fileMd5val,
PartMd5: partMd5Val,
PartSizes: partSizes,
PartMd5s: partMd5s,
}, nil
}