download code and tests
This commit is contained in:
@@ -5,4 +5,7 @@ import net.i2p.crypto.SigType
|
||||
class Constants {
|
||||
public static final byte PERSONA_VERSION = (byte)1
|
||||
public static final SigType SIG_TYPE = SigType.ECDSA_SHA512_P521 // TODO: decide which
|
||||
|
||||
public static final int MAX_HEADER_SIZE = 0x1 << 14
|
||||
public static final int MAX_HEADERS = 16
|
||||
}
|
||||
|
@@ -0,0 +1,139 @@
|
||||
package com.muwire.core.download;
|
||||
|
||||
import net.i2p.data.Base64
|
||||
|
||||
import com.muwire.core.Constants
|
||||
import com.muwire.core.InfoHash
|
||||
import com.muwire.core.connection.Endpoint
|
||||
import static com.muwire.core.util.DataUtil.readTillRN
|
||||
|
||||
import groovy.util.logging.Log
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.channels.FileChannel
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.StandardOpenOption
|
||||
import java.security.MessageDigest
|
||||
import java.security.NoSuchAlgorithmException
|
||||
|
||||
@Log
|
||||
class DownloadSession {
|
||||
|
||||
private final Pieces pieces
|
||||
private final InfoHash infoHash
|
||||
private final Endpoint endpoint
|
||||
private final File file
|
||||
private final int pieceSize
|
||||
private final long fileLength
|
||||
private final MessageDigest digest
|
||||
|
||||
private ByteBuffer mapped
|
||||
|
||||
DownloadSession(Pieces pieces, InfoHash infoHash, Endpoint endpoint, File file,
|
||||
int pieceSize, long fileLength) {
|
||||
this.pieces = pieces
|
||||
this.endpoint = endpoint
|
||||
this.infoHash = infoHash
|
||||
this.file = file
|
||||
this.pieceSize = pieceSize
|
||||
this.fileLength = fileLength
|
||||
try {
|
||||
digest = MessageDigest.getInstance("SHA-256")
|
||||
} catch (NoSuchAlgorithmException impossible) {
|
||||
digest = null
|
||||
System.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
public void request() throws IOException {
|
||||
OutputStream os = endpoint.getOutputStream()
|
||||
InputStream is = endpoint.getInputStream()
|
||||
|
||||
int piece = pieces.getRandomPiece()
|
||||
long start = piece * pieceSize
|
||||
long end = Math.min(fileLength, start + pieceSize) - 1
|
||||
long length = end - start + 1
|
||||
|
||||
String root = Base64.encode(infoHash.getRoot())
|
||||
|
||||
FileChannel channel
|
||||
try {
|
||||
os.write("GET $root\r\n".getBytes(StandardCharsets.US_ASCII))
|
||||
os.write("Range: $start-$end\r\n\r\n".getBytes(StandardCharsets.US_ASCII))
|
||||
os.flush()
|
||||
String code = readTillRN(is)
|
||||
if (code.startsWith("404 ")) {
|
||||
log.warning("file not found")
|
||||
endpoint.close()
|
||||
return
|
||||
}
|
||||
|
||||
if (code.startsWith("416 ")) {
|
||||
log.warning("range $start-$end cannot be satisfied")
|
||||
return // leave endpoint open
|
||||
}
|
||||
|
||||
if (!code.startsWith("200 ")) {
|
||||
log.warning("unknown code $code")
|
||||
endpoint.close()
|
||||
return
|
||||
}
|
||||
|
||||
// parse all headers
|
||||
Set<String> headers = new HashSet<>()
|
||||
String header
|
||||
while((header = readTillRN(is)) != "" && headers.size() < Constants.MAX_HEADERS)
|
||||
headers.add(header)
|
||||
|
||||
long receivedStart = -1
|
||||
long receivedEnd = -1
|
||||
for (String receivedHeader : headers) {
|
||||
def group = (receivedHeader =~ /^Content-Range: (\d+)-(\d+)$/)
|
||||
if (group.size() != 1) {
|
||||
log.info("ignoring header $receivedHeader")
|
||||
continue
|
||||
}
|
||||
|
||||
receivedStart = Long.parseLong(group[0][1])
|
||||
receivedEnd = Long.parseLong(group[0][2])
|
||||
}
|
||||
|
||||
if (receivedStart != start || receivedEnd != end) {
|
||||
log.warning("We don't support mismatching ranges yet")
|
||||
endpoint.close()
|
||||
return
|
||||
}
|
||||
|
||||
// start the download
|
||||
channel = Files.newByteChannel(file.toPath(), EnumSet.of(StandardOpenOption.READ, StandardOpenOption.WRITE,
|
||||
StandardOpenOption.SPARSE, StandardOpenOption.CREATE)) // TODO: double-check, maybe CREATE_NEW
|
||||
mapped = channel.map(FileChannel.MapMode.READ_WRITE, start, end - start + 1)
|
||||
|
||||
byte[] tmp = new byte[0x1 << 13]
|
||||
while(mapped.hasRemaining()) {
|
||||
int read = is.read(tmp)
|
||||
if (read == -1)
|
||||
throw new IOException()
|
||||
synchronized(this) {
|
||||
mapped.put(tmp, 0, read)
|
||||
}
|
||||
}
|
||||
|
||||
mapped.clear()
|
||||
digest.update(mapped)
|
||||
byte [] hash = digest.digest()
|
||||
byte [] expected = new byte[32]
|
||||
System.arraycopy(infoHash.getHashList(), piece * 32, expected, 0, 32)
|
||||
if (hash != expected) {
|
||||
log.warning("hash mismatch")
|
||||
endpoint.close()
|
||||
return
|
||||
}
|
||||
|
||||
pieces.markDownloaded(piece)
|
||||
} finally {
|
||||
try { channel?.close() } catch (IOException ignore) {}
|
||||
}
|
||||
}
|
||||
}
|
@@ -2,6 +2,7 @@ package com.muwire.core.upload
|
||||
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
import com.muwire.core.Constants
|
||||
import com.muwire.core.InfoHash
|
||||
|
||||
import groovy.util.logging.Log
|
||||
@@ -19,8 +20,8 @@ class Request {
|
||||
|
||||
static Request parse(InfoHash infoHash, InputStream is) throws IOException {
|
||||
Map<String,String> headers = new HashMap<>()
|
||||
byte [] tmp = new byte[0x1 << 14]
|
||||
while(true) {
|
||||
byte [] tmp = new byte[Constants.MAX_HEADER_SIZE]
|
||||
while(headers.size() < Constants.MAX_HEADERS) {
|
||||
boolean r = false
|
||||
boolean n = false
|
||||
int idx = 0
|
||||
|
@@ -2,6 +2,8 @@ package com.muwire.core.util
|
||||
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
import com.muwire.core.Constants
|
||||
|
||||
class DataUtil {
|
||||
|
||||
private final static int MAX_SHORT = (0x1 << 16) - 1
|
||||
@@ -61,4 +63,20 @@ class DataUtil {
|
||||
daos.close()
|
||||
baos.toByteArray()
|
||||
}
|
||||
|
||||
public static String readTillRN(InputStream is) {
|
||||
def baos = new ByteArrayOutputStream()
|
||||
while(baos.size() < (Constants.MAX_HEADER_SIZE)) {
|
||||
byte read = is.read()
|
||||
if (read == -1)
|
||||
throw new IOException()
|
||||
if (read == '\r') {
|
||||
if (is.read() != '\n')
|
||||
throw new IOException("invalid header")
|
||||
break
|
||||
}
|
||||
baos.write(read)
|
||||
}
|
||||
new String(baos.toByteArray(), StandardCharsets.US_ASCII)
|
||||
}
|
||||
}
|
||||
|
@@ -0,0 +1,89 @@
|
||||
package com.muwire.core.download
|
||||
|
||||
import org.junit.After
|
||||
import org.junit.Test
|
||||
|
||||
import com.muwire.core.InfoHash
|
||||
import com.muwire.core.connection.Endpoint
|
||||
import com.muwire.core.files.FileHasher
|
||||
import static com.muwire.core.util.DataUtil.readTillRN
|
||||
|
||||
import net.i2p.data.Base64
|
||||
|
||||
class DownloadSessionTest {
|
||||
|
||||
private File source, target
|
||||
private InfoHash infoHash
|
||||
private Endpoint endpoint
|
||||
private Pieces pieces
|
||||
private String rootBase64
|
||||
|
||||
private DownloadSession session
|
||||
private Thread downloadThread
|
||||
|
||||
private InputStream fromDownloader, fromUploader
|
||||
private OutputStream toDownloader, toUploader
|
||||
|
||||
private void initSession(int size) {
|
||||
Random r = new Random()
|
||||
byte [] content = new byte[size]
|
||||
r.nextBytes(content)
|
||||
|
||||
source = File.createTempFile("source", "tmp")
|
||||
source.deleteOnExit()
|
||||
def fos = new FileOutputStream(source)
|
||||
fos.write(content)
|
||||
fos.close()
|
||||
|
||||
def hasher = new FileHasher()
|
||||
infoHash = hasher.hashFile(source)
|
||||
rootBase64 = Base64.encode(infoHash.getRoot())
|
||||
|
||||
target = File.createTempFile("target", "tmp")
|
||||
int pieceSize = 1 << FileHasher.getPieceSize(size)
|
||||
|
||||
int nPieces
|
||||
if (size % pieceSize == 0)
|
||||
nPieces = size / pieceSize
|
||||
else
|
||||
nPieces = size / pieceSize + 1
|
||||
pieces = new Pieces(nPieces)
|
||||
|
||||
fromDownloader = new PipedInputStream()
|
||||
fromUploader = new PipedInputStream()
|
||||
toDownloader = new PipedOutputStream(fromUploader)
|
||||
toUploader = new PipedOutputStream(fromDownloader)
|
||||
endpoint = new Endpoint(null, fromUploader, toUploader, null)
|
||||
|
||||
session = new DownloadSession(pieces, infoHash, endpoint, target, pieceSize, size)
|
||||
downloadThread = new Thread( { session.request() } as Runnable)
|
||||
downloadThread.setDaemon(true)
|
||||
downloadThread.start()
|
||||
}
|
||||
|
||||
@After
|
||||
public void teardown() {
|
||||
source?.delete()
|
||||
target?.delete()
|
||||
downloadThread?.interrupt()
|
||||
Thread.sleep(50)
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSmallFile() {
|
||||
initSession(20)
|
||||
assert "GET $rootBase64" == readTillRN(fromDownloader)
|
||||
assert "Range: 0-19" == readTillRN(fromDownloader)
|
||||
assert "" == readTillRN(fromDownloader)
|
||||
|
||||
toDownloader.write("200 OK\r\n".bytes)
|
||||
toDownloader.write("Content-Range: 0-19\r\n\r\n".bytes)
|
||||
toDownloader.write(source.bytes)
|
||||
toDownloader.flush()
|
||||
|
||||
Thread.sleep(150)
|
||||
|
||||
assert pieces.isComplete()
|
||||
assert target.bytes == source.bytes
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user