use std::convert::TryFrom;
use std::error;
use std::fmt;
use std::io;
use std::mem::size_of;
use std::result;
use std::slice;
pub const MIN_BLOCK_SIZE: usize = 32;
pub const MAX_BLOCK_SIZE: usize = 1 << 20;
pub const META_SIZE: usize = size_of::<i32>() + size_of::<i64>();
const BYTE: i32 = u8::MAX as i32 + 1;
pub trait IntRng {
fn next_int(&mut self) -> i32;
}
pub trait Checksum {
fn reset(&mut self);
fn update(&mut self, v: u8);
fn get_value(&self) -> i64;
}
pub type Result<T> = result::Result<T, Error>;
#[derive(Debug)]
pub enum Error {
IoError(io::Error),
UnkocryptoError(Cause),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Cause {
InvalidBlockSize,
InvalidSourceSize,
InvalidDataCount,
InvalidData,
InvalidChecksum,
}
impl From<io::Error> for Error {
fn from(error: io::Error) -> Self {
Error::IoError(error)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Error::*;
match self {
IoError(ref e) => fmt::Display::fmt(e, f),
UnkocryptoError(ref c) => write!(f, "UnkocryptoError({})", c),
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
use Error::*;
match self {
IoError(ref e) => Some(e),
UnkocryptoError(_) => None,
}
}
}
impl fmt::Display for Cause {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
fn next_int<T: IntRng>(rng: &mut T, bound: i32) -> i32 {
if (bound & -bound) == bound {
return ((bound as u64 * (rng.next_int() as u32 >> 1) as u64) >> 31) as i32;
}
let mut val: i32;
loop {
let bits = ((rng.next_int() as u32) >> 1) as i32;
val = bits % bound;
if (bits - val).wrapping_add(bound - 1) >= 0 {
break;
}
}
val
}
fn read<R: io::Read>(src: &mut R, one_byte: &mut u8) -> Result<usize> {
let size = src.read(slice::from_mut(one_byte))?;
Ok(size)
}
fn read_int(src: &[u8], pos: usize) -> i32 {
let src = &src[pos..pos + size_of::<i32>()];
i32::from_be_bytes(TryFrom::try_from(src).unwrap())
}
fn read_long(src: &[u8], pos: usize) -> i64 {
let src = &src[pos..pos + size_of::<i64>()];
i64::from_be_bytes(TryFrom::try_from(src).unwrap())
}
fn write_int(dst: &mut [u8], pos: usize, v: i32) {
dst[pos..pos + size_of::<i32>()].copy_from_slice(&v.to_be_bytes());
}
fn write_long(dst: &mut [u8], pos: usize, v: i64) {
dst[pos..pos + size_of::<i64>()].copy_from_slice(&v.to_be_bytes());
}
pub fn decrypt<C, T, R, W>(
block_size: usize,
mut checksum: C,
rng: &mut T,
src: &mut R,
dst: &mut W,
) -> Result<(u64, u64)>
where
C: Checksum,
T: IntRng,
R: io::Read,
W: io::Write,
{
if !(MIN_BLOCK_SIZE..=MAX_BLOCK_SIZE).contains(&block_size) {
return Err(Error::UnkocryptoError(Cause::InvalidBlockSize));
}
let data_size: usize = block_size - META_SIZE;
let mut mask: Vec<u8> = vec![0; block_size];
let mut indexes: Vec<usize> = vec![0; block_size];
let mut data: Vec<u8> = vec![0; block_size];
let mut src_size: u64 = 0;
let mut dst_size: u64 = 0;
let mut one_byte: u8 = 0;
let mut read_size: usize = read(src, &mut one_byte)?;
if read_size == 0 {
return Err(Error::UnkocryptoError(Cause::InvalidSourceSize));
}
while read_size > 0 {
for i in 0..block_size {
mask[i] = next_int(rng, BYTE) as u8;
indexes[i] = i;
}
for i in 0..block_size {
let j = next_int(rng, (block_size - i) as i32) as usize + i;
indexes.swap(i, j);
}
for j in indexes.iter() {
if read_size == 0 {
return Err(Error::UnkocryptoError(Cause::InvalidSourceSize));
}
data[*j] = one_byte ^ mask[*j];
read_size = read(src, &mut one_byte)?;
}
src_size += block_size as u64;
let count = read_int(&data, data_size);
let code = read_long(&data, data_size + size_of::<i32>());
if count < 0 || (count == 0 && read_size > 0) || data_size < count as usize {
return Err(Error::UnkocryptoError(Cause::InvalidDataCount));
}
let count = count as usize;
dst_size += count as u64;
checksum.reset();
for (i, d) in data.iter().enumerate().take(data_size) {
if i < count {
dst.write_all(slice::from_ref(d))?;
checksum.update(*d);
} else if *d != 0 {
return Err(Error::UnkocryptoError(Cause::InvalidData));
}
}
if code != checksum.get_value() {
return Err(Error::UnkocryptoError(Cause::InvalidChecksum));
}
}
Ok((src_size, dst_size))
}
pub fn encrypt<C, T, R, W>(
block_size: usize,
mut checksum: C,
rng: &mut T,
src: &mut R,
dst: &mut W,
) -> Result<(u64, u64)>
where
C: Checksum,
T: IntRng,
R: io::Read,
W: io::Write,
{
if !(MIN_BLOCK_SIZE..=MAX_BLOCK_SIZE).contains(&block_size) {
return Err(Error::UnkocryptoError(Cause::InvalidBlockSize));
}
let data_size: usize = block_size - META_SIZE;
let mut data: Vec<u8> = vec![0; block_size];
let mut src_size: u64 = 0;
let mut dst_size: u64 = 0;
let mut one_byte: u8 = 0;
let mut read_size: usize = read(src, &mut one_byte)?;
loop {
dst_size += block_size as u64;
checksum.reset();
let mut count: usize = 0;
while count < data_size {
if read_size == 0 {
break;
}
checksum.update(one_byte);
data[count] = one_byte ^ next_int(rng, BYTE) as u8;
count += 1;
read_size = read(src, &mut one_byte)?;
}
src_size += count as u64;
for d in data.iter_mut().take(data_size).skip(count) {
*d = next_int(rng, BYTE) as u8;
}
write_int(&mut data, data_size, count as i32);
write_long(
&mut data,
data_size + size_of::<i32>(),
checksum.get_value(),
);
for d in data.iter_mut().skip(data_size) {
*d ^= next_int(rng, BYTE) as u8;
}
for i in 0..data.len() {
let j = next_int(rng, (data.len() - i) as i32) as usize + i;
data.swap(i, j);
}
dst.write_all(&data)?;
if read_size == 0 {
break;
}
}
Ok((src_size, dst_size))
}
pub fn calc_block_count(block_size: usize, src_len: u64) -> Result<u64> {
if !(MIN_BLOCK_SIZE..=MAX_BLOCK_SIZE).contains(&block_size) {
return Err(Error::UnkocryptoError(Cause::InvalidBlockSize));
}
let data_size: u64 = (block_size - META_SIZE) as u64;
let block_count: u64 = (src_len + data_size - 1) / data_size;
Ok(block_count)
}
pub fn calc_encrypted_size(block_size: usize, src_len: u64) -> Result<u64> {
calc_block_count(block_size, src_len)?
.checked_mul(block_size as u64)
.ok_or_else(|| Error::UnkocryptoError(Cause::InvalidSourceSize))
}
pub fn consume<T: IntRng>(rng: &mut T, block_size: usize, block_count: u64) -> Result<()> {
if !(MIN_BLOCK_SIZE..=MAX_BLOCK_SIZE).contains(&block_size) {
return Err(Error::UnkocryptoError(Cause::InvalidBlockSize));
}
let _ = block_count
.checked_mul(block_size as u64)
.ok_or_else(|| Error::UnkocryptoError(Cause::InvalidSourceSize))?;
for _ in 0..block_count {
for _ in 0..block_size {
let _ = next_int(rng, BYTE);
}
for i in 0..block_size {
let _ = next_int(rng, (block_size - i) as i32);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
struct JavaRandom {
seed: i64,
}
impl JavaRandom {
fn new(seed: i64) -> JavaRandom {
let mut jr = JavaRandom { seed: 0 };
jr.set_seed(seed);
jr
}
fn set_seed(&mut self, seed: i64) {
self.seed = (seed ^ 0x5DEECE66D) & ((1 << 48) - 1);
}
fn next(&mut self, bits: i32) -> i32 {
self.seed = self.seed.wrapping_mul(0x5DEECE66D).wrapping_add(0xB) & ((1 << 48) - 1);
(self.seed as u64 >> (48 - bits)) as i32
}
}
impl crate::IntRng for JavaRandom {
fn next_int(&mut self) -> i32 {
self.next(32)
}
}
static CRC32_TABLE: [i64; 256] = {
const fn f(c: i64) -> i64 {
(c >> 1) ^ (0xedb88320 * (c & 1))
}
const fn calc(v: i64) -> i64 {
f(f(f(f(f(f(f(f(v))))))))
}
let mut table = [0_i64; 256];
macro_rules! m {
($e:expr) => {
table[$e] = calc($e);
};
($e:expr, $a:expr $(,$b:expr)*) => {
m!(($e<<1) $(,$b)*);
m!((($e<<1)|1) $(,$b)*);
};
}
m!(0, 1, 2, 3, 4, 5, 6, 7, 8);
table
};
struct Crc32 {
value: i64,
}
impl Crc32 {
fn new() -> Self {
Self { value: 0xffffffff }
}
}
impl crate::Checksum for Crc32 {
fn reset(&mut self) {
self.value = 0xffffffff;
}
fn update(&mut self, v: u8) {
let b = CRC32_TABLE[(self.value & 0xff) as usize ^ v as usize];
let c = self.value >> 8;
self.value = b ^ c;
}
fn get_value(&self) -> i64 {
self.value ^ 0xffffffff
}
}
#[test]
fn it_works() {
let block_size: usize = crate::MIN_BLOCK_SIZE;
let seed: i64 = 123456789;
let data_src: [u8; 10] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let secret_src: [u8; 32] = [
0xa2, 0xfc, 0x45, 0x90, 0x64, 0x80, 0x77, 0x46, 0x3f, 0x7e, 0x1d, 0x7c, 0x64, 0xfe,
0x5c, 0x98, 0x7a, 0x00, 0x79, 0xa8, 0x64, 0xf2, 0x7d, 0xc1, 0xe3, 0x66, 0x31, 0x31,
0x1e, 0x62, 0xb6, 0x04,
];
{
let crc = Crc32::new();
let mut rng = JavaRandom::new(seed);
let mut cur = secret_src.as_slice();
let mut dst = Vec::new();
let (src_len, dst_len) =
crate::decrypt(block_size, crc, &mut rng, &mut cur, &mut dst).unwrap();
assert_eq!(secret_src.len() as u64, src_len);
assert_eq!(data_src.len() as u64, dst_len);
assert_eq!(data_src, dst.as_ref());
}
{
let crc = Crc32::new();
let mut rng = JavaRandom::new(seed);
let mut cur = data_src.as_slice();
let mut dst = Vec::new();
let (src_len, dst_len) =
crate::encrypt(block_size, crc, &mut rng, &mut cur, &mut dst).unwrap();
assert_eq!(data_src.len() as u64, src_len);
assert_eq!(secret_src.len() as u64, dst_len);
assert_eq!(secret_src, dst.as_ref());
}
{
const LARGE_SRC_LEN: u64 = 12345;
let crc = Crc32::new();
let mut rng = JavaRandom::new(seed);
let mut cur = std::io::Read::take(std::io::repeat(123), LARGE_SRC_LEN);
let mut dst = std::io::sink();
let (_, len) = crate::encrypt(block_size, crc, &mut rng, &mut cur, &mut dst).unwrap();
let guess_size = crate::calc_encrypted_size(block_size, LARGE_SRC_LEN).unwrap();
assert_eq!(len, guess_size);
}
{
let multi_src: Vec<Vec<u8>> = vec![data_src.as_slice(); 5]
.iter()
.enumerate()
.map(|(i, a)| {
a.iter()
.map(|e| i as u8 + *e)
.cycle()
.take(10 * (i + 10))
.collect()
})
.collect();
let encrypted = {
let mut rng = JavaRandom::new(seed);
let mut dst = Vec::new();
for mut src in multi_src.iter().map(|a| a.as_slice()) {
let block_count =
crate::calc_block_count(block_size, src.len() as u64).unwrap();
dst.push(block_count as u8);
let crc = Crc32::new();
crate::encrypt(block_size, crc, &mut rng, &mut src, &mut dst).unwrap();
}
dst
};
{
let mut rng = JavaRandom::new(seed);
let mut cur = encrypted.as_slice();
for (i, src) in multi_src.iter().enumerate() {
let block_count = cur[0] as u64;
let size = block_size * block_count as usize;
cur = &cur[1..];
if i % 2 == 1 {
crate::consume(&mut rng, block_size, block_count).unwrap();
} else {
let crc = Crc32::new();
let mut chunk = &cur[..size];
let mut dst = Vec::new();
crate::decrypt(block_size, crc, &mut rng, &mut chunk, &mut dst).unwrap();
assert_eq!(src, &dst);
}
cur = &cur[size..];
}
}
}
}
}