download code and tests

This commit is contained in:
Zlatin Balevsky
2019-05-29 08:10:42 +01:00
parent 8ff8e98424
commit 15d11be16e
5 changed files with 252 additions and 2 deletions

View File

@@ -5,4 +5,7 @@ import net.i2p.crypto.SigType
class Constants {
public static final byte PERSONA_VERSION = (byte)1
public static final SigType SIG_TYPE = SigType.ECDSA_SHA512_P521 // TODO: decide which
public static final int MAX_HEADER_SIZE = 0x1 << 14
public static final int MAX_HEADERS = 16
}

View File

@@ -0,0 +1,139 @@
package com.muwire.core.download;
import net.i2p.data.Base64
import com.muwire.core.Constants
import com.muwire.core.InfoHash
import com.muwire.core.connection.Endpoint
import static com.muwire.core.util.DataUtil.readTillRN
import groovy.util.logging.Log
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.StandardOpenOption
import java.security.MessageDigest
import java.security.NoSuchAlgorithmException
@Log
class DownloadSession {
private final Pieces pieces
private final InfoHash infoHash
private final Endpoint endpoint
private final File file
private final int pieceSize
private final long fileLength
private final MessageDigest digest
private ByteBuffer mapped
DownloadSession(Pieces pieces, InfoHash infoHash, Endpoint endpoint, File file,
int pieceSize, long fileLength) {
this.pieces = pieces
this.endpoint = endpoint
this.infoHash = infoHash
this.file = file
this.pieceSize = pieceSize
this.fileLength = fileLength
try {
digest = MessageDigest.getInstance("SHA-256")
} catch (NoSuchAlgorithmException impossible) {
digest = null
System.exit(1)
}
}
public void request() throws IOException {
OutputStream os = endpoint.getOutputStream()
InputStream is = endpoint.getInputStream()
int piece = pieces.getRandomPiece()
long start = piece * pieceSize
long end = Math.min(fileLength, start + pieceSize) - 1
long length = end - start + 1
String root = Base64.encode(infoHash.getRoot())
FileChannel channel
try {
os.write("GET $root\r\n".getBytes(StandardCharsets.US_ASCII))
os.write("Range: $start-$end\r\n\r\n".getBytes(StandardCharsets.US_ASCII))
os.flush()
String code = readTillRN(is)
if (code.startsWith("404 ")) {
log.warning("file not found")
endpoint.close()
return
}
if (code.startsWith("416 ")) {
log.warning("range $start-$end cannot be satisfied")
return // leave endpoint open
}
if (!code.startsWith("200 ")) {
log.warning("unknown code $code")
endpoint.close()
return
}
// parse all headers
Set<String> headers = new HashSet<>()
String header
while((header = readTillRN(is)) != "" && headers.size() < Constants.MAX_HEADERS)
headers.add(header)
long receivedStart = -1
long receivedEnd = -1
for (String receivedHeader : headers) {
def group = (receivedHeader =~ /^Content-Range: (\d+)-(\d+)$/)
if (group.size() != 1) {
log.info("ignoring header $receivedHeader")
continue
}
receivedStart = Long.parseLong(group[0][1])
receivedEnd = Long.parseLong(group[0][2])
}
if (receivedStart != start || receivedEnd != end) {
log.warning("We don't support mismatching ranges yet")
endpoint.close()
return
}
// start the download
channel = Files.newByteChannel(file.toPath(), EnumSet.of(StandardOpenOption.READ, StandardOpenOption.WRITE,
StandardOpenOption.SPARSE, StandardOpenOption.CREATE)) // TODO: double-check, maybe CREATE_NEW
mapped = channel.map(FileChannel.MapMode.READ_WRITE, start, end - start + 1)
byte[] tmp = new byte[0x1 << 13]
while(mapped.hasRemaining()) {
int read = is.read(tmp)
if (read == -1)
throw new IOException()
synchronized(this) {
mapped.put(tmp, 0, read)
}
}
mapped.clear()
digest.update(mapped)
byte [] hash = digest.digest()
byte [] expected = new byte[32]
System.arraycopy(infoHash.getHashList(), piece * 32, expected, 0, 32)
if (hash != expected) {
log.warning("hash mismatch")
endpoint.close()
return
}
pieces.markDownloaded(piece)
} finally {
try { channel?.close() } catch (IOException ignore) {}
}
}
}

