Files

164 lines
3.4 KiB
Go

package copier
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
)
const maxRetries = 3
// FileJob represents one file to copy
type FileJob struct {
Source string
LocalDest string
S3Bucket string
S3Key string
AWSProfile string
}
// FileResult is the result of copying one file
type FileResult struct {
Job FileJob
SHA256 string
Bytes int64
Error error
}
// ProgressFn is called with (bytesCopied, totalBytes) during local copy
type ProgressFn func(n int64, total int64)
// CopyLocal copies src to dst, returns sha256 and bytes written
func CopyLocal(src, dst string, progress ProgressFn) (string, int64, error) {
if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
return "", 0, err
}
in, err := os.Open(src)
if err != nil {
return "", 0, err
}
defer in.Close()
info, err := in.Stat()
if err != nil {
return "", 0, err
}
total := info.Size()
out, err := os.Create(dst)
if err != nil {
return "", 0, err
}
defer out.Close()
h := sha256.New()
var written int64
buf := make([]byte, 1<<20) // 1MB buffer
for {
nr, er := in.Read(buf)
if nr > 0 {
nw, ew := out.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
h.Write(buf[:nw])
if progress != nil {
progress(written, total)
}
}
if ew != nil {
return "", written, ew
}
}
if er == io.EOF {
break
}
if er != nil {
return "", written, er
}
}
return hex.EncodeToString(h.Sum(nil)), written, nil
}
// CopyS3 uploads localPath to s3://bucket/key via s5cmd
func CopyS3(ctx context.Context, localPath, bucket, key, awsProfile string) error {
s3uri := fmt.Sprintf("s3://%s/%s", bucket, key)
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
wait := time.Duration(1<<attempt) * time.Second
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(wait):
}
}
args := []string{"cp", "--concurrency", "8", localPath, s3uri}
cmd := exec.CommandContext(ctx, "s5cmd", args...)
if awsProfile != "" {
cmd.Env = append(os.Environ(), "AWS_PROFILE="+awsProfile)
}
out, err := cmd.CombinedOutput()
if err == nil {
return nil
}
if strings.Contains(string(out), "NoSuchBucket") {
return fmt.Errorf("bucket not found: %s", bucket)
}
}
return fmt.Errorf("s5cmd failed after %d retries", maxRetries)
}
// Pool runs file copy jobs with worker concurrency
type Pool struct {
Workers int
OnResult func(FileResult)
}
// Run processes jobs from the jobs channel until closed or ctx is done
func (p *Pool) Run(ctx context.Context, jobs <-chan FileJob) {
var wg sync.WaitGroup
for i := 0; i < p.Workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for job := range jobs {
select {
case <-ctx.Done():
return
default:
}
result := p.processJob(ctx, job)
if p.OnResult != nil {
p.OnResult(result)
}
}
}()
}
wg.Wait()
}
func (p *Pool) processJob(ctx context.Context, job FileJob) FileResult {
sha, bytes, err := CopyLocal(job.Source, job.LocalDest, nil)
if err != nil {
return FileResult{Job: job, Error: fmt.Errorf("local copy: %w", err)}
}
if job.S3Bucket != "" {
if err := CopyS3(ctx, job.LocalDest, job.S3Bucket, job.S3Key, job.AWSProfile); err != nil {
return FileResult{Job: job, SHA256: sha, Bytes: bytes, Error: fmt.Errorf("s3: %w", err)}
}
}
return FileResult{Job: job, SHA256: sha, Bytes: bytes}
}