diff --git a/src/TensorFlowNET.Core/Util/Data.cs b/src/TensorFlowNET.Core/Util/Data.cs index a14c69b1..fe3466ed 100644 --- a/src/TensorFlowNET.Core/Util/Data.cs +++ b/src/TensorFlowNET.Core/Util/Data.cs @@ -1,4 +1,5 @@ -using Tensorflow.NumPy; +using OneOf; +using Tensorflow.NumPy; namespace Tensorflow.Util { @@ -8,10 +9,10 @@ namespace Tensorflow.Util /// public class ValidationDataPack { - public NDArray val_x; - public NDArray val_y; - public NDArray val_sample_weight = null; - + internal OneOf val_x; + internal NDArray val_y; + internal NDArray val_sample_weight = null; + public bool val_x_is_array = false; public ValidationDataPack((NDArray, NDArray) validation_data) { this.val_x = validation_data.Item1; @@ -27,15 +28,17 @@ public ValidationDataPack((NDArray, NDArray, NDArray) validation_data) public ValidationDataPack((IEnumerable, NDArray) validation_data) { - this.val_x = validation_data.Item1.ToArray()[0]; + this.val_x = validation_data.Item1.ToArray(); this.val_y = validation_data.Item2; + val_x_is_array = true; } public ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data) { - this.val_x = validation_data.Item1.ToArray()[0]; + this.val_x = validation_data.Item1.ToArray(); this.val_y = validation_data.Item2; this.val_sample_weight = validation_data.Item3; + val_x_is_array = true; } public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) @@ -52,15 +55,24 @@ public static implicit operator ValidationDataPack((IEnumerable, NDArra public void Deconstruct(out NDArray val_x, out NDArray val_y) { - val_x = this.val_x; + val_x = this.val_x.AsT0; val_y = this.val_y; } public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) { - val_x = this.val_x; + val_x = this.val_x.AsT0; + val_y = this.val_y; + val_sample_weight = this.val_sample_weight; + } + + // add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) + public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse) + { + val_x_array = this.val_x.AsT1; val_y = this.val_y; val_sample_weight = this.val_sample_weight; + unuse = null; } } } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs index b2750496..590f30a7 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs @@ -92,9 +92,17 @@ public static ((IEnumerable, NDArray, NDArray), ValidationDataPack) tra var train_y = y[new Slice(0, train_count)]; var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); var val_y = y[new Slice(train_count)]; - NDArray tmp_sample_weight = sample_weight; - sample_weight = sample_weight[new Slice(0, train_count)]; - ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); + + ValidationDataPack validation_data; + if (sample_weight != null) + { + validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]); + sample_weight = sample_weight[new Slice(0, train_count)]; + } + else + { + validation_data = (val_x, val_y); + } return ((train_x, train_y, sample_weight), validation_data); } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 474d5e5a..b3264429 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -70,13 +70,19 @@ public Dictionary evaluate(NDArray x, NDArray y, return evaluate(data_handler, callbacks, is_val, test_function); } - public Dictionary evaluate(IEnumerable x, Tensor y, int verbose = 1, bool is_val = false) + public Dictionary evaluate( + IEnumerable x, + Tensor y, + int verbose = 1, + NDArray sample_weight = null, + bool is_val = false) { var data_handler = new DataHandler(new DataHandlerArgs { X = new Tensors(x.ToArray()), Y = y, Model = this, + SampleWeight = sample_weight, StepsPerExecution = _steps_per_execution }); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index d61211c7..13a1b63b 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -7,6 +7,7 @@ using System.Diagnostics; using Tensorflow.Keras.Callbacks; using Tensorflow.Util; +using OneOf; namespace Tensorflow.Keras.Engine { @@ -287,10 +288,24 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List val_logs; + if (!validation_data.val_x_is_array) + { + (val_x, val_y, val_sample_weight) = validation_data; + // Because evaluate calls call_test_batch_end, this interferes with our output on the screen + // so we need to pass a is_val parameter to stop on_test_batch_end + val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true); + + } + else + { + (val_x_array, val_y, val_sample_weight, _) = validation_data; + val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true); + } foreach (var log in val_logs) { logs["val_" + log.Key] = log.Value;