Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix the validation_pack when multiple input #1212

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions src/TensorFlowNET.Core/Util/Data.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Tensorflow.NumPy;
using OneOf;
using Tensorflow.NumPy;

namespace Tensorflow.Util
{
Expand All @@ -8,10 +9,10 @@ namespace Tensorflow.Util
/// </summary>
public class ValidationDataPack
{
public NDArray val_x;
public NDArray val_y;
public NDArray val_sample_weight = null;

internal OneOf<NDArray, NDArray[]> val_x;
internal NDArray val_y;
internal NDArray val_sample_weight = null;
public bool val_x_is_array = false;
public ValidationDataPack((NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1;
Expand All @@ -27,15 +28,17 @@ public ValidationDataPack((NDArray, NDArray, NDArray) validation_data)

public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_x = validation_data.Item1.ToArray();
this.val_y = validation_data.Item2;
val_x_is_array = true;
}

public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_x = validation_data.Item1.ToArray();
this.val_y = validation_data.Item2;
this.val_sample_weight = validation_data.Item3;
val_x_is_array = true;
}

public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data)
Expand All @@ -52,15 +55,24 @@ public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArra

public void Deconstruct(out NDArray val_x, out NDArray val_y)
{
val_x = this.val_x;
val_x = this.val_x.AsT0;
val_y = this.val_y;
}

public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
{
val_x = this.val_x;
val_x = this.val_x.AsT0;
val_y = this.val_y;
val_sample_weight = this.val_sample_weight;
}

// add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse)
{
val_x_array = this.val_x.AsT1;
val_y = this.val_y;
val_sample_weight = this.val_sample_weight;
unuse = null;
}
}
}
14 changes: 11 additions & 3 deletions src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,17 @@ public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) tra
var train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
var val_y = y[new Slice(train_count)];
NDArray tmp_sample_weight = sample_weight;
sample_weight = sample_weight[new Slice(0, train_count)];
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]);

ValidationDataPack validation_data;
if (sample_weight != null)
{
validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]);
sample_weight = sample_weight[new Slice(0, train_count)];
}
else
{
validation_data = (val_x, val_y);
}
return ((train_x, train_y, sample_weight), validation_data);
}
}
Expand Down
8 changes: 7 additions & 1 deletion src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,19 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
return evaluate(data_handler, callbacks, is_val, test_function);
}

public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
public Dictionary<string, float> evaluate(
IEnumerable<Tensor> x,
Tensor y,
int verbose = 1,
NDArray sample_weight = null,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's recommended to put sample_weight the last parameter, to keep its backward compatibility. However if this is exactly the order in python, just keeping it is okay.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it follow the order in python.

bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(x.ToArray()),
Y = y,
Model = this,
SampleWeight = sample_weight,
StepsPerExecution = _steps_per_execution
});

Expand Down
23 changes: 19 additions & 4 deletions src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Diagnostics;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Util;
using OneOf;

namespace Tensorflow.Keras.Engine
{
Expand Down Expand Up @@ -287,10 +288,24 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal

if (validation_data != null)
{
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
var (val_x, val_y, val_sample_weight) = validation_data;
var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true);
NDArray val_x;
NDArray[] val_x_array;
NDArray val_y;
NDArray val_sample_weight;
Dictionary<string, float> val_logs;
if (!validation_data.val_x_is_array)
{
(val_x, val_y, val_sample_weight) = validation_data;
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true);

}
else
{
(val_x_array, val_y, val_sample_weight, _) = validation_data;
val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true);
}
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;
Expand Down
Loading