View File

@@ -2,6 +2,7 @@ package com.muwire.core.upload
import java.nio.charset.StandardCharsets
import com.muwire.core.Constants
import com.muwire.core.InfoHash
import groovy.util.logging.Log
@@ -19,8 +20,8 @@ class Request {
static Request parse(InfoHash infoHash, InputStream is) throws IOException {
Map<String,String> headers = new HashMap<>()
byte [] tmp = new byte[0x1 << 14]
while(true) {
byte [] tmp = new byte[Constants.MAX_HEADER_SIZE]
while(headers.size() < Constants.MAX_HEADERS) {
boolean r = false
boolean n = false
int idx = 0

View File

@@ -2,6 +2,8 @@ package com.muwire.core.util
import java.nio.charset.StandardCharsets
import com.muwire.core.Constants
class DataUtil {
private final static int MAX_SHORT = (0x1 << 16) - 1
@@ -61,4 +63,20 @@ class DataUtil {
daos.close()
baos.toByteArray()
}
public static String readTillRN(InputStream is) {
def baos = new ByteArrayOutputStream()
while(baos.size() < (Constants.MAX_HEADER_SIZE)) {
byte read = is.read()
if (read == -1)
throw new IOException()
if (read == '\r') {
if (is.read() != '\n')
throw new IOException("invalid header")
break
}
baos.write(read)
}
new String(baos.toByteArray(), StandardCharsets.US_ASCII)
}
}

View File

@@ -0,0 +1,89 @@
package com.muwire.core.download
import org.junit.After
import org.junit.Test
import com.muwire.core.InfoHash
import com.muwire.core.connection.Endpoint
import com.muwire.core.files.FileHasher
import static com.muwire.core.util.DataUtil.readTillRN
import net.i2p.data.Base64
class DownloadSessionTest {
private File source, target
private InfoHash infoHash
private Endpoint endpoint
private Pieces pieces
private String rootBase64
private DownloadSession session
private Thread downloadThread
private InputStream fromDownloader, fromUploader
private OutputStream toDownloader, toUploader
private void initSession(int size) {
Random r = new Random()
byte [] content = new byte[size]
r.nextBytes(content)
source = File.createTempFile("source", "tmp")
source.deleteOnExit()
def fos = new FileOutputStream(source)
fos.write(content)
fos.close()
def hasher = new FileHasher()
infoHash = hasher.hashFile(source)
rootBase64 = Base64.encode(infoHash.getRoot())
target = File.createTempFile("target", "tmp")
int pieceSize = 1 << FileHasher.getPieceSize(size)
int nPieces
if (size % pieceSize == 0)
nPieces = size / pieceSize
else
nPieces = size / pieceSize + 1
pieces = new Pieces(nPieces)
fromDownloader = new PipedInputStream()
fromUploader = new PipedInputStream()
toDownloader = new PipedOutputStream(fromUploader)
toUploader = new PipedOutputStream(fromDownloader)
endpoint = new Endpoint(null, fromUploader, toUploader, null)
session = new DownloadSession(pieces, infoHash, endpoint, target, pieceSize, size)
downloadThread = new Thread( { session.request() } as Runnable)
downloadThread.setDaemon(true)
downloadThread.start()
}
@After
public void teardown() {
source?.delete()
target?.delete()
downloadThread?.interrupt()
Thread.sleep(50)
}
@Test
public void testSmallFile() {
initSession(20)
assert "GET $rootBase64" == readTillRN(fromDownloader)
assert "Range: 0-19" == readTillRN(fromDownloader)
assert "" == readTillRN(fromDownloader)
toDownloader.write("200 OK\r\n".bytes)
toDownloader.write("Content-Range: 0-19\r\n\r\n".bytes)
toDownloader.write(source.bytes)
toDownloader.flush()
Thread.sleep(150)
assert pieces.isComplete()
assert target.bytes == source.bytes
}
}