From 975d5e7fb7a33a7785d9cce643d83dabb2acd03f Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 8 Dec 2020 16:35:23 -0800 Subject: [PATCH 01/15] [RELAY,TOPI] Threefry PRNG: splittable and stateless --- include/tvm/relay/attrs/algorithm.h | 8 + python/tvm/relay/op/_algorithm.py | 7 + python/tvm/relay/op/algorithm.py | 114 +++++- python/tvm/relay/op/op_attrs.py | 5 + python/tvm/relay/op/strategy/cuda.py | 1 - python/tvm/relay/op/strategy/generic.py | 44 +++ python/tvm/topi/generic/__init__.py | 1 + python/tvm/topi/generic/algorithm.py | 401 +++++++++++++++++++++ src/relay/op/algorithm/prng.cc | 85 +++++ tests/python/relay/test_prng.py | 61 ++++ tests/python/topi/python/test_topi_prng.py | 116 ++++++ 11 files changed, 841 insertions(+), 2 deletions(-) create mode 100644 python/tvm/topi/generic/algorithm.py create mode 100644 src/relay/op/algorithm/prng.cc create mode 100644 tests/python/relay/test_prng.py create mode 100644 tests/python/topi/python/test_topi_prng.py diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 83b4ddaead43..6ecc5c23935c 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -76,6 +76,14 @@ struct TopKAttrs : public tvm::AttrsNode { } }; +struct ThreefryGenerateAttrs : public tvm::AttrsNode { + Array out_shape; + + TVM_DECLARE_ATTRS(ThreefryGenerateAttrs, "relay.attrs.ThreefryGenerateAttrs") { + TVM_ATTR_FIELD(out_shape).describe("Shape of random numbers to generate"); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ALGORITHM_H_ diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 732d5016755a..5124f7112049 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -73,3 +73,10 @@ def topk_shape_func(attrs, inputs, _): ret = [indices_out] return ret + + +# threefry +register_strategy("threefry_generate", strategy.threefry_generate_strategy) +register_pattern("threefry_generate", OpPattern.OPAQUE) +register_strategy("threefry_split", strategy.threefry_split_strategy) +register_pattern("threefry_split", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index e0550543f4b8..7b4d1b48978b 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,9 +17,14 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs +import sys + +import numpy as np + +from ... import nd +from ..expr import Constant, Expr, TupleWrapper from . import _make from .dyn import _make as _dyn_make -from ..expr import TupleWrapper, Expr, Constant def argsort(data, axis=-1, is_ascend=1, dtype="int32"): @@ -93,3 +98,110 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): if ret_type == "both": return TupleWrapper(out, 2) return out + + +def threefry_seed(seed): + """Create a new Threefry random number generator. + + Example + ------- + + .. code-block:: python + + gen = threefry_seed(0) + _, random_number = threefry_generate(gen, (1,)) + + Parameters + ---------- + seed : int + Starting seed for the generator + + Returns + ------- + gen : relay.Expr + New generator to pass to future uses of :py:func:`threefry_split` or + :py:func:`threefry_generate`. + """ + s = np.frombuffer(seed.to_bytes(32, sys.byteorder), dtype="uint64") + a = np.concatenate((s, np.array([0, 0, 0, 0, 1 << 63, 0], dtype="uint64"))) + return Constant(nd.array(a)) + + +def threefry_generate(gen, shape): + """Generate an array of random numbers using the Threefry algorithm + + Example + ------- + + .. code-block:: python + + gen = threefry_seed(0) + new_gen, random1 = threefry_generate(gen, (1,)) + _, random2 = threefry_generate(new_gen, (1,)) + # random1 and random2 are different random numbers + + Parameters + ---------- + gen : relay.Expr + generator that uniquely determines the random values. Multiple uses with the + same generator will generate the same random values. This generator should be + treated as an opaque pointer. You can create one from calling + :py:func:`threefry_seed`, :py:func:`threefry_split`, or + :py:func:`threefry_generate`. _Do not use this generator again after calling + this function_. + + shape : Sequence[int] + Desired outputs shape of random numbers + + Returns + ------- + new_gen : relay.Expr + New generator to pass to future uses of :py:func:`threefry_split` or + :py:func:`threefry_generate`. + + random_array : relay.Expr + Array of random numbers. Has shape `shape`. + """ + return _make.threefry_generate(gen, shape) + + +def threefry_split(gen): + """Split an existing threefry generator into two new ones. + + This is useful if you have to subsequent calls which each need their own + random number generation. + + Example + ------- + + .. code-block:: python + + def foo(gen): + new_gen, num = threefry_generate(gen, (1,)) + return num + + gen = threefry_seed(0) + gen1, gen2 = threefry_split(gen) + assert foo(gen1) != foo(gen2) + + Parameters + ---------- + gen : relay.Expr + generator that uniquely determines the random values. Multiple uses with the + same generator will generate the same random values. This generator should be + treated as an opaque pointer. You can create one from calling + :py:func:`threefry_seed`, :py:func:`threefry_split`, or + :py:func:`threefry_generate`. _Do not use this generator again after calling + this function_. + + Returns + ------- + new_gen_1 : relay.Expr + New generator to pass to future uses of :py:func:`threefry_split` or + :py:func:`threefry_generate`. + + new_gen_2 : relay.Expr + New generator to pass to future uses of :py:func:`threefry_split` or + :py:func:`threefry_generate`. + """ + return _make.threefry_split(gen) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index cb837b192a6c..41076817b374 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -552,3 +552,8 @@ class SpaceToBatchNDAttrs(Attrs): @tvm._ffi.register_object("relay.attrs.BatchToSpaceNDAttrs") class BatchToSpaceNDAttrs(Attrs): """Attributes used in BatchToSpaceNDAttrs operators""" + + +@tvm._ffi.register_object("relay.attrs.ThreefryGenerateAttrs") +class ThreefryGenerateAttrs(Attrs): + """Attributes used in ThreefryGenerateAttrs operators""" diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index a060a5f6eb68..326a184579e0 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -17,7 +17,6 @@ """Definition of CUDA/GPU operator strategy.""" # pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import from tvm import topi -import tvm from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.te import SpecializedCondition from tvm.contrib import nvcc diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index e888eb4d037b..42e882dd212b 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1264,3 +1264,47 @@ def argwhere_strategy(attrs, inputs, out_type, target): name="argwhere.generic", ) return strategy + + +# threefry_generate +def wrap_compute_threefry_generate(topi_compute): + """Wrap threefry_generate topi compute""" + + def _compute_threefry_generate(attrs, inputs, _): + return topi_compute(inputs[0], attrs.out_shape) + + return _compute_threefry_generate + + +@override_native_generic_func("threefry_generate_strategy") +def threefry_generate_strategy(attrs, inputs, out_type, target): + """threefry_generate generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_threefry_generate(topi.generic.threefry_generate), + wrap_topi_schedule(topi.generic.schedule_extern), + name="threefry_generate.generic", + ) + return strategy + + +# threefry_split +def wrap_compute_threefry_split(topi_compute): + """Wrap threefry_split topi compute""" + + def _compute_threefry_split(attrs, inputs, _): + return topi_compute(inputs[0]) + + return _compute_threefry_split + + +@override_native_generic_func("threefry_split_strategy") +def threefry_split_strategy(attrs, inputs, out_type, target): + """threefry_split generic strategy""" + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_threefry_split(topi.generic.threefry_split), + wrap_topi_schedule(topi.generic.schedule_extern), + name="threefry_split.generic", + ) + return strategy diff --git a/python/tvm/topi/generic/__init__.py b/python/tvm/topi/generic/__init__.py index cc64abab8ed8..8bfa73542c09 100644 --- a/python/tvm/topi/generic/__init__.py +++ b/python/tvm/topi/generic/__init__.py @@ -39,3 +39,4 @@ from .sort import * from .search import * from .image import * +from .algorithm import * diff --git a/python/tvm/topi/generic/algorithm.py b/python/tvm/topi/generic/algorithm.py new file mode 100644 index 000000000000..aa9257673f1b --- /dev/null +++ b/python/tvm/topi/generic/algorithm.py @@ -0,0 +1,401 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Threefry PRNG with splitting based on +- J. K. Salmon, M. A. Moraes, R. O. Dror and D. E. Shaw, "Parallel random numbers: As easy as 1, 2, + 3," SC '11: Proceedings of 2011 International Conference for High Performance Computing, + Networking, Storage and Analysis, Seattle, WA, 2011, pp. 1-12, doi: 10.1145/2063384.2063405. +- Claessen, K. ; Palka, M. (2013) "Splittable Pseudorandom Number Generators using Cryptographic + Hashing". Proceedings of Haskell Symposium 2013 pp. 47-58. MLA +- Ferguson, Niels, et al. "The Skein hash function family." Submission to NIST (round 3) 7.7.5 + (2010): 3. + + +Threefry is a counter based PRNG: given a unique input, it generates a unique random number. As +there is no state to maintain, we can apply it to a sequence of numbers (0..N) to generate a +sequence of random numbers in parallel. In order to make the PRNG splittable (that is we can +generate a sequence of random numbers in one place, and another sequence in another), we add a path +and key in addition to the counter. The path allows us to encode a sequence of splits (a 0 in the +path indicates the left result of a split, a 1 indicates the right). To avoid continuously growing +the path, we can compress an existing path into the key portion of the generator by hashing the +current key, path, and counter to create the new key (this same technique is used if we run out of +room for the counter). + +This module use encoding e4 from the appendix of "Splittable Pseudorandom Number Generators using +Cryptographic Hashing" (confusingly, the definition in the paper uses e3 to define the encoding +function). This encoding uses a 10 element uint64 tensor where each byte has the following meaning: + +.. code-block: + + gen: + words: 0 1 2 3 | 4 5 | 6 7 | 8 9 + usage: key | path | counter | position of next step in path encoded in binary + ex: 0b00010 -> next path entry goes one from the right + +Right now, counter only uses the rightmost word. +""" +import tvm +import tvm.topi +from ... import tir +from ...tir import ir_builder + +# Threefry rotation constants from the Skein paper ("The Skein Hash Function Family" +# https://www.schneier.com/wp-content/uploads/2015/01/skein.pdf) +_ROTATIONS = { + 4: [[14, 16], [52, 57], [23, 40], [5, 37], [25, 33], [46, 12], [58, 22], [32, 32]], + 8: [ + [46, 36, 19, 37], + [33, 27, 14, 42], + [17, 49, 36, 39], + [44, 9, 54, 56], + [39, 30, 34, 24], + [13, 50, 10, 17], + [25, 29, 39, 43], + [8, 35, 56, 22], + ], + 16: [ + [24, 13, 8, 47, 8, 17, 22, 37], + [38, 19, 10, 55, 49, 18, 23, 52], + [33, 4, 51, 13, 34, 41, 59, 17], + [5, 20, 48, 41, 47, 28, 16, 25], + [41, 9, 37, 31, 12, 47, 44, 30], + [16, 34, 56, 51, 4, 53, 42, 41], + [31, 44, 47, 46, 19, 42, 44, 25], + [9, 48, 35, 52, 23, 31, 37, 20], + ], +} + +# Threefry permutation constants from the Skein paper ("The Skein Hash Function Family" +# https://www.schneier.com/wp-content/uploads/2015/01/skein.pdf) +_PERMUTATIONS = { + 4: [0, 3, 2, 1], + 8: [2, 1, 4, 7, 6, 5, 0, 3], + 16: [0, 9, 2, 13, 6, 11, 4, 15, 10, 7, 12, 3, 14, 5, 8, 1], +} + + +def _threefry( + irb, key_buf, key_offset, counter_buf, counter_offset, out_buf, out_offset, out_shape +): + """IRBuilder code for running Threefry + + Parameters + ---------- + irb: IRBuilder + IRBuilder that this code will be generated for. + + key_buf: BufferVar + Buffer to read the key from. + + key_offset: number + Threefry will write to key_buf[key_offset:key_offset+4] + + counter_buf: BufferVar + Buffer to read the counter from. + + counter_offset: number + Threefry will write to counter_buf[counter_offset:counter_offset+4] + + out_buf: BufferVar + Buffer to read the counter from. + + counter_offset: number + Threefry will write to out_buf[out_offset:out_offset+4*product(out_shape)] + + out_shape: number + Determines the number of ouput states to generate. state[i] will correspond to counter+i. + """ + nrounds = 20 + nwords = 4 + iwidth = 64 + assert nrounds % 4 == 0 + assert nwords in [4, 8, 16] + + assert key_buf.dtype == "uint64" # TODO: support 32 bit inputs + assert key_buf.dtype == counter_buf.dtype + + def mix(a, b, rotation): + x = a + b # TODO should be wrapping + y = x ^ ((b << rotation) | (b >> (iwidth - rotation))) + return [x, y] + + # temporary buffer for holding the results of _PERMUTATIONS + tmp = irb.allocate(out_buf.dtype, out_shape, name="tmp", scope="global") + tmp_offset = 0 + + # Initialize entire key. It is composed of the original key with one + # element appended. The appended element is the xor of all key words plus a + # constant. + full_key = irb.allocate("uint64", nwords + 1, name="full_key", scope="global") + for i in range(nwords): + full_key[i] = key_buf[key_offset + i] + # initial key constant, full_key[nwords] is equivalent to k_{N_W} in the Skein paper. + full_key[nwords] = tvm.tir.const(0x1BD11BDAA9FC1A22, dtype="uint64") + for i in range(nwords): + full_key[nwords] ^= key_buf[key_offset + i] # TODO: wrapping + + # TODO: overwrite counter instead? + with irb.for_range(0, out_shape, dtype="uint64", name="i") as i: + for j in range(nwords): + out_buf[out_offset + i * nwords + j] = counter_buf[counter_offset + j] + i + + def key_schedule(s, i): + # Threefry uses no tweak, so the key schedule is simple + if i == nwords - 1: + return full_key[(s + i) % (nwords + 1)] + tvm.tir.const(s, dtype="uint64") + return full_key[(s + i) % (nwords + 1)] + + with irb.for_range(0, out_shape, name="l") as l: # pylint: disable=invalid-name + for i in range(nrounds // 4): + for j in range(nwords): + out_buf[out_offset + l * nwords + j] += key_schedule(i, j) # TODO wrapping + for k in range(4): + for j in range(nwords // 2): + ( + out_buf[out_offset + l * nwords + j * 2 + 0], + out_buf[out_offset + l * nwords + j * 2 + 1], + ) = mix( + out_buf[out_offset + l * nwords + j * 2 + 0], + out_buf[out_offset + l * nwords + j * 2 + 1], + _ROTATIONS[nwords][(i * 4 + k) % 8][j], + ) + for j in range(nwords): + tmp[tmp_offset + l * nwords + j] = out_buf[ + out_offset + l * nwords + _PERMUTATIONS[nwords][j] + ] + # number of rounds is even, so out always contains the result + (out_buf, tmp) = (tmp, out_buf) + (out_offset, tmp_offset) = (tmp_offset, out_offset) + + +def threefry_generate(gen, out_shape): + """Generate a series of random values + + Notes + ----- + This function uses the counter portion of the generator state to generate a series of random + numbers in parallel. Random number `i` is generated by applying Threefry to the current + generator state with the counter portion incremented by `i`. This means that each random number + is generated independently from each other random number, so we can compute them in parallel. + + If there is not enough room left in the counter to generate the desired shape of random values, + then a new generator is created by applying Threefry to the current key, path, and counter. + This new generator will have a reset counter. + + Parameters + ---------- + gen : Tensor[10, uint64] + Generator state. Can be create with :py:func:`threefry_seed`. This should not be used in + another function, otherwise random numbers will be repeated. + + out_shape : Sequence[int] + Output shape of the random numbers. Product of all dimensions must be a multiple of 4. + + Returns + ------- + rand : Tensor[out_shape, uint64] + Tensor of random numbers with shape `out_shape`. + """ + out_len = 1 + for s in out_shape: + out_len *= s + assert ( + out_len.value % 4 == 0 + ), f"Threefry can only generate arrays who's size is a multiple of 4 ({out_len} was provided)." + assert ( + out_len.value <= 2 ** 64 - 1 + ), f"Can only generate up to 2^64 random numbers, but {out_len} were requested." + + def gen_ir(gen_ptr, out_gen_ptr, out_array_ptr): + irb = ir_builder.create() + gen = irb.buffer_ptr(gen_ptr) + out_gen = irb.buffer_ptr(out_gen_ptr) + out_array = irb.buffer_ptr(out_array_ptr) + + # Create a temporary array to hold the generator state we will use to create the random + # numbers. We cannot use gen because we may need to update the key + path if there is not + # enough room in the counter. + tmp = irb.allocate(gen.dtype, 10, name="tmp", scope="global") + + # TODO(tkonolige): for now we only use the last word of the counter for counting. Its too + # much work to figure out how to do 128 bit addition. + + # Max value for counter should be 2**64-2 because we need to reserve a special value to + # indicate the counter is used up. + with irb.if_scope(gen[7] < tir.const(2 ** 64 - 1, dtype=gen.dtype) - out_len): + for i in range(10): + tmp[i] = gen[i] + with irb.else_scope(): + # no room left in the counter, we have to change the path or key + with irb.if_scope(gen[8] == 0 and gen[9] == 0): + # out of room in the path, have to generate new key + + # The paper says the counter that we will be hashing should be a special value of + # all ones. We need to allocate some space for it because we cannot overwrite gen. + tmp_counter = irb.allocate(gen.dtype, 2, name="tmp_counter", scope="global") + tmp_counter[0] = tir.const(0xFFFFFFFFFFFFFFFF, dtype=gen.dtype) + tmp_counter[1] = tir.const(0xFFFFFFFFFFFFFFFF, dtype=gen.dtype) + _threefry(irb, gen, 0, tmp_counter, 0, tmp, 0, 1) + tmp[4] = tir.const(0, dtype=gen.dtype) # zero path, i.e. no path + tmp[5] = tir.const(0, dtype=gen.dtype) + tmp[6] = tir.const(0, dtype=gen.dtype) # zero counter + tmp[7] = tir.const(0, dtype=gen.dtype) + tmp[8] = tir.const(1 << 63, dtype=gen.dtype) # one in the leftmost position + tmp[9] = tir.const(0, dtype=gen.dtype) + with irb.else_scope(): + tmp[0] = gen[0] + tmp[1] = gen[1] + tmp[2] = gen[2] + tmp[3] = gen[3] + tmp[4] = gen[4] | gen[8] # add a 1 to the path + tmp[5] = gen[5] | gen[9] + tmp[6] = tir.const(0, dtype=gen.dtype) # zero counter + tmp[7] = tir.const(0, dtype=gen.dtype) + _shift_right(irb, gen[8], gen[9], tmp, 8, tmp, 9) + + # Compute random values + _threefry(irb, tmp, 0, tmp, 4, out_array, 0, out_len // 4) + + # Update generator state + out_gen[0] = tmp[0] # key stays the same + out_gen[1] = tmp[1] + out_gen[2] = tmp[2] + out_gen[3] = tmp[3] + out_gen[4] = tmp[4] # path stays the same + out_gen[5] = tmp[5] + out_gen[6] = tir.const(0, dtype=gen.dtype) # unused, leave it as 0 + out_gen[7] = tmp[7] + tir.Cast(gen.dtype, out_len) # increment counter + out_gen[8] = tmp[8] # path unchanged, so no update here + + return irb.get() + + out_gen = tvm.tir.decl_buffer((10,), name="out_gen", dtype="uint64") + out_array = tvm.tir.decl_buffer(out_shape, name="out_array", dtype="uint64") + return tvm.te.extern( + [out_gen.shape, out_array.shape], + [gen], + lambda ins, outs: gen_ir(ins[0], outs[0], outs[1]), + out_buffers=[out_gen, out_array], + name="threefry_generate", + tag="threefry_generate", + ) + + +def _shift_right(irb, a, b, out_a, a_off, out_b, b_off): + """Shift a 128bit number composed of two 64 bit words right by one""" + with irb.if_scope(a == 1): + out_a[a_off] = tir.const(0, dtype=a.dtype) + out_b[b_off] = tir.const(0x8000000000000000, dtype=a.dtype) + with irb.else_scope(): + with irb.if_scope(a == 0): + out_a[a_off] = tir.const(0, dtype=a.dtype) + out_b[b_off] = b >> 1 + with irb.else_scope(): + out_a[a_off] = a >> 1 + out_b[b_off] = tir.const(0, dtype=a.dtype) + + +def threefry_split(gen): + """Split a single generator state into two new ones + + Notes + ----- + The new generator is created by appending a one (for the right output) or a zero (for the left + output) to the end of the path portion of the generator If there is no longer and room in the + path, then we create a new key portion of the generator by applying Threefry to the old state, + path, and counter. i.e. :code:`new_key = threefry(old_key, [old_path, old_counter])`. This + resets the path portion of the new generator. + + Parameters + ---------- + gen : Tensor[10, uint64] + Generator state. Can be create with :py:func:`threefry_seed`. This should not be used in + another function, otherwise random numbers will be repeated. + + Returns + ------- + out_gen_left : Tensor[10, uint64] + New generator state that is distinct from `out_gen_right`. + + out_gen_right : Tensor[10, uint64] + New generator state that is distinct from `out_gen_left`. + """ + + def gen_ir(gen_ptr, out_left_ptr, out_right_ptr): + irb = ir_builder.create() + gen = irb.buffer_ptr(gen_ptr) + out_left = irb.buffer_ptr(out_left_ptr) + out_right = irb.buffer_ptr(out_right_ptr) + + with irb.if_scope(gen[8] == 0 and gen[9] == 0): + # Generate new key because we have run out of room to extend the path + _threefry(irb, gen, 0, gen, 4, out_left, 0, 1) + out_left[4] = tir.const(0, dtype=gen.dtype) + out_left[5] = tir.const(0, dtype=gen.dtype) + out_left[6] = tir.const(0, dtype=gen.dtype) # counter gets zeroed + out_left[7] = tir.const(0, dtype=gen.dtype) # counter gets zeroed + out_left[8] = tir.const( + 1 << 62, dtype=gen.dtype + ) # one in the second from the leftmost position + out_left[9] = tir.const(0, dtype=gen.dtype) + + out_right[0] = out_left[0] + out_right[1] = out_left[1] + out_right[2] = out_left[2] + out_right[3] = out_left[3] + out_right[4] = tir.const(1 << 63, dtype=gen.dtype) # one in the leftmost position + out_right[5] = tir.const(0, dtype=gen.dtype) + out_right[6] = tir.const(0, dtype=gen.dtype) + out_right[7] = tir.const(0, dtype=gen.dtype) + out_right[8] = tir.const( + 1 << 62, dtype=gen.dtype + ) # one in the second from the leftmost position + out_right[9] = tir.const(0, dtype=gen.dtype) + with irb.else_scope(): + out_left[0] = gen[0] + out_left[1] = gen[1] + out_left[2] = gen[2] + out_left[3] = gen[3] + out_left[4] = gen[4] # adding a zero here, but its already zero padded + out_left[5] = gen[5] + out_left[6] = gen[6] + out_left[7] = gen[7] + # move path position over one bit + _shift_right(irb, gen[8], gen[9], out_left, 8, out_left, 9) + + out_right[0] = gen[0] + out_right[1] = gen[1] + out_right[2] = gen[2] + out_right[3] = gen[3] + out_right[4] = gen[4] | gen[8] # add a one to the path + out_right[5] = gen[5] | gen[9] + out_right[6] = gen[6] + out_right[7] = gen[7] + _shift_right(irb, gen[8], gen[9], out_right, 8, out_right, 9) + + return irb.get() + + out_left = tvm.tir.decl_buffer((10,), name="out_left", dtype="uint64") + out_right = tvm.tir.decl_buffer((10,), name="out_right", dtype="uint64") + return tvm.te.extern( + [out_left.shape, out_right.shape], + [gen], + lambda ins, outs: gen_ir(ins[0], outs[0], outs[1]), + out_buffers=[out_left, out_right], + name="threefry_split", + tag="threefry_split", + ) diff --git a/src/relay/op/algorithm/prng.cc b/src/relay/op/algorithm/prng.cc new file mode 100644 index 000000000000..d95640fab85d --- /dev/null +++ b/src/relay/op/algorithm/prng.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs); + +bool ThreefryGenerateRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const ThreefryGenerateAttrs* param = attrs.as(); + ICHECK_EQ(types.size(), 2) << "ThreefryGenerate should have one input and one output"; + const auto* gen = types[0].as(); + + std::vector oshape; + for (auto& x : param->out_shape) { + oshape.push_back(x); + } + // generate returns the next gen and an array of random values + reporter->Assign(types[1], + TupleType({TensorType(gen->shape, gen->dtype), TensorType(oshape, gen->dtype)})); + return true; +} + +Expr MakeThreefryGenerate(Expr gen, Array out_shape) { + auto attrs = make_object(); + attrs->out_shape = out_shape; + static const Op& op = Op::Get("threefry_generate"); + return Call(op, {gen}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.threefry_generate").set_body_typed(MakeThreefryGenerate); + +RELAY_REGISTER_OP("threefry_generate") + .describe( + R"doc(Generate an array of random numbers using the Threefry algorithm.)doc" TVM_ADD_FILELINE) + .set_num_inputs(1) + .set_attrs_type() + .add_argument("gen", "Tensor", "Input generator") + .add_type_rel("ThreefryGenerate", ThreefryGenerateRel); + +bool ThreefrySplitRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2) << "ThreefrySplit should have one input and one output"; + const auto* gen = types[0].as(); + reporter->Assign(types[1], TupleType({TensorType(gen->shape, gen->dtype), + TensorType(gen->shape, gen->dtype)})); + return true; +} + +Expr MakeThreefrySplit(Expr gen) { + static const Op& op = Op::Get("threefry_split"); + return Call(op, {gen}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.threefry_split").set_body_typed(MakeThreefrySplit); + +RELAY_REGISTER_OP("threefry_split") + .describe( + R"doc(Split an array of random numbers using the Threefry algorithm.)doc" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("gen", "Tensor", "Input generator") + .add_type_rel("ThreefrySplit", ThreefrySplitRel); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_prng.py b/tests/python/relay/test_prng.py new file mode 100644 index 000000000000..2b014cd5aa74 --- /dev/null +++ b/tests/python/relay/test_prng.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.relay +import tvm.testing + + +@tvm.testing.parametrize_targets +def test_threefry_repeatablity(target, ctx): + seed1 = tvm.relay.threefry_seed(1) + rand1 = tvm.relay.threefry_generate(seed1, (12,)) + out_gen1, out1 = tvm.relay.create_executor( + "vm", tvm.IRModule.from_expr(tvm.relay.Function([], rand1)), target=target, ctx=ctx + ).evaluate()() + + seed2 = tvm.relay.threefry_seed(1) + rand2 = tvm.relay.threefry_generate(seed1, (12,)) + out_gen2, out2 = tvm.relay.create_executor( + "vm", tvm.IRModule.from_expr(tvm.relay.Function([], rand2)), target=target, ctx=ctx + ).evaluate()() + + assert ( + out1.asnumpy() == out2.asnumpy() + ).all(), "Generate on same seed should have the same output" + + +@tvm.testing.parametrize_targets +def test_threefry_split(target, ctx): + seed = tvm.relay.threefry_seed(1) + left, right = tvm.relay.TupleWrapper(tvm.relay.threefry_split(seed), 2) + _, rand1 = tvm.relay.TupleWrapper(tvm.relay.threefry_generate(left, (12,)), 2) + _, rand2 = tvm.relay.TupleWrapper(tvm.relay.threefry_generate(right, (12,)), 2) + out1, out2 = tvm.relay.create_executor( + "vm", + tvm.IRModule.from_expr(tvm.relay.Function([], tvm.relay.Tuple((rand1, rand2)))), + target=target, + ctx=ctx, + ).evaluate()() + + assert ( + out1.asnumpy() != out2.asnumpy() + ).any(), "Generate after split should not have the same output" + + +if __name__ == "__main__": + test_threefry_repeatablity(tvm.target.Target("llvm"), tvm.context("cpu")) + test_threefry_split(tvm.target.Target("llvm"), tvm.context("cpu")) diff --git a/tests/python/topi/python/test_topi_prng.py b/tests/python/topi/python/test_topi_prng.py new file mode 100644 index 000000000000..76f167043650 --- /dev/null +++ b/tests/python/topi/python/test_topi_prng.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.relay +import tvm.testing +import tvm.topi +import numpy as np + + +def threefry_split(target, ctx, gen): + gen_placeholder = tvm.te.placeholder(gen.shape, name="gen", dtype="uint64") + left_placeholder, right_placeholder = tvm.topi.generic.threefry_split(gen_placeholder) + s = tvm.topi.generic.schedule_extern([left_placeholder, right_placeholder]) + f = tvm.build(s, [gen_placeholder, left_placeholder, right_placeholder]) + left = tvm.nd.array(np.zeros(gen.shape, dtype="uint64")) + right = tvm.nd.array(np.zeros(gen.shape, dtype="uint64")) + f(tvm.nd.array(gen), left, right) + return left.asnumpy(), right.asnumpy() + + +def threefry_generate(target, ctx, gen, size): + gen_placeholder = tvm.te.placeholder(gen.shape, name="gen", dtype="uint64") + left_placeholder, right_placeholder = tvm.topi.generic.threefry_generate(gen_placeholder, size) + s = tvm.topi.generic.schedule_extern([left_placeholder, right_placeholder]) + f = tvm.build(s, [gen_placeholder, left_placeholder, right_placeholder]) + out_gen = tvm.nd.array(np.zeros(gen.shape, dtype="uint64")) + rands = tvm.nd.array(np.zeros(size, dtype="uint64")) + f(tvm.nd.array(gen), out_gen, rands) + return out_gen.asnumpy(), rands.asnumpy() + + +@tvm.testing.parametrize_targets +def test_threefry_split(target, ctx): + # test that results of split do not equal eachother or the input + gen = tvm.relay.threefry_seed(0).data.asnumpy() + a, b = threefry_split(target, ctx, gen) + assert (a != b).any() and ( + a != gen + ).any(), "Splitting a gen should result in different output gens" + # unittest some split inputs + assert (a == np.array([0, 0, 0, 0, 0, 0, 0, 0, 1 << 62, 0], dtype="uint64")).all() + assert (b == np.array([0, 0, 0, 0, 1 << 63, 0, 0, 0, 1 << 62, 0], dtype="uint64")).all() + + # test enough splits to go over path length + for i in range(129): + a, b = threefry_split(target, ctx, b) + assert (a[0:4] == b[0:4]).all(), "State part of split should be the same" + assert (b[0:4] != np.zeros(4, dtype="uint64")).any() + + # check that split then generate does not generate the same for both sides + a, a_rands = threefry_generate(target, ctx, a, (100,)) + b, b_rands = threefry_generate(target, ctx, b, (100,)) + assert ( + a_rands != b_rands + ).all(), "Numbers generated from different initial states should be different" + + # check repeatability + _, rands1 = threefry_generate(target, ctx, a, (100,)) + _, rands2 = threefry_generate(target, ctx, a, (100,)) + assert ( + rands1 == rands2 + ).all(), "Numbers generated from the same initial state should be the same" + + a1, b1 = threefry_split(target, ctx, a) + a2, b2 = threefry_split(target, ctx, a) + assert (a1 == a2).all() and ( + b1 == b2 + ).all(), "Split called on the same input should return the same result" + + +@tvm.testing.parametrize_targets +def test_threefry_generate(target, ctx): + gen = tvm.relay.threefry_seed(0).data.asnumpy() + + # check that we can generate some data + a, rands = threefry_generate(target, ctx, gen, (100,)) + assert ( + rands.shape[0] == 100 and len(rands.shape) == 1 + ), "Output shape should match requested shape" + + # check that gen out does not equal input + assert (a != gen).any(), "Output generator should be different from input generator" + + # test enough generates to go over generate limit + gen = np.array( + [0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 1 << 63, 0], dtype="uint64" + ) # make counter large + a, rands = threefry_generate(target, ctx, gen, (100,)) + assert gen[4] != a[4], "Overflow of counter should trigger path change" + assert a[7] == 100, "Overflow of counter should still update counter" + + # check generate with path at length limit + gen = np.array([0, 0, 0, 0, 0, 0, 0, 2 ** 64 - 2, 0, 0], dtype="uint64") # make counter large + a, rands = threefry_generate(target, ctx, gen, (100,)) + assert ( + gen[0:4] != a[0:4] + ).any(), "Overflowing counter with no space left in path should change state" + + +if __name__ == "__main__": + test_threefry_split(tvm.target.Target("llvm"), tvm.context("cpu")) + test_threefry_generate(tvm.target.Target("llvm"), tvm.context("cpu")) From c87658b8f6e3e50ddc37dcf2e88c9f189d39cc35 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 10 Dec 2020 15:16:46 -0800 Subject: [PATCH 02/15] Fix sphinx? --- python/tvm/topi/generic/algorithm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/generic/algorithm.py b/python/tvm/topi/generic/algorithm.py index aa9257673f1b..389805b221c8 100644 --- a/python/tvm/topi/generic/algorithm.py +++ b/python/tvm/topi/generic/algorithm.py @@ -102,22 +102,22 @@ def _threefry( Buffer to read the key from. key_offset: number - Threefry will write to key_buf[key_offset:key_offset+4] + Threefry will write to :code:`key_buf[key_offset:key_offset+4]` counter_buf: BufferVar Buffer to read the counter from. counter_offset: number - Threefry will write to counter_buf[counter_offset:counter_offset+4] + Threefry will write to :code:`counter_buf[counter_offset:counter_offset+4]` out_buf: BufferVar Buffer to read the counter from. counter_offset: number - Threefry will write to out_buf[out_offset:out_offset+4*product(out_shape)] + Threefry will write to :code:`out_buf[out_offset:out_offset+4*product(out_shape)]` out_shape: number - Determines the number of ouput states to generate. state[i] will correspond to counter+i. + Determines the number of ouput states to generate. :code:`state[i]` will correspond to counter+i. """ nrounds = 20 nwords = 4 @@ -199,7 +199,7 @@ def threefry_generate(gen, out_shape): Parameters ---------- gen : Tensor[10, uint64] - Generator state. Can be create with :py:func:`threefry_seed`. This should not be used in + Generator state. Can be create with :py:func:`tvm.relay.threefry_seed`. This should not be used in another function, otherwise random numbers will be repeated. out_shape : Sequence[int] @@ -323,7 +323,7 @@ def threefry_split(gen): Parameters ---------- gen : Tensor[10, uint64] - Generator state. Can be create with :py:func:`threefry_seed`. This should not be used in + Generator state. Can be create with :py:func:`tvm.relay.threefry_seed`. This should not be used in another function, otherwise random numbers will be repeated. Returns From 4c230feeba5cef77224f7a28ab3a7c48c3e1d513 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 10 Dec 2020 16:24:40 -0800 Subject: [PATCH 03/15] Lint fixes --- python/tvm/topi/generic/algorithm.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/tvm/topi/generic/algorithm.py b/python/tvm/topi/generic/algorithm.py index 389805b221c8..165729ab8ea7 100644 --- a/python/tvm/topi/generic/algorithm.py +++ b/python/tvm/topi/generic/algorithm.py @@ -117,7 +117,8 @@ def _threefry( Threefry will write to :code:`out_buf[out_offset:out_offset+4*product(out_shape)]` out_shape: number - Determines the number of ouput states to generate. :code:`state[i]` will correspond to counter+i. + Determines the number of ouput states to generate. :code:`state[i]` will correspond to + counter+i. """ nrounds = 20 nwords = 4 @@ -199,8 +200,8 @@ def threefry_generate(gen, out_shape): Parameters ---------- gen : Tensor[10, uint64] - Generator state. Can be create with :py:func:`tvm.relay.threefry_seed`. This should not be used in - another function, otherwise random numbers will be repeated. + Generator state. Can be create with :py:func:`tvm.relay.threefry_seed`. This should not be + used in another function, otherwise random numbers will be repeated. out_shape : Sequence[int] Output shape of the random numbers. Product of all dimensions must be a multiple of 4. @@ -323,8 +324,8 @@ def threefry_split(gen): Parameters ---------- gen : Tensor[10, uint64] - Generator state. Can be create with :py:func:`tvm.relay.threefry_seed`. This should not be used in - another function, otherwise random numbers will be repeated. + Generator state. Can be create with :py:func:`tvm.relay.threefry_seed`. This should not be + used in another function, otherwise random numbers will be repeated. Returns ------- From eaf07d6dc865395d769981b589fe5f25a77fdff6 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 14 Dec 2020 09:09:54 -0800 Subject: [PATCH 04/15] sphinx fixes round 2 --- python/tvm/relay/op/algorithm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 7b4d1b48978b..4bb81e4491b5 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -147,8 +147,8 @@ def threefry_generate(gen, shape): same generator will generate the same random values. This generator should be treated as an opaque pointer. You can create one from calling :py:func:`threefry_seed`, :py:func:`threefry_split`, or - :py:func:`threefry_generate`. _Do not use this generator again after calling - this function_. + :py:func:`threefry_generate`. **Do not use this generator again after calling + this function.** shape : Sequence[int] Desired outputs shape of random numbers @@ -191,8 +191,8 @@ def foo(gen): same generator will generate the same random values. This generator should be treated as an opaque pointer. You can create one from calling :py:func:`threefry_seed`, :py:func:`threefry_split`, or - :py:func:`threefry_generate`. _Do not use this generator again after calling - this function_. + :py:func:`threefry_generate`. **Do not use this generator again after calling + this function.** Returns ------- From d858d5e4dd5a6ae87d9b3c051b1460ef7d23bd43 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Mon, 14 Dec 2020 11:38:48 -0800 Subject: [PATCH 05/15] fix inputs for tests --- python/tvm/topi/generic/algorithm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/generic/algorithm.py b/python/tvm/topi/generic/algorithm.py index 165729ab8ea7..5f1588a69243 100644 --- a/python/tvm/topi/generic/algorithm.py +++ b/python/tvm/topi/generic/algorithm.py @@ -211,7 +211,7 @@ def threefry_generate(gen, out_shape): rand : Tensor[out_shape, uint64] Tensor of random numbers with shape `out_shape`. """ - out_len = 1 + out_len = tir.const(1) for s in out_shape: out_len *= s assert ( From abe300b87c0a40b2e945b972b184deb31c648fc6 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 16 Dec 2020 15:21:21 -0800 Subject: [PATCH 06/15] reorganize to random, fix uninitialized memory bug --- include/tvm/relay/attrs/algorithm.h | 8 -- include/tvm/relay/attrs/random.h | 43 ++++++ python/tvm/relay/__init__.py | 1 + python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/_algorithm.py | 7 - python/tvm/relay/op/algorithm.py | 111 --------------- python/tvm/relay/op/random/__init__.py | 20 +++ python/tvm/relay/op/random/_kernel.py | 29 ++++ python/tvm/relay/op/random/_make.py | 20 +++ python/tvm/relay/op/random/kernel.py | 134 ++++++++++++++++++ python/tvm/relay/op/strategy/generic.py | 4 +- python/tvm/topi/__init__.py | 1 + python/tvm/topi/generic/__init__.py | 1 - python/tvm/topi/random/__init__.py | 22 +++ .../algorithm.py => random/kernel.py} | 83 ++++++----- .../{algorithm/prng.cc => random/kernel.cc} | 45 +++--- tests/python/relay/test_prng.py | 31 ++-- tests/python/topi/python/test_topi_prng.py | 8 +- 18 files changed, 364 insertions(+), 205 deletions(-) create mode 100644 include/tvm/relay/attrs/random.h create mode 100644 python/tvm/relay/op/random/__init__.py create mode 100644 python/tvm/relay/op/random/_kernel.py create mode 100644 python/tvm/relay/op/random/_make.py create mode 100644 python/tvm/relay/op/random/kernel.py create mode 100644 python/tvm/topi/random/__init__.py rename python/tvm/topi/{generic/algorithm.py => random/kernel.py} (85%) rename src/relay/op/{algorithm/prng.cc => random/kernel.cc} (61%) diff --git a/include/tvm/relay/attrs/algorithm.h b/include/tvm/relay/attrs/algorithm.h index 6ecc5c23935c..83b4ddaead43 100644 --- a/include/tvm/relay/attrs/algorithm.h +++ b/include/tvm/relay/attrs/algorithm.h @@ -76,14 +76,6 @@ struct TopKAttrs : public tvm::AttrsNode { } }; -struct ThreefryGenerateAttrs : public tvm::AttrsNode { - Array out_shape; - - TVM_DECLARE_ATTRS(ThreefryGenerateAttrs, "relay.attrs.ThreefryGenerateAttrs") { - TVM_ATTR_FIELD(out_shape).describe("Shape of random numbers to generate"); - } -}; - } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_ALGORITHM_H_ diff --git a/include/tvm/relay/attrs/random.h b/include/tvm/relay/attrs/random.h new file mode 100644 index 000000000000..303f4069163d --- /dev/null +++ b/include/tvm/relay/attrs/random.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/vision.h + * \brief Auxiliary attributes for random operators. + */ +#ifndef TVM_RELAY_ATTRS_RANDOM_H_ +#define TVM_RELAY_ATTRS_RANDOM_H_ + +#include + + +namespace tvm { +namespace relay { + +struct ThreefryGenerateAttrs : public tvm::AttrsNode { + Array out_shape; + + TVM_DECLARE_ATTRS(ThreefryGenerateAttrs, "relay.attrs.ThreefryGenerateAttrs") { + TVM_ATTR_FIELD(out_shape).describe("Shape of random numbers to generate"); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_RANDOM_H_ diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index cd96ecc7ee33..97f6d1cb60c0 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -45,6 +45,7 @@ from .op import vision from .op import contrib from .op import dyn +from .op import random from .op.reduce import * from .op.tensor import * from .op.transform import * diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index f6afa443d280..1f267abedc1a 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -43,6 +43,7 @@ from . import image from . import vision from . import op_attrs +from . import random # operator registry diff --git a/python/tvm/relay/op/_algorithm.py b/python/tvm/relay/op/_algorithm.py index 5124f7112049..732d5016755a 100644 --- a/python/tvm/relay/op/_algorithm.py +++ b/python/tvm/relay/op/_algorithm.py @@ -73,10 +73,3 @@ def topk_shape_func(attrs, inputs, _): ret = [indices_out] return ret - - -# threefry -register_strategy("threefry_generate", strategy.threefry_generate_strategy) -register_pattern("threefry_generate", OpPattern.OPAQUE) -register_strategy("threefry_split", strategy.threefry_split_strategy) -register_pattern("threefry_split", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 4bb81e4491b5..22aac6890fff 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,10 +17,6 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs -import sys - -import numpy as np - from ... import nd from ..expr import Constant, Expr, TupleWrapper from . import _make @@ -98,110 +94,3 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): if ret_type == "both": return TupleWrapper(out, 2) return out - - -def threefry_seed(seed): - """Create a new Threefry random number generator. - - Example - ------- - - .. code-block:: python - - gen = threefry_seed(0) - _, random_number = threefry_generate(gen, (1,)) - - Parameters - ---------- - seed : int - Starting seed for the generator - - Returns - ------- - gen : relay.Expr - New generator to pass to future uses of :py:func:`threefry_split` or - :py:func:`threefry_generate`. - """ - s = np.frombuffer(seed.to_bytes(32, sys.byteorder), dtype="uint64") - a = np.concatenate((s, np.array([0, 0, 0, 0, 1 << 63, 0], dtype="uint64"))) - return Constant(nd.array(a)) - - -def threefry_generate(gen, shape): - """Generate an array of random numbers using the Threefry algorithm - - Example - ------- - - .. code-block:: python - - gen = threefry_seed(0) - new_gen, random1 = threefry_generate(gen, (1,)) - _, random2 = threefry_generate(new_gen, (1,)) - # random1 and random2 are different random numbers - - Parameters - ---------- - gen : relay.Expr - generator that uniquely determines the random values. Multiple uses with the - same generator will generate the same random values. This generator should be - treated as an opaque pointer. You can create one from calling - :py:func:`threefry_seed`, :py:func:`threefry_split`, or - :py:func:`threefry_generate`. **Do not use this generator again after calling - this function.** - - shape : Sequence[int] - Desired outputs shape of random numbers - - Returns - ------- - new_gen : relay.Expr - New generator to pass to future uses of :py:func:`threefry_split` or - :py:func:`threefry_generate`. - - random_array : relay.Expr - Array of random numbers. Has shape `shape`. - """ - return _make.threefry_generate(gen, shape) - - -def threefry_split(gen): - """Split an existing threefry generator into two new ones. - - This is useful if you have to subsequent calls which each need their own - random number generation. - - Example - ------- - - .. code-block:: python - - def foo(gen): - new_gen, num = threefry_generate(gen, (1,)) - return num - - gen = threefry_seed(0) - gen1, gen2 = threefry_split(gen) - assert foo(gen1) != foo(gen2) - - Parameters - ---------- - gen : relay.Expr - generator that uniquely determines the random values. Multiple uses with the - same generator will generate the same random values. This generator should be - treated as an opaque pointer. You can create one from calling - :py:func:`threefry_seed`, :py:func:`threefry_split`, or - :py:func:`threefry_generate`. **Do not use this generator again after calling - this function.** - - Returns - ------- - new_gen_1 : relay.Expr - New generator to pass to future uses of :py:func:`threefry_split` or - :py:func:`threefry_generate`. - - new_gen_2 : relay.Expr - New generator to pass to future uses of :py:func:`threefry_split` or - :py:func:`threefry_generate`. - """ - return _make.threefry_split(gen) diff --git a/python/tvm/relay/op/random/__init__.py b/python/tvm/relay/op/random/__init__.py new file mode 100644 index 000000000000..8366f4a06dac --- /dev/null +++ b/python/tvm/relay/op/random/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""PRNG related operators.""" +from .kernel import * +from . import _kernel diff --git a/python/tvm/relay/op/random/_kernel.py b/python/tvm/relay/op/random/_kernel.py new file mode 100644 index 000000000000..8be3397008d5 --- /dev/null +++ b/python/tvm/relay/op/random/_kernel.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Splittable and parallelizable PRNG kernels.""" +# pylint: disable=invalid-name,unused-argument +from __future__ import absolute_import + +from .. import strategy +from ..op import register_strategy, register_pattern, OpPattern + + +# Threefry +register_strategy("random.threefry_generate", strategy.threefry_generate_strategy) +register_pattern("random.threefry_generate", OpPattern.OPAQUE) +register_strategy("random.threefry_split", strategy.threefry_split_strategy) +register_pattern("random.threefry_split", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/random/_make.py b/python/tvm/relay/op/random/_make.py new file mode 100644 index 000000000000..51a8a6aa9339 --- /dev/null +++ b/python/tvm/relay/op/random/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +import tvm._ffi + +tvm._ffi._init_api("relay.op.random._make", __name__) diff --git a/python/tvm/relay/op/random/kernel.py b/python/tvm/relay/op/random/kernel.py new file mode 100644 index 000000000000..d9dd796d24b7 --- /dev/null +++ b/python/tvm/relay/op/random/kernel.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Splittable and parallelizable PRNG kernels.""" +# pylint: disable=invalid-name,unused-argument +from __future__ import absolute_import + +import sys +import numpy as np + +from ...expr import Constant +from .... import nd +from . import _make + + +def threefry_key(seed): + """Create a new Threefry random number generator key. + + Example + ------- + + .. code-block:: python + + gen = threefry_key(0) + _, random_number = threefry_generate(gen, (4,)) + + Parameters + ---------- + seed : int + Starting seed for the key + + Returns + ------- + key : relay.Expr + New key to pass to future uses of :py:func:`threefry_split` or + :py:func:`threefry_generate`. + """ + s = np.frombuffer(seed.to_bytes(32, sys.byteorder), dtype="uint64") + a = np.concatenate((s, np.array([0, 0, 0, 0, 1 << 63, 0], dtype="uint64"))) + return Constant(nd.array(a)) + + +def threefry_generate(key, shape): + """Generate an array of random bits (`uint64`) using the Threefry algorithm + + Example + ------- + + .. code-block:: python + + key = threefry_key(0) + new_key, random1 = threefry_generate(key, (4,)) + _, random2 = threefry_generate(new_key, (4,)) + # random1 and random2 are different random numbers + + Parameters + ---------- + key : relay.Expr + key that uniquely determines the random values. Multiple uses with the + same key will generate the same random values. This key should be + treated as an opaque pointer. You can create one from calling + :py:func:`threefry_key`, :py:func:`threefry_split`, or + :py:func:`threefry_generate`. **Do not use this key again after calling + this function.** + + shape : Sequence[int] + Desired outputs shape of random numbers. **Currently the total + number of elements must be a multiple of 4.** + + Returns + ------- + new_key : relay.Expr + New key to pass to future uses of :py:func:`threefry_split` or + :py:func:`threefry_generate`. + + random_array : relay.Expr + Array of random numbers. Has shape `shape`. + """ + return _make.threefry_generate(key, shape) + + +def threefry_split(key): + """Split an existing Threefry key into two new ones. + + This is useful if you have to subsequent calls which each need their own + independent random number generation. + + Example + ------- + + .. code-block:: python + + def foo(key): + new_key, num = threefry_generate(key, (1,)) + return num + + key = threefry_key(0) + key1, key2 = threefry_split(key) + assert foo(key1) != foo(key2) + + Parameters + ---------- + key : relay.Expr + key that uniquely determines the random values. Multiple uses with the + same generator will generate the same random values. This generator should be + treated as an opaque pointer. You can create one from calling + :py:func:`threefry_key`, :py:func:`threefry_split`, or + :py:func:`threefry_generate`. **Do not use this generator again after calling + this function.** + + Returns + ------- + new_key_1 : relay.Expr + New key to pass to future uses of :py:func:`threefry_split` or + :py:func:`threefry_generate`. + + new_key_2 : relay.Expr + New key to pass to future uses of :py:func:`threefry_split` or + :py:func:`threefry_generate`. + """ + return _make.threefry_split(key) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 42e882dd212b..c111676e6ea8 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1281,7 +1281,7 @@ def threefry_generate_strategy(attrs, inputs, out_type, target): """threefry_generate generic strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_threefry_generate(topi.generic.threefry_generate), + wrap_compute_threefry_generate(topi.random.threefry_generate), wrap_topi_schedule(topi.generic.schedule_extern), name="threefry_generate.generic", ) @@ -1303,7 +1303,7 @@ def threefry_split_strategy(attrs, inputs, out_type, target): """threefry_split generic strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_threefry_split(topi.generic.threefry_split), + wrap_compute_threefry_split(topi.random.threefry_split), wrap_topi_schedule(topi.generic.schedule_extern), name="threefry_split.generic", ) diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 97951d941f64..cb94b5b86c9e 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -54,6 +54,7 @@ from . import image from . import sparse from . import hls +from . import random # error reporting from .utils import InvalidShapeError diff --git a/python/tvm/topi/generic/__init__.py b/python/tvm/topi/generic/__init__.py index 8bfa73542c09..cc64abab8ed8 100644 --- a/python/tvm/topi/generic/__init__.py +++ b/python/tvm/topi/generic/__init__.py @@ -39,4 +39,3 @@ from .sort import * from .search import * from .image import * -from .algorithm import * diff --git a/python/tvm/topi/random/__init__.py b/python/tvm/topi/random/__init__.py new file mode 100644 index 000000000000..ee8d1d6385b7 --- /dev/null +++ b/python/tvm/topi/random/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=wildcard-import +"""Pseudorandom generator kernels and operators.""" +from __future__ import absolute_import + +from .kernel import * diff --git a/python/tvm/topi/generic/algorithm.py b/python/tvm/topi/random/kernel.py similarity index 85% rename from python/tvm/topi/generic/algorithm.py rename to python/tvm/topi/random/kernel.py index 5f1588a69243..f1aac3e68d14 100644 --- a/python/tvm/topi/generic/algorithm.py +++ b/python/tvm/topi/random/kernel.py @@ -14,45 +14,46 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -Threefry PRNG with splitting based on -- J. K. Salmon, M. A. Moraes, R. O. Dror and D. E. Shaw, "Parallel random numbers: As easy as 1, 2, - 3," SC '11: Proceedings of 2011 International Conference for High Performance Computing, - Networking, Storage and Analysis, Seattle, WA, 2011, pp. 1-12, doi: 10.1145/2063384.2063405. -- Claessen, K. ; Palka, M. (2013) "Splittable Pseudorandom Number Generators using Cryptographic - Hashing". Proceedings of Haskell Symposium 2013 pp. 47-58. MLA -- Ferguson, Niels, et al. "The Skein hash function family." Submission to NIST (round 3) 7.7.5 - (2010): 3. - - -Threefry is a counter based PRNG: given a unique input, it generates a unique random number. As -there is no state to maintain, we can apply it to a sequence of numbers (0..N) to generate a -sequence of random numbers in parallel. In order to make the PRNG splittable (that is we can -generate a sequence of random numbers in one place, and another sequence in another), we add a path -and key in addition to the counter. The path allows us to encode a sequence of splits (a 0 in the -path indicates the left result of a split, a 1 indicates the right). To avoid continuously growing -the path, we can compress an existing path into the key portion of the generator by hashing the -current key, path, and counter to create the new key (this same technique is used if we run out of -room for the counter). - -This module use encoding e4 from the appendix of "Splittable Pseudorandom Number Generators using -Cryptographic Hashing" (confusingly, the definition in the paper uses e3 to define the encoding -function). This encoding uses a 10 element uint64 tensor where each byte has the following meaning: - -.. code-block: - - gen: - words: 0 1 2 3 | 4 5 | 6 7 | 8 9 - usage: key | path | counter | position of next step in path encoded in binary - ex: 0b00010 -> next path entry goes one from the right - -Right now, counter only uses the rightmost word. -""" +"""Pseudorandom number kernels.""" import tvm import tvm.topi from ... import tir from ...tir import ir_builder + +# Threefry PRNG with splitting based on +# - J. K. Salmon, M. A. Moraes, R. O. Dror and D. E. Shaw, "Parallel random numbers: As easy as 1, 2, +# 3," SC '11: Proceedings of 2011 International Conference for High Performance Computing, +# Networking, Storage and Analysis, Seattle, WA, 2011, pp. 1-12, doi: 10.1145/2063384.2063405. +# - Claessen, K. ; Palka, M. (2013) "Splittable Pseudorandom Number Generators using Cryptographic +# Hashing". Proceedings of Haskell Symposium 2013 pp. 47-58. MLA +# - Ferguson, Niels, et al. "The Skein hash function family." Submission to NIST (round 3) 7.7.5 +# (2010): 3. + + +# Threefry is a counter based PRNG: given a unique input, it generates a unique random number. As +# there is no state to maintain, we can apply it to a sequence of numbers (0..N) to generate a +# sequence of random numbers in parallel. In order to make the PRNG splittable (that is we can +# generate a sequence of random numbers in one place, and another sequence in another), we add a path +# and key in addition to the counter. The path allows us to encode a sequence of splits (a 0 in the +# path indicates the left result of a split, a 1 indicates the right). To avoid continuously growing +# the path, we can compress an existing path into the key portion of the generator by hashing the +# current key, path, and counter to create the new key (this same technique is used if we run out of +# room for the counter). + +# This module use encoding e4 from the appendix of "Splittable Pseudorandom Number Generators using +# Cryptographic Hashing" (confusingly, the definition in the paper uses e3 to define the encoding +# function). This encoding uses a 10 element uint64 tensor where each byte has the following meaning: + +# .. code-block: + +# gen: +# words: 0 1 2 3 | 4 5 | 6 7 | 8 9 +# usage: key | path | counter | position of next step in path encoded in binary +# ex: 0b00010 -> next path entry goes one from the right + +# Right now, counter only uses the rightmost word. + # Threefry rotation constants from the Skein paper ("The Skein Hash Function Family" # https://www.schneier.com/wp-content/uploads/2015/01/skein.pdf) _ROTATIONS = { @@ -113,7 +114,7 @@ def _threefry( out_buf: BufferVar Buffer to read the counter from. - counter_offset: number + out_offset: number Threefry will write to :code:`out_buf[out_offset:out_offset+4*product(out_shape)]` out_shape: number @@ -200,14 +201,17 @@ def threefry_generate(gen, out_shape): Parameters ---------- gen : Tensor[10, uint64] - Generator state. Can be create with :py:func:`tvm.relay.threefry_seed`. This should not be - used in another function, otherwise random numbers will be repeated. + Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be + reused in another function, otherwise random numbers will be repeated. out_shape : Sequence[int] Output shape of the random numbers. Product of all dimensions must be a multiple of 4. Returns ------- + new_gen : Tensor[10, uint64] + The new generator state to be used in subsequent calls. + rand : Tensor[out_shape, uint64] Tensor of random numbers with shape `out_shape`. """ @@ -281,6 +285,7 @@ def gen_ir(gen_ptr, out_gen_ptr, out_array_ptr): out_gen[6] = tir.const(0, dtype=gen.dtype) # unused, leave it as 0 out_gen[7] = tmp[7] + tir.Cast(gen.dtype, out_len) # increment counter out_gen[8] = tmp[8] # path unchanged, so no update here + out_gen[9] = tmp[9] return irb.get() @@ -324,8 +329,8 @@ def threefry_split(gen): Parameters ---------- gen : Tensor[10, uint64] - Generator state. Can be create with :py:func:`tvm.relay.threefry_seed`. This should not be - used in another function, otherwise random numbers will be repeated. + Generator state. Can be create with :py:func:`tvm.relay.threefry_key`. This should not be + reused in another function, otherwise random numbers will be repeated. Returns ------- diff --git a/src/relay/op/algorithm/prng.cc b/src/relay/op/random/kernel.cc similarity index 61% rename from src/relay/op/algorithm/prng.cc rename to src/relay/op/random/kernel.cc index d95640fab85d..a955d29caa38 100644 --- a/src/relay/op/algorithm/prng.cc +++ b/src/relay/op/random/kernel.cc @@ -17,7 +17,7 @@ * under the License. */ -#include +#include #include namespace tvm { @@ -29,56 +29,61 @@ bool ThreefryGenerateRel(const Array& types, int num_inputs, const Attrs& const TypeReporter& reporter) { const ThreefryGenerateAttrs* param = attrs.as(); ICHECK_EQ(types.size(), 2) << "ThreefryGenerate should have one input and one output"; - const auto* gen = types[0].as(); + const auto* key = types[0].as(); + + if (key == nullptr) return false; std::vector oshape; for (auto& x : param->out_shape) { oshape.push_back(x); } - // generate returns the next gen and an array of random values + // generate returns the next key and an array of random values reporter->Assign(types[1], - TupleType({TensorType(gen->shape, gen->dtype), TensorType(oshape, gen->dtype)})); + TupleType({TensorType(key->shape, key->dtype), TensorType(oshape, key->dtype)})); return true; } -Expr MakeThreefryGenerate(Expr gen, Array out_shape) { +Expr MakeThreefryGenerate(Expr key, Array out_shape) { auto attrs = make_object(); attrs->out_shape = out_shape; - static const Op& op = Op::Get("threefry_generate"); - return Call(op, {gen}, Attrs(attrs), {}); + static const Op& op = Op::Get("random.threefry_generate"); + return Call(op, {key}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.threefry_generate").set_body_typed(MakeThreefryGenerate); +TVM_REGISTER_GLOBAL("relay.op.random._make.threefry_generate").set_body_typed(MakeThreefryGenerate); -RELAY_REGISTER_OP("threefry_generate") +RELAY_REGISTER_OP("random.threefry_generate") .describe( R"doc(Generate an array of random numbers using the Threefry algorithm.)doc" TVM_ADD_FILELINE) .set_num_inputs(1) .set_attrs_type() - .add_argument("gen", "Tensor", "Input generator") + .add_argument("key", "Tensor", "Input Threefry key") .add_type_rel("ThreefryGenerate", ThreefryGenerateRel); bool ThreefrySplitRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2) << "ThreefrySplit should have one input and one output"; - const auto* gen = types[0].as(); - reporter->Assign(types[1], TupleType({TensorType(gen->shape, gen->dtype), - TensorType(gen->shape, gen->dtype)})); + const auto* key = types[0].as(); + + if (key == nullptr) return false; + + reporter->Assign(types[1], TupleType({TensorType(key->shape, key->dtype), + TensorType(key->shape, key->dtype)})); return true; } -Expr MakeThreefrySplit(Expr gen) { - static const Op& op = Op::Get("threefry_split"); - return Call(op, {gen}, Attrs(), {}); +Expr MakeThreefrySplit(Expr key) { + static const Op& op = Op::Get("random.threefry_split"); + return Call(op, {key}, Attrs(), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.threefry_split").set_body_typed(MakeThreefrySplit); +TVM_REGISTER_GLOBAL("relay.op.random._make.threefry_split").set_body_typed(MakeThreefrySplit); -RELAY_REGISTER_OP("threefry_split") +RELAY_REGISTER_OP("random.threefry_split") .describe( - R"doc(Split an array of random numbers using the Threefry algorithm.)doc" TVM_ADD_FILELINE) + R"doc(Split the input Threefry key into two new ones.)doc" TVM_ADD_FILELINE) .set_num_inputs(1) - .add_argument("gen", "Tensor", "Input generator") + .add_argument("key", "Tensor", "Input Threefry key") .add_type_rel("ThreefrySplit", ThreefrySplitRel); } // namespace relay diff --git a/tests/python/relay/test_prng.py b/tests/python/relay/test_prng.py index 2b014cd5aa74..8770d1baa9c8 100644 --- a/tests/python/relay/test_prng.py +++ b/tests/python/relay/test_prng.py @@ -20,30 +20,35 @@ @tvm.testing.parametrize_targets -def test_threefry_repeatablity(target, ctx): - seed1 = tvm.relay.threefry_seed(1) - rand1 = tvm.relay.threefry_generate(seed1, (12,)) - out_gen1, out1 = tvm.relay.create_executor( +def test_threefry_repeatability(target, ctx): + target, ctx = "llvm", tvm.cpu(0) + key1 = tvm.relay.random.threefry_key(1) + rand1 = tvm.relay.random.threefry_generate(key1, (12,)) + out_key1, out1 = tvm.relay.create_executor( "vm", tvm.IRModule.from_expr(tvm.relay.Function([], rand1)), target=target, ctx=ctx ).evaluate()() - seed2 = tvm.relay.threefry_seed(1) - rand2 = tvm.relay.threefry_generate(seed1, (12,)) - out_gen2, out2 = tvm.relay.create_executor( + key2 = tvm.relay.random.threefry_key(1) + rand2 = tvm.relay.random.threefry_generate(key2, (12,)) + out_key2, out2 = tvm.relay.create_executor( "vm", tvm.IRModule.from_expr(tvm.relay.Function([], rand2)), target=target, ctx=ctx ).evaluate()() assert ( out1.asnumpy() == out2.asnumpy() - ).all(), "Generate on same seed should have the same output" + ).all(), "Generate on same seed should have the same output random numbers" + + assert ( + out_key1.asnumpy() == out_key2.asnumpy() + ).all(), "Generate on same seed should have the same next keys" @tvm.testing.parametrize_targets def test_threefry_split(target, ctx): - seed = tvm.relay.threefry_seed(1) - left, right = tvm.relay.TupleWrapper(tvm.relay.threefry_split(seed), 2) - _, rand1 = tvm.relay.TupleWrapper(tvm.relay.threefry_generate(left, (12,)), 2) - _, rand2 = tvm.relay.TupleWrapper(tvm.relay.threefry_generate(right, (12,)), 2) + key = tvm.relay.random.threefry_key(1) + left, right = tvm.relay.TupleWrapper(tvm.relay.random.threefry_split(key), 2) + _, rand1 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(left, (12,)), 2) + _, rand2 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(right, (12,)), 2) out1, out2 = tvm.relay.create_executor( "vm", tvm.IRModule.from_expr(tvm.relay.Function([], tvm.relay.Tuple((rand1, rand2)))), @@ -57,5 +62,5 @@ def test_threefry_split(target, ctx): if __name__ == "__main__": - test_threefry_repeatablity(tvm.target.Target("llvm"), tvm.context("cpu")) + test_threefry_repeatability(tvm.target.Target("llvm"), tvm.context("cpu")) test_threefry_split(tvm.target.Target("llvm"), tvm.context("cpu")) diff --git a/tests/python/topi/python/test_topi_prng.py b/tests/python/topi/python/test_topi_prng.py index 76f167043650..43b0494ee6f5 100644 --- a/tests/python/topi/python/test_topi_prng.py +++ b/tests/python/topi/python/test_topi_prng.py @@ -23,7 +23,7 @@ def threefry_split(target, ctx, gen): gen_placeholder = tvm.te.placeholder(gen.shape, name="gen", dtype="uint64") - left_placeholder, right_placeholder = tvm.topi.generic.threefry_split(gen_placeholder) + left_placeholder, right_placeholder = tvm.topi.random.threefry_split(gen_placeholder) s = tvm.topi.generic.schedule_extern([left_placeholder, right_placeholder]) f = tvm.build(s, [gen_placeholder, left_placeholder, right_placeholder]) left = tvm.nd.array(np.zeros(gen.shape, dtype="uint64")) @@ -34,7 +34,7 @@ def threefry_split(target, ctx, gen): def threefry_generate(target, ctx, gen, size): gen_placeholder = tvm.te.placeholder(gen.shape, name="gen", dtype="uint64") - left_placeholder, right_placeholder = tvm.topi.generic.threefry_generate(gen_placeholder, size) + left_placeholder, right_placeholder = tvm.topi.random.threefry_generate(gen_placeholder, size) s = tvm.topi.generic.schedule_extern([left_placeholder, right_placeholder]) f = tvm.build(s, [gen_placeholder, left_placeholder, right_placeholder]) out_gen = tvm.nd.array(np.zeros(gen.shape, dtype="uint64")) @@ -46,7 +46,7 @@ def threefry_generate(target, ctx, gen, size): @tvm.testing.parametrize_targets def test_threefry_split(target, ctx): # test that results of split do not equal eachother or the input - gen = tvm.relay.threefry_seed(0).data.asnumpy() + gen = tvm.relay.random.threefry_key(0).data.asnumpy() a, b = threefry_split(target, ctx, gen) assert (a != b).any() and ( a != gen @@ -84,7 +84,7 @@ def test_threefry_split(target, ctx): @tvm.testing.parametrize_targets def test_threefry_generate(target, ctx): - gen = tvm.relay.threefry_seed(0).data.asnumpy() + gen = tvm.relay.random.threefry_key(0).data.asnumpy() # check that we can generate some data a, rands = threefry_generate(target, ctx, gen, (100,)) From 54a7de026db7fffde1664a52175cc002adbbc9fa Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 16 Dec 2020 15:32:07 -0800 Subject: [PATCH 07/15] silence linter --- include/tvm/relay/attrs/random.h | 1 - src/relay/op/random/kernel.cc | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/include/tvm/relay/attrs/random.h b/include/tvm/relay/attrs/random.h index 303f4069163d..8238f102dab8 100644 --- a/include/tvm/relay/attrs/random.h +++ b/include/tvm/relay/attrs/random.h @@ -26,7 +26,6 @@ #include - namespace tvm { namespace relay { diff --git a/src/relay/op/random/kernel.cc b/src/relay/op/random/kernel.cc index a955d29caa38..6412e1c082fd 100644 --- a/src/relay/op/random/kernel.cc +++ b/src/relay/op/random/kernel.cc @@ -80,8 +80,7 @@ Expr MakeThreefrySplit(Expr key) { TVM_REGISTER_GLOBAL("relay.op.random._make.threefry_split").set_body_typed(MakeThreefrySplit); RELAY_REGISTER_OP("random.threefry_split") - .describe( - R"doc(Split the input Threefry key into two new ones.)doc" TVM_ADD_FILELINE) + .describe(R"doc(Split the input Threefry key into two new ones.)doc" TVM_ADD_FILELINE) .set_num_inputs(1) .add_argument("key", "Tensor", "Input Threefry key") .add_type_rel("ThreefrySplit", ThreefrySplitRel); From 6694be3b5c23ed3bfdd7fba5498d9b533de2003e Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 16 Dec 2020 15:48:10 -0800 Subject: [PATCH 08/15] silence linter even further --- python/tvm/topi/random/kernel.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py index f1aac3e68d14..d0c2d5c3dd34 100644 --- a/python/tvm/topi/random/kernel.py +++ b/python/tvm/topi/random/kernel.py @@ -22,8 +22,8 @@ # Threefry PRNG with splitting based on -# - J. K. Salmon, M. A. Moraes, R. O. Dror and D. E. Shaw, "Parallel random numbers: As easy as 1, 2, -# 3," SC '11: Proceedings of 2011 International Conference for High Performance Computing, +# - J. K. Salmon, M. A. Moraes, R. O. Dror and D. E. Shaw, "Parallel random numbers: As easy as 1, +# 2, 3," SC '11: Proceedings of 2011 International Conference for High Performance Computing, # Networking, Storage and Analysis, Seattle, WA, 2011, pp. 1-12, doi: 10.1145/2063384.2063405. # - Claessen, K. ; Palka, M. (2013) "Splittable Pseudorandom Number Generators using Cryptographic # Hashing". Proceedings of Haskell Symposium 2013 pp. 47-58. MLA @@ -34,16 +34,16 @@ # Threefry is a counter based PRNG: given a unique input, it generates a unique random number. As # there is no state to maintain, we can apply it to a sequence of numbers (0..N) to generate a # sequence of random numbers in parallel. In order to make the PRNG splittable (that is we can -# generate a sequence of random numbers in one place, and another sequence in another), we add a path -# and key in addition to the counter. The path allows us to encode a sequence of splits (a 0 in the -# path indicates the left result of a split, a 1 indicates the right). To avoid continuously growing -# the path, we can compress an existing path into the key portion of the generator by hashing the -# current key, path, and counter to create the new key (this same technique is used if we run out of -# room for the counter). +# generate a sequence of random numbers in one place, and another sequence in another), we add a +# path and key in addition to the counter. The path allows us to encode a sequence of splits (a 0 in +# the path indicates the left result of a split, a 1 indicates the right). To avoid continuously +# growing the path, we can compress an existing path into the key portion of the generator by +# hashing the current key, path, and counter to create the new key (this same technique is used if +# we run out of room for the counter). # This module use encoding e4 from the appendix of "Splittable Pseudorandom Number Generators using # Cryptographic Hashing" (confusingly, the definition in the paper uses e3 to define the encoding -# function). This encoding uses a 10 element uint64 tensor where each byte has the following meaning: +# function). This encoding uses a 10 element uint64 tensor where each byte means the following: # .. code-block: From 374e944458cea03f12324253ce439f868ff0f0a7 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 16 Dec 2020 15:51:31 -0800 Subject: [PATCH 09/15] s --- python/tvm/relay/op/algorithm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index 22aac6890fff..e2609cdceaff 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,7 +17,6 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs -from ... import nd from ..expr import Constant, Expr, TupleWrapper from . import _make from .dyn import _make as _dyn_make From ee949a74a8f0e7fb8366d3601f83efb7b2b1ca3e Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 17 Dec 2020 13:49:01 -0800 Subject: [PATCH 10/15] strengthen Threefry key type checking, add tests --- src/relay/op/random/kernel.cc | 14 +++++------ tests/python/relay/test_prng.py | 42 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/relay/op/random/kernel.cc b/src/relay/op/random/kernel.cc index 6412e1c082fd..2a86955b662f 100644 --- a/src/relay/op/random/kernel.cc +++ b/src/relay/op/random/kernel.cc @@ -25,21 +25,23 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs); +static const TensorType THREEFRY_KEY_TYPE = TensorType({10}, tvm::DataType::UInt(64)); + bool ThreefryGenerateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const ThreefryGenerateAttrs* param = attrs.as(); ICHECK_EQ(types.size(), 2) << "ThreefryGenerate should have one input and one output"; - const auto* key = types[0].as(); - if (key == nullptr) return false; + reporter->Assign(types[0], THREEFRY_KEY_TYPE); std::vector oshape; for (auto& x : param->out_shape) { oshape.push_back(x); } // generate returns the next key and an array of random values + // TODO(@tkonolige, @altanh): support other output dtypes? reporter->Assign(types[1], - TupleType({TensorType(key->shape, key->dtype), TensorType(oshape, key->dtype)})); + TupleType({THREEFRY_KEY_TYPE, TensorType(oshape, tvm::DataType::UInt(64))})); return true; } @@ -63,12 +65,10 @@ RELAY_REGISTER_OP("random.threefry_generate") bool ThreefrySplitRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2) << "ThreefrySplit should have one input and one output"; - const auto* key = types[0].as(); - if (key == nullptr) return false; + reporter->Assign(types[0], THREEFRY_KEY_TYPE); + reporter->Assign(types[1], TupleType({THREEFRY_KEY_TYPE, THREEFRY_KEY_TYPE})); - reporter->Assign(types[1], TupleType({TensorType(key->shape, key->dtype), - TensorType(key->shape, key->dtype)})); return true; } diff --git a/tests/python/relay/test_prng.py b/tests/python/relay/test_prng.py index 8770d1baa9c8..63d526415ca3 100644 --- a/tests/python/relay/test_prng.py +++ b/tests/python/relay/test_prng.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm import tvm.relay import tvm.testing +from tvm.relay.testing import run_infer_type @tvm.testing.parametrize_targets @@ -61,6 +63,46 @@ def test_threefry_split(target, ctx): ).any(), "Generate after split should not have the same output" +def test_threefry_generate_infer(): + oshape = (12,) + key_type = tvm.relay.TensorType([10], dtype="uint64") + gen_type = tvm.relay.TensorType(oshape, dtype="uint64") + expected_type = tvm.relay.TupleType([key_type, gen_type]) + + key = tvm.relay.random.threefry_key(1) + rand1 = tvm.relay.random.threefry_generate(key, oshape) + f = tvm.relay.Function([], rand1) + f = run_infer_type(f) + assert tvm.ir.structural_equal(f.ret_type, expected_type) + + +def test_threefry_split_infer(): + key_type = tvm.relay.TensorType([10], dtype="uint64") + expected_type = tvm.relay.TupleType([key_type, key_type]) + + key = tvm.relay.random.threefry_key(1) + out_keys = tvm.relay.random.threefry_split(key) + f = tvm.relay.Function([], out_keys) + f = run_infer_type(f) + assert tvm.ir.structural_equal(f.ret_type, expected_type) + + +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_threefry_generate_infer_fail(): + fake_key = tvm.relay.const([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="uint64") + rand1 = tvm.relay.random.threefry_generate(fake_key, (12,)) + f = tvm.relay.Function([], rand1) + f = run_infer_type(f) + + +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_threefry_split_infer_fail(): + fake_key = tvm.relay.const([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="uint64") + out_keys = tvm.relay.random.threefry_split(fake_key) + f = tvm.relay.Function([], out_keys) + f = run_infer_type(f) + + if __name__ == "__main__": test_threefry_repeatability(tvm.target.Target("llvm"), tvm.context("cpu")) test_threefry_split(tvm.target.Target("llvm"), tvm.context("cpu")) From 930282f9f9b1d181bafd1e9fb6bd0842c9605ccc Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 17 Dec 2020 14:48:05 -0800 Subject: [PATCH 11/15] replace static variable with function for Threefry key type --- src/relay/op/random/kernel.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/relay/op/random/kernel.cc b/src/relay/op/random/kernel.cc index 2a86955b662f..df1f4ab66af3 100644 --- a/src/relay/op/random/kernel.cc +++ b/src/relay/op/random/kernel.cc @@ -25,14 +25,16 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs); -static const TensorType THREEFRY_KEY_TYPE = TensorType({10}, tvm::DataType::UInt(64)); +static TensorType ThreefryKeyType() { + return TensorType({10}, tvm::DataType::UInt(64)); +} bool ThreefryGenerateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { const ThreefryGenerateAttrs* param = attrs.as(); ICHECK_EQ(types.size(), 2) << "ThreefryGenerate should have one input and one output"; - reporter->Assign(types[0], THREEFRY_KEY_TYPE); + reporter->Assign(types[0], ThreefryKeyType()); std::vector oshape; for (auto& x : param->out_shape) { @@ -41,7 +43,7 @@ bool ThreefryGenerateRel(const Array& types, int num_inputs, const Attrs& // generate returns the next key and an array of random values // TODO(@tkonolige, @altanh): support other output dtypes? reporter->Assign(types[1], - TupleType({THREEFRY_KEY_TYPE, TensorType(oshape, tvm::DataType::UInt(64))})); + TupleType({ThreefryKeyType(), TensorType(oshape, tvm::DataType::UInt(64))})); return true; } @@ -66,8 +68,8 @@ bool ThreefrySplitRel(const Array& types, int num_inputs, const Attrs& att const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2) << "ThreefrySplit should have one input and one output"; - reporter->Assign(types[0], THREEFRY_KEY_TYPE); - reporter->Assign(types[1], TupleType({THREEFRY_KEY_TYPE, THREEFRY_KEY_TYPE})); + reporter->Assign(types[0], ThreefryKeyType()); + reporter->Assign(types[1], TupleType({ThreefryKeyType(), ThreefryKeyType()})); return true; } From ed30d60595ab18e65d7ca7c23120a7fcd676fd22 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 17 Dec 2020 14:52:46 -0800 Subject: [PATCH 12/15] lint fix --- src/relay/op/random/kernel.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/relay/op/random/kernel.cc b/src/relay/op/random/kernel.cc index df1f4ab66af3..ec092a7e05f2 100644 --- a/src/relay/op/random/kernel.cc +++ b/src/relay/op/random/kernel.cc @@ -25,9 +25,7 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ThreefryGenerateAttrs); -static TensorType ThreefryKeyType() { - return TensorType({10}, tvm::DataType::UInt(64)); -} +static TensorType ThreefryKeyType() { return TensorType({10}, tvm::DataType::UInt(64)); } bool ThreefryGenerateRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { From ed59bfb7ade4518e36d91f3ecf2f0fac83bf1a6d Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Tue, 5 Jan 2021 16:00:14 -0800 Subject: [PATCH 13/15] Remove old todos, improve assert messages --- python/tvm/relay/op/random/kernel.py | 2 +- python/tvm/topi/random/kernel.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/op/random/kernel.py b/python/tvm/relay/op/random/kernel.py index d9dd796d24b7..96634943128d 100644 --- a/python/tvm/relay/op/random/kernel.py +++ b/python/tvm/relay/op/random/kernel.py @@ -104,7 +104,7 @@ def threefry_split(key): .. code-block:: python def foo(key): - new_key, num = threefry_generate(key, (1,)) + new_key, num = threefry_generate(key, (4,)) return num key = threefry_key(0) diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py index d0c2d5c3dd34..1cf052df24d8 100644 --- a/python/tvm/topi/random/kernel.py +++ b/python/tvm/topi/random/kernel.py @@ -127,8 +127,10 @@ def _threefry( assert nrounds % 4 == 0 assert nwords in [4, 8, 16] - assert key_buf.dtype == "uint64" # TODO: support 32 bit inputs - assert key_buf.dtype == counter_buf.dtype + # The paper has constants for 32 bit threefry, but we keep the implementation simple by only + # using 64-bit words. + assert key_buf.dtype == "uint64", "threefry only supports 64-bit keys" + assert key_buf.dtype == counter_buf.dtype, "threefry key and counter must be the same dtype" def mix(a, b, rotation): x = a + b # TODO should be wrapping @@ -148,9 +150,8 @@ def mix(a, b, rotation): # initial key constant, full_key[nwords] is equivalent to k_{N_W} in the Skein paper. full_key[nwords] = tvm.tir.const(0x1BD11BDAA9FC1A22, dtype="uint64") for i in range(nwords): - full_key[nwords] ^= key_buf[key_offset + i] # TODO: wrapping + full_key[nwords] ^= key_buf[key_offset + i] - # TODO: overwrite counter instead? with irb.for_range(0, out_shape, dtype="uint64", name="i") as i: for j in range(nwords): out_buf[out_offset + i * nwords + j] = counter_buf[counter_offset + j] + i @@ -236,7 +237,7 @@ def gen_ir(gen_ptr, out_gen_ptr, out_array_ptr): # enough room in the counter. tmp = irb.allocate(gen.dtype, 10, name="tmp", scope="global") - # TODO(tkonolige): for now we only use the last word of the counter for counting. Its too + # TODO(tkonolige): for now we only use the last word of the counter for counting. It is too # much work to figure out how to do 128 bit addition. # Max value for counter should be 2**64-2 because we need to reserve a special value to @@ -302,7 +303,7 @@ def gen_ir(gen_ptr, out_gen_ptr, out_array_ptr): def _shift_right(irb, a, b, out_a, a_off, out_b, b_off): - """Shift a 128bit number composed of two 64 bit words right by one""" + """Binary shift a 128bit number composed of two 64 bit words right by one.""" with irb.if_scope(a == 1): out_a[a_off] = tir.const(0, dtype=a.dtype) out_b[b_off] = tir.const(0x8000000000000000, dtype=a.dtype) From 9c4ac11479d11796ca8827d4c65031d069e75c6c Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 8 Jan 2021 12:54:17 -0800 Subject: [PATCH 14/15] describe how random number is generated --- python/tvm/topi/random/kernel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py index 1cf052df24d8..576fd9254a79 100644 --- a/python/tvm/topi/random/kernel.py +++ b/python/tvm/topi/random/kernel.py @@ -39,7 +39,9 @@ # the path indicates the left result of a split, a 1 indicates the right). To avoid continuously # growing the path, we can compress an existing path into the key portion of the generator by # hashing the current key, path, and counter to create the new key (this same technique is used if -# we run out of room for the counter). +# we run out of room for the counter). They key is initialized with a unique initial state. +# +# Random numbers are generated by applying the Threefry hash to the current key, path, and counter. # This module use encoding e4 from the appendix of "Splittable Pseudorandom Number Generators using # Cryptographic Hashing" (confusingly, the definition in the paper uses e3 to define the encoding From e3e8af2ba3c737adf38c66880d1a611f5245a2ed Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 8 Jan 2021 13:05:26 -0800 Subject: [PATCH 15/15] add tests for incorrect output size. also vary test sizes --- tests/python/relay/test_prng.py | 38 +++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_prng.py b/tests/python/relay/test_prng.py index 63d526415ca3..2109d3b30a82 100644 --- a/tests/python/relay/test_prng.py +++ b/tests/python/relay/test_prng.py @@ -49,8 +49,8 @@ def test_threefry_repeatability(target, ctx): def test_threefry_split(target, ctx): key = tvm.relay.random.threefry_key(1) left, right = tvm.relay.TupleWrapper(tvm.relay.random.threefry_split(key), 2) - _, rand1 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(left, (12,)), 2) - _, rand2 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(right, (12,)), 2) + _, rand1 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(left, (16,)), 2) + _, rand2 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(right, (16,)), 2) out1, out2 = tvm.relay.create_executor( "vm", tvm.IRModule.from_expr(tvm.relay.Function([], tvm.relay.Tuple((rand1, rand2)))), @@ -63,6 +63,23 @@ def test_threefry_split(target, ctx): ).any(), "Generate after split should not have the same output" +@tvm.testing.parametrize_targets +def test_threefry_sequential_generate(target, ctx): + key = tvm.relay.random.threefry_key(1) + key, rand1 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(key, (4,)), 2) + _, rand2 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(key, (4,)), 2) + out1, out2 = tvm.relay.create_executor( + "vm", + tvm.IRModule.from_expr(tvm.relay.Function([], tvm.relay.Tuple((rand1, rand2)))), + target=target, + ctx=ctx, + ).evaluate()() + + assert ( + out1.asnumpy() != out2.asnumpy() + ).any(), "Sequential generates should not have the same output" + + def test_threefry_generate_infer(): oshape = (12,) key_type = tvm.relay.TensorType([10], dtype="uint64") @@ -89,6 +106,7 @@ def test_threefry_split_infer(): @pytest.mark.xfail(raises=tvm.error.TVMError) def test_threefry_generate_infer_fail(): + # xfail: key size should be 10 fake_key = tvm.relay.const([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="uint64") rand1 = tvm.relay.random.threefry_generate(fake_key, (12,)) f = tvm.relay.Function([], rand1) @@ -97,12 +115,28 @@ def test_threefry_generate_infer_fail(): @pytest.mark.xfail(raises=tvm.error.TVMError) def test_threefry_split_infer_fail(): + # xfail: key size should be 10 fake_key = tvm.relay.const([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype="uint64") out_keys = tvm.relay.random.threefry_split(fake_key) f = tvm.relay.Function([], out_keys) f = run_infer_type(f) +@tvm.testing.requires_llvm +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_threefry_generate_incorrect_out_size(): + key = tvm.relay.random.threefry_key(1) + # xfail: output size should be multiple of 4 + key, rand1 = tvm.relay.TupleWrapper(tvm.relay.random.threefry_generate(key, (5,)), 2) + out1, out2 = tvm.relay.create_executor( + "vm", + tvm.IRModule.from_expr(tvm.relay.Function([], rand1)), + target=tvm.target.Target("llvm"), + ctx=tvm.context("cpu"), + ).evaluate()() + + if __name__ == "__main__": test_threefry_repeatability(tvm.target.Target("llvm"), tvm.context("cpu")) test_threefry_split(tvm.target.Target("llvm"), tvm.context("cpu")) + test_threefry_sequential_generate(tvm.target.Target("llvm"), tvm.context("cpu"))