mirror of
https://github.com/coder/coder.git
synced 2025-07-08 11:39:50 +00:00
feat: return better error if file size is too big to upload (#7775)
* feat: return better error if file size is too big to upload * Use a limit writer to capture actual tar size
This commit is contained in:
43
coderd/util/xio/limitwriter.go
Normal file
43
coderd/util/xio/limitwriter.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package xio
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrLimitReached = xerrors.Errorf("i/o limit reached")
|
||||||
|
|
||||||
|
// LimitWriter will only write bytes to the underlying writer until the limit is reached.
|
||||||
|
type LimitWriter struct {
|
||||||
|
Limit int64
|
||||||
|
N int64
|
||||||
|
W io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLimitWriter(w io.Writer, n int64) *LimitWriter {
|
||||||
|
// If anyone tries this, just make a 0 writer.
|
||||||
|
if n < 0 {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
return &LimitWriter{
|
||||||
|
Limit: n,
|
||||||
|
N: 0,
|
||||||
|
W: w,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LimitWriter) Write(p []byte) (int, error) {
|
||||||
|
if l.N >= l.Limit {
|
||||||
|
return 0, ErrLimitReached
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write 0 bytes if the limit is to be exceeded.
|
||||||
|
if int64(len(p)) > l.Limit-l.N {
|
||||||
|
return 0, ErrLimitReached
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := l.W.Write(p)
|
||||||
|
l.N += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
141
coderd/util/xio/limitwriter_test.go
Normal file
141
coderd/util/xio/limitwriter_test.go
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
package xio_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
cryptorand "crypto/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/util/xio"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLimitWriter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
type writeCase struct {
|
||||||
|
N int
|
||||||
|
ExpN int
|
||||||
|
Err bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// testCases will do multiple writes to the same limit writer and check the output.
|
||||||
|
testCases := []struct {
|
||||||
|
Name string
|
||||||
|
L int64
|
||||||
|
Writes []writeCase
|
||||||
|
N int
|
||||||
|
ExpN int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "Empty",
|
||||||
|
L: 1000,
|
||||||
|
Writes: []writeCase{
|
||||||
|
// A few empty writes
|
||||||
|
{N: 0, ExpN: 0}, {N: 0, ExpN: 0}, {N: 0, ExpN: 0},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "NotFull",
|
||||||
|
L: 1000,
|
||||||
|
Writes: []writeCase{
|
||||||
|
{N: 250, ExpN: 250},
|
||||||
|
{N: 250, ExpN: 250},
|
||||||
|
{N: 250, ExpN: 250},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Short",
|
||||||
|
L: 1000,
|
||||||
|
Writes: []writeCase{
|
||||||
|
{N: 250, ExpN: 250},
|
||||||
|
{N: 250, ExpN: 250},
|
||||||
|
{N: 250, ExpN: 250},
|
||||||
|
{N: 250, ExpN: 250},
|
||||||
|
{N: 250, ExpN: 0, Err: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Exact",
|
||||||
|
L: 1000,
|
||||||
|
Writes: []writeCase{
|
||||||
|
{
|
||||||
|
N: 1000,
|
||||||
|
ExpN: 1000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
N: 1000,
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Over",
|
||||||
|
L: 1000,
|
||||||
|
Writes: []writeCase{
|
||||||
|
{
|
||||||
|
N: 5000,
|
||||||
|
ExpN: 0,
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
N: 5000,
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
N: 5000,
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Strange",
|
||||||
|
L: -1,
|
||||||
|
Writes: []writeCase{
|
||||||
|
{
|
||||||
|
N: 5,
|
||||||
|
ExpN: 0,
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
N: 0,
|
||||||
|
ExpN: 0,
|
||||||
|
Err: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range testCases {
|
||||||
|
c := c
|
||||||
|
t.Run(c.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
allBuff := bytes.NewBuffer([]byte{})
|
||||||
|
w := xio.NewLimitWriter(buf, c.L)
|
||||||
|
|
||||||
|
for _, wc := range c.Writes {
|
||||||
|
data := make([]byte, wc.N)
|
||||||
|
|
||||||
|
n, err := cryptorand.Read(data)
|
||||||
|
require.NoError(t, err, "crand read")
|
||||||
|
require.Equal(t, wc.N, n, "correct bytes read")
|
||||||
|
max := data[:wc.ExpN]
|
||||||
|
n, err = w.Write(data)
|
||||||
|
if wc.Err {
|
||||||
|
require.Error(t, err, "exp error")
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err, "write")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Need to use this to compare across multiple writes.
|
||||||
|
// Each write appends to the expected output.
|
||||||
|
allBuff.Write(max)
|
||||||
|
|
||||||
|
require.Equal(t, wc.ExpN, n, "correct bytes written")
|
||||||
|
require.Equal(t, allBuff.Bytes(), buf.Bytes(), "expected data")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -8,6 +8,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"github.com/coder/coder/coderd/util/xio"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -32,8 +34,9 @@ func dirHasExt(dir string, ext string) (bool, error) {
|
|||||||
|
|
||||||
// Tar archives a Terraform directory.
|
// Tar archives a Terraform directory.
|
||||||
func Tar(w io.Writer, directory string, limit int64) error {
|
func Tar(w io.Writer, directory string, limit int64) error {
|
||||||
|
// The total bytes written must be under the limit, so use -1
|
||||||
|
w = xio.NewLimitWriter(w, limit-1)
|
||||||
tarWriter := tar.NewWriter(w)
|
tarWriter := tar.NewWriter(w)
|
||||||
totalSize := int64(0)
|
|
||||||
|
|
||||||
const tfExt = ".tf"
|
const tfExt = ".tf"
|
||||||
hasTf, err := dirHasExt(directory, tfExt)
|
hasTf, err := dirHasExt(directory, tfExt)
|
||||||
@ -95,22 +98,26 @@ func Tar(w io.Writer, directory string, limit int64) error {
|
|||||||
if !fileInfo.Mode().IsRegular() {
|
if !fileInfo.Mode().IsRegular() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := os.Open(file)
|
data, err := os.Open(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer data.Close()
|
defer data.Close()
|
||||||
wrote, err := io.Copy(tarWriter, data)
|
_, err = io.Copy(tarWriter, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if xerrors.Is(err, xio.ErrLimitReached) {
|
||||||
|
return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
totalSize += wrote
|
|
||||||
if limit != 0 && totalSize >= limit {
|
|
||||||
return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit)
|
|
||||||
}
|
|
||||||
return data.Close()
|
return data.Close()
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if xerrors.Is(err, xio.ErrLimitReached) {
|
||||||
|
return xerrors.Errorf("Archive too big. Must be <= %d bytes", limit)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = tarWriter.Flush()
|
err = tarWriter.Flush()
|
||||||
|
@ -15,6 +15,32 @@ import (
|
|||||||
|
|
||||||
func TestTar(t *testing.T) {
|
func TestTar(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
t.Run("HeaderBreakLimit", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
dir := t.TempDir()
|
||||||
|
file, err := os.CreateTemp(dir, "*.tf")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_ = file.Close()
|
||||||
|
// A header is 512 bytes
|
||||||
|
err = provisionersdk.Tar(io.Discard, dir, 100)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
t.Run("HeaderAndContent", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
dir := t.TempDir()
|
||||||
|
file, err := os.CreateTemp(dir, "*.tf")
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, _ = file.Write(make([]byte, 100))
|
||||||
|
_ = file.Close()
|
||||||
|
// Pay + header is 1024 bytes (padding)
|
||||||
|
err = provisionersdk.Tar(io.Discard, dir, 1025)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Limit is 1 byte too small (n == limit is a failure, must be under)
|
||||||
|
err = provisionersdk.Tar(io.Discard, dir, 1024)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("NoTF", func(t *testing.T) {
|
t.Run("NoTF", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
@ -97,7 +123,8 @@ func TestTar(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
archive := new(bytes.Buffer)
|
archive := new(bytes.Buffer)
|
||||||
err := provisionersdk.Tar(archive, dir, 1024)
|
// Headers are chonky so raise the limit to something reasonable
|
||||||
|
err := provisionersdk.Tar(archive, dir, 1024<<2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
dir = t.TempDir()
|
dir = t.TempDir()
|
||||||
err = provisionersdk.Untar(dir, archive)
|
err = provisionersdk.Untar(dir, archive)
|
||||||
|
Reference in New Issue
Block a user