Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions source/mir/ndslice/algorithm.d
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,25 @@ private void checkShapesMatch(
}
}


private auto ref frontOf(alias slice)() { return slice.front; };
template frontOf(size_t N)
{
static if (N == 0)
enum frontOf = "";
else
{
enum i = N - 1;
enum frontOf = frontOf!i ~ "slices[" ~ i.stringof ~ "].front, ";
}
}

S reduceImpl(alias fun, S, Slices...)(S seed, Slices slices)
{
do
{
static if (slices[0].shape.length == 1)
seed = fun(seed, staticMap!(frontOf, slices));
seed = mixin("fun(seed, " ~ frontOf!(Slices.length) ~ ")");
else
seed = .reduceImpl!fun(seed, staticMap!(frontOf, slices));
seed = mixin(".reduceImpl!fun(seed," ~ frontOf!(Slices.length) ~ ")");
foreach(ref slice; slices)
slice.popFront;
}
Expand Down Expand Up @@ -267,9 +275,9 @@ void eachImpl(alias fun, Slices...)(Slices slices)
do
{
static if (slices[0].shape.length == 1)
fun(staticMap!(frontOf, slices));
mixin("fun(" ~ frontOf!(Slices.length) ~ ");");
else
.eachImpl!fun(staticMap!(frontOf, slices));
mixin(".eachImpl!fun(" ~ frontOf!(Slices.length) ~ ");");
foreach(ref slice; slices)
slice.popFront;
}
Expand Down Expand Up @@ -385,15 +393,15 @@ size_t findImpl(alias fun, size_t N, Slices...)(ref size_t[N] backwardIndex, Sli
{
static if (slices[0].shape.length == 1)
{
if (fun(staticMap!(frontOf, slices)))
if (mixin("fun(" ~ frontOf!(Slices.length) ~ ")"))
{
backwardIndex[0] = slices[0].length;
return 1;
}
}
else
{
if (findImpl!fun(backwardIndex[1 .. $], staticMap!(frontOf, slices)))
if (mixin("findImpl!fun(backwardIndex[1 .. $], " ~ frontOf!(Slices.length) ~ ")"))
{
backwardIndex[0] = slices[0].length;
return 1;
Expand Down Expand Up @@ -556,12 +564,12 @@ size_t anyImpl(alias fun, Slices...)(Slices slices)
{
static if (slices[0].shape.length == 1)
{
if (fun(staticMap!(frontOf, slices)))
if (mixin("fun(" ~ frontOf!(Slices.length) ~ ")"))
return true;
}
else
{
if (anyImpl!fun(staticMap!(frontOf, slices)))
if (mixin("anyImpl!fun(" ~ frontOf!(Slices.length) ~ ")"))
return true;
}
foreach(ref slice; slices)
Expand Down Expand Up @@ -676,12 +684,12 @@ size_t allImpl(alias fun, Slices...)(Slices slices)
{
static if (slices[0].shape.length == 1)
{
if (!fun(staticMap!(frontOf, slices)))
if (!mixin("fun(" ~ frontOf!(Slices.length) ~ ")"))
return false;
}
else
{
if (!allImpl!fun(staticMap!(frontOf, slices)))
if (!mixin("allImpl!fun(" ~ frontOf!(Slices.length) ~ ")"))
return false;
}
foreach(ref slice; slices)
Expand Down Expand Up @@ -1096,11 +1104,11 @@ size_t countImpl(alias fun, Slices...)(Slices slices)
{
static if (slices[0].shape.length == 1)
{
if(fun(staticMap!(frontOf, slices)))
if(mixin("fun(" ~ frontOf!(Slices.length) ~ ")"))
ret++;
}
else
ret += .countImpl!fun(staticMap!(frontOf, slices));
ret += mixin(".countImpl!fun(" ~ frontOf!(Slices.length) ~ ")");
foreach(ref slice; slices)
slice.popFront;
}
Expand Down
9 changes: 9 additions & 0 deletions source/mir/ndslice/allocation.d
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module mir.ndslice.allocation;
import std.traits;
import mir.ndslice.slice;
import mir.ndslice.internal;
import mir.ndslice.stack;

@fastmath:

Expand Down Expand Up @@ -94,6 +95,14 @@ pure nothrow unittest
assert(tensor[1, 1] == 5);
}

/// ditto
auto slice(size_t dim, Slices...)(Stack!(dim, Slices) stack)
{
auto ret = .slice!(Unqual!(stack.DeepElemType))(stack.shape);
ret[] = stack;
return ret;
}

pure nothrow unittest
{
import mir.ndslice.topology : iota;
Expand Down
12 changes: 12 additions & 0 deletions source/mir/ndslice/package.d
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,17 @@ $(TR $(TDNW $(SUBMODULE algorithm)
)
)

$(TR $(TDNW $(SUBMODULE stack)
$(BR) $(SMALL Concatenation and algorithms))
$(TD
$(SUBREF stack, isStack)
$(SUBREF stack, stack)
$(SUBREF stack, Stack)
$(SUBREF stack, stackDimension)
$(SUBREF stack, until)
)
)

$(TR $(TDNW $(SUBMODULE dynamic)
$(BR) $(SMALL Dynamic dimension manipulators))
$(TD
Expand Down Expand Up @@ -378,6 +389,7 @@ public import mir.ndslice.algorithm;
public import mir.ndslice.allocation;
public import mir.ndslice.dynamic;
public import mir.ndslice.slice;
public import mir.ndslice.stack;
public import mir.ndslice.topology;


Expand Down
54 changes: 52 additions & 2 deletions source/mir/ndslice/slice.d
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import std.meta;

import mir.internal.utility;
import mir.ndslice.internal;
import mir.ndslice.stack;
import mir.primitives;

@fastmath:
Expand Down Expand Up @@ -307,6 +308,8 @@ auto slicedField(Field)(Field field)
Returns the element type of a $(LREF Slice).
+/
alias DeepElementType(S : Slice!(kind, packs, Iterator), SliceKind kind, size_t[] packs, Iterator) = S.DeepElemType;
/// ditto
alias DeepElementType(S : Stack!(dim, Slices), size_t dim, Slices...) = S.DeepElemType;

///
unittest
Expand Down Expand Up @@ -1775,6 +1778,41 @@ struct Slice(SliceKind kind, size_t[] packs, Iterator)
[2, 2, 3, 3]]);
}


private void opIndexOpAssignImplStack(string op, T)(T value)
{
auto sl = this;
static if (stackDimension!T)
{
if (!sl.empty) do
{
mixin(`sl.front[] ` ~ op ~ `= value.front;`);
value.popFront;
sl.popFront;
}
while(!sl.empty);
}
else
{
foreach (ref slice; value._slices)
{
mixin("sl[0 .. slice.length][] " ~ op ~ "= slice;");
sl = sl[slice.length .. $];
}
assert(sl.empty);
}
}

///
void opIndexAssign(T, Slices...)(T stack, Slices slices)
if (isFullPureSlice!Slices && isStack!T)
{
import mir.ndslice.topology : unpack;
auto sl = this[slices].unpack;
static assert(isSlice!(typeof(sl))[0] == stack.N);
sl.opIndexOpAssignImplStack!""(stack);
}

/++
Assignment of a value (e.g. a number) to a $(B fully defined slice).

Expand All @@ -1784,7 +1822,8 @@ struct Slice(SliceKind kind, size_t[] packs, Iterator)
void opIndexAssign(T, Slices...)(T value, Slices slices)
if (isFullPureSlice!Slices
&& (!isDynamicArray!T || isDynamicArray!DeepElemType)
&& !isSlice!T)
&& !isSlice!T
&& !isStack!T)
{
import mir.ndslice.topology : unpack;
auto sl = this[slices].unpack;
Expand Down Expand Up @@ -2042,7 +2081,8 @@ struct Slice(SliceKind kind, size_t[] packs, Iterator)
void opIndexOpAssign(string op, T, Slices...)(T value, Slices slices)
if (isFullPureSlice!Slices
&& (!isDynamicArray!T || isDynamicArray!DeepElemType)
&& !isSlice!T)
&& !isSlice!T
&& !isStack!T)
{
import mir.ndslice.topology : unpack;
auto sl = this[slices].unpack;
Expand All @@ -2067,6 +2107,16 @@ struct Slice(SliceKind kind, size_t[] packs, Iterator)
assert(a[1] == [6, 6, 1]);
}

///
void opIndexOpAssign(string op,T, Slices...)(T stack, Slices slices)
if (isFullPureSlice!Slices && isStack!T)
{
import mir.ndslice.topology : unpack;
auto sl = this[slices].unpack;
static assert(isSlice!(typeof(sl))[0] == stack.N);
sl.opIndexOpAssignImplStack!op(stack);
}

static if (doUnittest)
/// Packed slices have the same behavior.
pure nothrow unittest
Expand Down
Loading