Add NewBufferedWriter, and Flush and Close methods.
Deprecate NewWriter.
See the discussion on
https://groups.google.com/d/topic/golang-dev/nXp12KmMSvM/discussion
diff --git a/encode.go b/encode.go
index d80185d..d68d441 100644
--- a/encode.go
+++ b/encode.go
@@ -6,6 +6,7 @@
import (
"encoding/binary"
+ "errors"
"io"
)
@@ -175,23 +176,61 @@
return 32 + srcLen + srcLen/6
}
-// NewWriter returns a new Writer that compresses to w, using the framing
-// format described at
-// https://github.com/google/snappy/blob/master/framing_format.txt
+var errClosed = errors.New("snappy: Writer is closed")
+
+// NewWriter returns a new Writer that compresses to w.
+//
+// The Writer returned does not buffer writes. There is no need to Flush or
+// Close such a Writer.
+//
+// Deprecated: the Writer returned is not suitable for many small writes, only
+// for few large writes. Use NewBufferedWriter instead, which is efficient
+// regardless of the frequency and shape of the writes, and remember to Close
+// that Writer when done.
func NewWriter(w io.Writer) *Writer {
return &Writer{
- w: w,
- enc: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)),
+ w: w,
+ obuf: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)),
+ }
+}
+
+// NewBufferedWriter returns a new Writer that compresses to w, using the
+// framing format described at
+// https://github.com/google/snappy/blob/master/framing_format.txt
+//
+// The Writer returned buffers writes. Users must call Close to guarantee all
+// data has been forwarded to the underlying io.Writer. They may also call
+// Flush zero or more times before calling Close.
+func NewBufferedWriter(w io.Writer) *Writer {
+ return &Writer{
+ w: w,
+ ibuf: make([]byte, 0, maxUncompressedChunkLen),
+ obuf: make([]byte, MaxEncodedLen(maxUncompressedChunkLen)),
}
}
// Writer is an io.Writer than can write Snappy-compressed bytes.
type Writer struct {
- w io.Writer
- err error
- enc []byte
- buf [checksumSize + chunkHeaderSize]byte
- wroteHeader bool
+ w io.Writer
+ err error
+
+ // ibuf is a buffer for the incoming (uncompressed) bytes.
+ //
+ // Its use is optional. For backwards compatibility, Writers created by the
+ // NewWriter function have ibuf == nil, do not buffer incoming bytes, and
+ // therefore do not need to be Flush'ed or Close'd.
+ ibuf []byte
+
+ // obuf is a buffer for the outgoing (compressed) bytes.
+ obuf []byte
+
+ // chunkHeaderBuf is a buffer for the per-chunk header (chunk type, length
+ // and checksum), not to be confused with the magic string that forms the
+ // stream header.
+ chunkHeaderBuf [checksumSize + chunkHeaderSize]byte
+
+ // wroteStreamHeader is whether we have written the stream header.
+ wroteStreamHeader bool
}
// Reset discards the writer's state and switches the Snappy writer to write to
@@ -199,21 +238,61 @@
func (w *Writer) Reset(writer io.Writer) {
w.w = writer
w.err = nil
- w.wroteHeader = false
+ if w.ibuf != nil {
+ w.ibuf = w.ibuf[:0]
+ }
+ w.wroteStreamHeader = false
}
// Write satisfies the io.Writer interface.
-func (w *Writer) Write(p []byte) (n int, errRet error) {
+func (w *Writer) Write(p []byte) (nRet int, errRet error) {
+ if w.ibuf == nil {
+ // Do not buffer incoming bytes. This does not perform or compress well
+ // if the caller of Writer.Write writes many small slices. This
+ // behavior is therefore deprecated, but still supported for backwards
+ // compatibility with code that doesn't explicitly Flush or Close.
+ return w.write(p)
+ }
+
+ // The remainder of this method is based on bufio.Writer.Write from the
+ // standard library.
+
+ for len(p) > (cap(w.ibuf)-len(w.ibuf)) && w.err == nil {
+ var n int
+ if len(w.ibuf) == 0 {
+ // Large write, empty buffer.
+ // Write directly from p to avoid copy.
+ n, _ = w.write(p)
+ } else {
+ n = copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
+ w.ibuf = w.ibuf[:len(w.ibuf)+n]
+ w.Flush()
+ }
+ nRet += n
+ p = p[n:]
+ }
+ if w.err != nil {
+ return nRet, w.err
+ }
+ n := copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
+ w.ibuf = w.ibuf[:len(w.ibuf)+n]
+ nRet += n
+ return nRet, nil
+}
+
+func (w *Writer) write(p []byte) (nRet int, errRet error) {
if w.err != nil {
return 0, w.err
}
- if !w.wroteHeader {
- copy(w.enc, magicChunk)
- if _, err := w.w.Write(w.enc[:len(magicChunk)]); err != nil {
- w.err = err
- return n, err
+ if !w.wroteStreamHeader {
+ if copy(w.obuf, magicChunk) != len(magicChunk) {
+ panic("unreachable")
}
- w.wroteHeader = true
+ if _, err := w.w.Write(w.obuf[:len(magicChunk)]); err != nil {
+ w.err = err
+ return nRet, err
+ }
+ w.wroteStreamHeader = true
}
for len(p) > 0 {
var uncompressed []byte
@@ -227,29 +306,52 @@
// Compress the buffer, discarding the result if the improvement
// isn't at least 12.5%.
chunkType := uint8(chunkTypeCompressedData)
- chunkBody := Encode(w.enc, uncompressed)
+ chunkBody := Encode(w.obuf, uncompressed)
if len(chunkBody) >= len(uncompressed)-len(uncompressed)/8 {
chunkType, chunkBody = chunkTypeUncompressedData, uncompressed
}
chunkLen := 4 + len(chunkBody)
- w.buf[0] = chunkType
- w.buf[1] = uint8(chunkLen >> 0)
- w.buf[2] = uint8(chunkLen >> 8)
- w.buf[3] = uint8(chunkLen >> 16)
- w.buf[4] = uint8(checksum >> 0)
- w.buf[5] = uint8(checksum >> 8)
- w.buf[6] = uint8(checksum >> 16)
- w.buf[7] = uint8(checksum >> 24)
- if _, err := w.w.Write(w.buf[:]); err != nil {
+ w.chunkHeaderBuf[0] = chunkType
+ w.chunkHeaderBuf[1] = uint8(chunkLen >> 0)
+ w.chunkHeaderBuf[2] = uint8(chunkLen >> 8)
+ w.chunkHeaderBuf[3] = uint8(chunkLen >> 16)
+ w.chunkHeaderBuf[4] = uint8(checksum >> 0)
+ w.chunkHeaderBuf[5] = uint8(checksum >> 8)
+ w.chunkHeaderBuf[6] = uint8(checksum >> 16)
+ w.chunkHeaderBuf[7] = uint8(checksum >> 24)
+ if _, err := w.w.Write(w.chunkHeaderBuf[:]); err != nil {
w.err = err
- return n, err
+ return nRet, err
}
if _, err := w.w.Write(chunkBody); err != nil {
w.err = err
- return n, err
+ return nRet, err
}
- n += len(uncompressed)
+ nRet += len(uncompressed)
}
- return n, nil
+ return nRet, nil
+}
+
+// Flush flushes the Writer to its underlying io.Writer.
+func (w *Writer) Flush() error {
+ if w.err != nil {
+ return w.err
+ }
+ if len(w.ibuf) == 0 {
+ return nil
+ }
+ w.write(w.ibuf)
+ w.ibuf = w.ibuf[:0]
+ return w.err
+}
+
+// Close calls Flush and then closes the Writer.
+func (w *Writer) Close() error {
+ w.Flush()
+ ret := w.err
+ if w.err == nil {
+ w.err = errClosed
+ }
+ return ret
}
diff --git a/snappy_test.go b/snappy_test.go
index 815fc0b..905dba0 100644
--- a/snappy_test.go
+++ b/snappy_test.go
@@ -141,6 +141,67 @@
}
}
+func TestNewBufferedWriter(t *testing.T) {
+ // Test all 32 possible sub-sequences of these 5 input slices.
+ //
+ // Their lengths sum to 400,000, which is over 6 times the Writer ibuf
+ // capacity: 6 * maxUncompressedChunkLen is 393,216.
+ inputs := [][]byte{
+ bytes.Repeat([]byte{'a'}, 40000),
+ bytes.Repeat([]byte{'b'}, 150000),
+ bytes.Repeat([]byte{'c'}, 60000),
+ bytes.Repeat([]byte{'d'}, 120000),
+ bytes.Repeat([]byte{'e'}, 30000),
+ }
+loop:
+ for i := 0; i < 1<<uint(len(inputs)); i++ {
+ var want []byte
+ buf := new(bytes.Buffer)
+ w := NewBufferedWriter(buf)
+ for j, input := range inputs {
+ if i&(1<<uint(j)) == 0 {
+ continue
+ }
+ if _, err := w.Write(input); err != nil {
+ t.Errorf("i=%#02x: j=%d: Write: %v", i, j, err)
+ continue loop
+ }
+ want = append(want, input...)
+ }
+ if err := w.Close(); err != nil {
+ t.Errorf("i=%#02x: Close: %v", i, err)
+ continue
+ }
+ got, err := ioutil.ReadAll(NewReader(buf))
+ if err != nil {
+ t.Errorf("i=%#02x: ReadAll: %v", i, err)
+ continue
+ }
+ if err := cmp(got, want); err != nil {
+ t.Errorf("i=%#02x: %v", i, err)
+ continue
+ }
+ }
+}
+
+func TestFlush(t *testing.T) {
+ buf := new(bytes.Buffer)
+ w := NewBufferedWriter(buf)
+ defer w.Close()
+ if _, err := w.Write(bytes.Repeat([]byte{'x'}, 20)); err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ if n := buf.Len(); n != 0 {
+ t.Fatalf("before Flush: %d bytes were written to the underlying io.Writer, want 0", n)
+ }
+ if err := w.Flush(); err != nil {
+ t.Fatalf("Flush: %v", err)
+ }
+ if n := buf.Len(); n == 0 {
+ t.Fatalf("after Flush: %d bytes were written to the underlying io.Writer, want non-0", n)
+ }
+}
+
func TestReaderReset(t *testing.T) {
gold := bytes.Repeat([]byte("All that is gold does not glitter,\n"), 10000)
buf := new(bytes.Buffer)
@@ -181,34 +242,75 @@
func TestWriterReset(t *testing.T) {
gold := bytes.Repeat([]byte("Not all those who wander are lost;\n"), 10000)
- var gots, wants [][]byte
const n = 20
- w, failed := NewWriter(nil), false
- for i := 0; i <= n; i++ {
- buf := new(bytes.Buffer)
- w.Reset(buf)
- want := gold[:len(gold)*i/n]
- if _, err := w.Write(want); err != nil {
- t.Errorf("#%d: Write: %v", i, err)
- failed = true
+ for _, buffered := range []bool{false, true} {
+ var w *Writer
+ if buffered {
+ w = NewBufferedWriter(nil)
+ defer w.Close()
+ } else {
+ w = NewWriter(nil)
+ }
+
+ var gots, wants [][]byte
+ failed := false
+ for i := 0; i <= n; i++ {
+ buf := new(bytes.Buffer)
+ w.Reset(buf)
+ want := gold[:len(gold)*i/n]
+ if _, err := w.Write(want); err != nil {
+ t.Errorf("#%d: Write: %v", i, err)
+ failed = true
+ continue
+ }
+ if buffered {
+ if err := w.Flush(); err != nil {
+ t.Errorf("#%d: Flush: %v", i, err)
+ failed = true
+ continue
+ }
+ }
+ got, err := ioutil.ReadAll(NewReader(buf))
+ if err != nil {
+ t.Errorf("#%d: ReadAll: %v", i, err)
+ failed = true
+ continue
+ }
+ gots = append(gots, got)
+ wants = append(wants, want)
+ }
+ if failed {
continue
}
- got, err := ioutil.ReadAll(NewReader(buf))
- if err != nil {
- t.Errorf("#%d: ReadAll: %v", i, err)
- failed = true
- continue
+ for i := range gots {
+ if err := cmp(gots[i], wants[i]); err != nil {
+ t.Errorf("#%d: %v", i, err)
+ }
}
- gots = append(gots, got)
- wants = append(wants, want)
}
- if failed {
- return
+}
+
+func TestWriterResetWithoutFlush(t *testing.T) {
+ buf0 := new(bytes.Buffer)
+ buf1 := new(bytes.Buffer)
+ w := NewBufferedWriter(buf0)
+ if _, err := w.Write([]byte("xxx")); err != nil {
+ t.Fatalf("Write #0: %v", err)
}
- for i := range gots {
- if err := cmp(gots[i], wants[i]); err != nil {
- t.Errorf("#%d: %v", i, err)
- }
+ // Note that we don't Flush the Writer before calling Reset.
+ w.Reset(buf1)
+ if _, err := w.Write([]byte("yyy")); err != nil {
+ t.Fatalf("Write #1: %v", err)
+ }
+ if err := w.Flush(); err != nil {
+ t.Fatalf("Flush: %v", err)
+ }
+ got, err := ioutil.ReadAll(NewReader(buf1))
+ if err != nil {
+ t.Fatalf("ReadAll: %v", err)
+ }
+ if err := cmp(got, []byte("yyy")); err != nil {
+ t.Fatal(err)
}
}