Files
kotlin-fork/libraries/stdlib/jvm/test/io/Base64IOStreamTest.kt
T
2023-01-16 11:24:49 +00:00

372 lines
11 KiB
Kotlin

/*
* Copyright 2010-2022 JetBrains s.r.o. and Kotlin Programming Language contributors.
* Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
*/
package test.io
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import kotlin.io.encoding.Base64
import kotlin.io.encoding.decodingWith
import kotlin.io.encoding.encodingWith
import kotlin.test.*
class Base64IOStreamTest {
private fun testCoding(base64: Base64, text: String, encodedText: String) {
val encodedBytes = ByteArray(encodedText.length) { encodedText[it].code.toByte() }
val bytes = ByteArray(text.length) { text[it].code.toByte() }
encodedBytes.inputStream().decodingWith(base64).use { inputStream ->
assertEquals(text, inputStream.reader().readText())
}
encodedBytes.inputStream().decodingWith(base64).use { inputStream ->
assertContentEquals(bytes, inputStream.readBytes())
}
ByteArrayOutputStream().let { outputStream ->
outputStream.encodingWith(base64).use {
it.write(bytes)
}
assertContentEquals(encodedBytes, outputStream.toByteArray())
}
}
@Test
fun base64() {
fun testBase64(text: String, encodedText: String) {
testCoding(Base64, text, encodedText)
testCoding(Base64.Mime, text, encodedText)
}
testBase64("", "")
testBase64("f", "Zg==")
testBase64("fo", "Zm8=")
testBase64("foo", "Zm9v")
testBase64("foob", "Zm9vYg==")
testBase64("fooba", "Zm9vYmE=")
testBase64("foobar", "Zm9vYmFy")
}
@Test
fun readDifferentOffsetAndLength() {
val repeat = 10_000
val symbols = "Zm9vYmFy".repeat(repeat) + "Zm8="
val expected = "foobar".repeat(repeat) + "fo"
val bytes = ByteArray(expected.length)
symbols.byteInputStream().decodingWith(Base64).use { input ->
var read = 0
repeat(6) {
bytes[read++] = input.read().toByte()
}
var toRead = 1
while (read < bytes.size) {
val length = minOf(toRead, bytes.size - read)
val result = input.read(bytes, read, length)
assertEquals(length, result)
read += result
toRead += toRead * 10 / 9
}
assertEquals(-1, input.read(bytes))
assertEquals(-1, input.read())
assertEquals(expected, bytes.decodeToString())
}
}
@Test
fun readDifferentOffsetAndLengthMime() {
val repeat = 10_000
val symbols = ("Zm9vYmFy".repeat(repeat) + "Zm8=").chunked(76).joinToString(separator = "\r\n")
val expected = "foobar".repeat(repeat) + "fo"
val bytes = ByteArray(expected.length)
symbols.byteInputStream().decodingWith(Base64.Mime).use { input ->
var read = 0
repeat(6) {
bytes[read++] = input.read().toByte()
}
var toRead = 1
while (read < bytes.size) {
val length = minOf(toRead, bytes.size - read)
val result = input.read(bytes, read, length)
assertEquals(length, result)
read += result
toRead += toRead * 10 / 9
}
assertEquals(-1, input.read(bytes))
assertEquals(-1, input.read())
assertEquals(expected, bytes.decodeToString())
}
}
@Test
fun writeDifferentOffsetAndLength() {
val repeat = 10_000
val bytes = ("foobar".repeat(repeat) + "fo").encodeToByteArray()
val expected = "Zm9vYmFy".repeat(repeat) + "Zm8="
val underlying = ByteArrayOutputStream()
underlying.encodingWith(Base64).use { output ->
var written = 0
repeat(8) {
output.write(bytes[written++].toInt())
}
var toWrite = 1
while (written < bytes.size) {
val length = minOf(toWrite, bytes.size - written)
output.write(bytes, written, length)
written += length
toWrite += toWrite * 10 / 9
}
}
assertEquals(expected, underlying.toString())
}
@Test
fun writeDifferentOffsetAndLengthMime() {
val repeat = 10_000
val bytes = ("foobar".repeat(repeat) + "fo").encodeToByteArray()
val expected = ("Zm9vYmFy".repeat(repeat) + "Zm8=").chunked(76).joinToString(separator = "\r\n")
val underlying = ByteArrayOutputStream()
underlying.encodingWith(Base64.Mime).use { output ->
var written = 0
repeat(8) {
output.write(bytes[written++].toInt())
}
var toWrite = 1
while (written < bytes.size) {
val length = minOf(toWrite, bytes.size - written)
output.write(bytes, written, length)
written += length
toWrite += toWrite * 10 / 9
}
}
assertEquals(expected, underlying.toString())
}
@Test
fun inputStreamClosesUnderlying() {
val underlying = object : InputStream() {
var isClosed: Boolean = false
override fun close() {
isClosed = true
super.close()
}
override fun read(): Int {
return 0
}
}
val wrapper = underlying.decodingWith(Base64)
wrapper.close()
assertTrue(underlying.isClosed)
}
@Test
fun outputStreamClosesUnderlying() {
val underlying = object : OutputStream() {
var isClosed: Boolean = false
override fun close() {
isClosed = true
super.close()
}
override fun write(b: Int) {
// ignore
}
}
val wrapper = underlying.encodingWith(Base64)
wrapper.close()
assertTrue(underlying.isClosed)
}
@Test
fun correctPadding() {
val inputStream = "Zg==Zg==".byteInputStream()
val wrapper = inputStream.decodingWith(Base64)
wrapper.use {
assertEquals('f'.code, it.read())
assertEquals(-1, it.read())
assertEquals(-1, it.read())
// in the wrapped IS the chars after the padding are not consumed
assertContentEquals("Zg==".toByteArray(), inputStream.readBytes())
}
assertFailsWith<IOException> {
wrapper.read()
}
}
@Test
fun correctPaddingMime() {
val inputStream = "Zg==Zg==".byteInputStream()
val wrapper = inputStream.decodingWith(Base64.Mime)
wrapper.use {
assertEquals('f'.code, it.read())
assertEquals(-1, it.read())
assertEquals(-1, it.read())
// in the wrapped IS the chars after the padding are not consumed
assertContentEquals("Zg==".toByteArray(), inputStream.readBytes())
}
// closed
assertFailsWith<IOException> {
wrapper.read()
}
}
@Test
fun illegalSymbol() {
val inputStream = "Zm\u00FF9vYg==".byteInputStream()
val wrapper = inputStream.decodingWith(Base64)
wrapper.use {
// one group of 4 symbols is read for decoding, that group includes illegal '\u00FF'
assertFailsWith<IllegalArgumentException> {
it.read()
}
}
// closed
assertFailsWith<IOException> {
wrapper.read()
}
}
@Test
fun illegalSymbolMime() {
val inputStream = "Zm\u00FF9vYg==".byteInputStream()
val wrapper = inputStream.decodingWith(Base64.Mime)
wrapper.use {
assertEquals('f'.code, it.read())
assertEquals('o'.code, it.read())
assertEquals('o'.code, it.read())
assertEquals('b'.code, it.read())
assertEquals(-1, it.read())
assertEquals(-1, it.read())
}
// closed
assertFailsWith<IOException> {
wrapper.read()
}
}
@Test
fun incorrectPadding() {
for (base64 in listOf(Base64, Base64.Mime)) {
val inputStream = "Zm9vZm=9v".byteInputStream()
val wrapper = inputStream.decodingWith(base64)
wrapper.use {
assertEquals('f'.code, it.read())
assertEquals('o'.code, it.read())
assertEquals('o'.code, it.read())
// the second group is incorrectly padded
assertFailsWith<IllegalArgumentException> {
it.read()
}
}
// closed
assertFailsWith<IOException> {
wrapper.read()
}
}
}
@Test
fun withoutPadding() {
for (base64 in listOf(Base64, Base64.Mime)) {
val inputStream = "Zm9vYg".byteInputStream()
val wrapper = inputStream.decodingWith(base64)
wrapper.use {
assertEquals('f'.code, it.read())
assertEquals('o'.code, it.read())
assertEquals('o'.code, it.read())
assertEquals('b'.code, it.read())
assertEquals(-1, it.read())
assertEquals(-1, it.read())
}
// closed
assertFailsWith<IOException> {
wrapper.read()
}
}
}
@Test
fun separatedPadSymbols() {
val inputStream = "Zm9vYg=[,.|^&*@#]=".byteInputStream()
val wrapper = inputStream.decodingWith(Base64)
wrapper.use {
assertEquals('f'.code, it.read())
assertEquals('o'.code, it.read())
assertEquals('o'.code, it.read())
// the second group contains illegal symbols
assertFailsWith<IllegalArgumentException> {
it.read()
}
}
// closed
assertFailsWith<IOException> {
wrapper.read()
}
}
@Test
fun separatedPadSymbolsMime() {
val inputStream = "Zm9vYg=[,.|^&*@#]=".byteInputStream()
val wrapper = inputStream.decodingWith(Base64.Mime)
wrapper.use {
assertEquals('f'.code, it.read())
assertEquals('o'.code, it.read())
assertEquals('o'.code, it.read())
assertEquals('b'.code, it.read())
assertEquals(-1, it.read())
assertEquals(-1, it.read())
}
// closed
assertFailsWith<IOException> {
wrapper.read()
}
}
}