diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 2cdd868522..7c5aef7328 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -637,9 +637,41 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) : _runners = new ConcurrentBag(); } + private Delegate CreateGetter(DataViewRow input, int iinfo, Func activeOutput, OutputCache outputCache) + { + Host.AssertValue(input); + + var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray(); + + var type = Tf2MlNetType(_parent.TFOutputTypes[iinfo]).RawType; + Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType); + var srcTensorGetters = GetTensorValueGetters(input, _inputColIndices, _isInputVector, _parent.TFInputTypes, _fullySpecifiedShapes); + return Utils.MarshalInvoke(MakeGetter, type, input, iinfo, srcTensorGetters, activeOutputColNames, outputCache); + } + + public override Delegate[] CreateGetters(DataViewRow input, Func activeOutput, out Action disposer) + { + Contracts.Assert(input.Schema == InputSchema); + + OutputCache outputCacher = new OutputCache(); + + int n = OutputColumns.Value.Length; + var result = new Delegate[n]; + for (int i = 0; i < n; i++) { + if (!activeOutput(i)) + continue; + result[i] = CreateGetter(input, i, activeOutput, outputCacher); + } + disposer = () => + { + outputCacher.Dispose(); + }; + return result; + } + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); - private class OutputCache + private class OutputCache : IDisposable { public long Position; public Dictionary Outputs; @@ -648,22 +680,22 @@ public OutputCache() Position = -1; Outputs = new Dictionary(); } - } - - protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) - { - disposer = null; - Host.AssertValue(input); - var outputCache = new OutputCache(); - var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray(); + private bool _isDisposed; - var type = Tf2MlNetType(_parent.TFOutputTypes[iinfo]).RawType; - Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType); - var srcTensorGetters = GetTensorValueGetters(input, _inputColIndices, _isInputVector, _parent.TFInputTypes, _fullySpecifiedShapes); - return Utils.MarshalInvoke(MakeGetter, type, input, iinfo, srcTensorGetters, activeOutputColNames, outputCache); + public void Dispose() + { + if (_isDisposed) + return; + foreach (var tensor in Outputs.Values) + tensor.Dispose(); + _isDisposed = true; + } } + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + => throw new NotImplementedException("This should never be called!"); + private Delegate MakeGetter(DataViewRow input, int iinfo, ITensorValueGetter[] srcTensorGetters, string[] activeOutputColNames, OutputCache outputCache) where T : unmanaged { Host.AssertValue(input);