Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Implemented support for loading models with Concatenate layers #1192

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
// TODO: complete the implementation
public class MergeArgs : LayerArgs
public class MergeArgs : AutoSerializeLayerArgs
{
public Tensors Inputs { get; set; }
[JsonProperty("axis")]
public int Axis { get; set; }
}
}
9 changes: 9 additions & 0 deletions src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ public static NDArray concatenate((NDArray, NDArray) tuple, int axis = 0)
[AutoNumPy]
public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays));

[AutoNumPy]
public static NDArray stack(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.stack(arrays, axis));

[AutoNumPy]
public static NDArray stack((NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2 }, axis));

[AutoNumPy]
public static NDArray stack((NDArray, NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2, tuple.Item3 }, axis));

[AutoNumPy]
public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination));
}
Expand Down
30 changes: 18 additions & 12 deletions src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_co
created_layers = created_layers ?? new Dictionary<string, ILayer>();
var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary<ILayer, int>();
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
var unprocessed_nodes = new Dictionary<ILayer, List<NodeConfig>>();
// First, we create all layers and enqueue nodes to be processed
foreach (var layer_data in config.Layers)
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);
Expand Down Expand Up @@ -79,7 +79,7 @@ public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_co

static void process_layer(Dictionary<string, ILayer> created_layers,
LayerConfig layer_data,
Dictionary<ILayer, NodeConfig> unprocessed_nodes,
Dictionary<ILayer, List<NodeConfig>> unprocessed_nodes,
Dictionary<ILayer, int> node_count_by_layer)
{
ILayer layer = null;
Expand All @@ -92,32 +92,38 @@ static void process_layer(Dictionary<string, ILayer> created_layers,

created_layers[layer_name] = layer;
}
node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;
node_count_by_layer[layer] = layer_data.InboundNodes.Count - (_should_skip_first_node(layer) ? 1 : 0);

var inbound_nodes_data = layer_data.InboundNodes;
foreach (var node_data in inbound_nodes_data)
{
if (!unprocessed_nodes.ContainsKey(layer))
unprocessed_nodes[layer] = node_data;
unprocessed_nodes[layer] = new List<NodeConfig>() { node_data };
else
unprocessed_nodes.Add(layer, node_data);
unprocessed_nodes[layer].Add(node_data);
}
}

static void process_node(ILayer layer,
NodeConfig node_data,
List<NodeConfig> nodes_data,
Dictionary<string, ILayer> created_layers,
Dictionary<ILayer, int> node_count_by_layer,
Dictionary<(string, int), int> node_index_map)
{

var input_tensors = new List<Tensor>();
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;

var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
for (int i = 0; i < nodes_data.Count; i++)
{
var node_data = nodes_data[i];
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;

var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
}

var output_tensors = layer.Apply(input_tensors);

Expand Down
1 change: 1 addition & 0 deletions src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public override void build(KerasShapesWrapper input_shape)
shape_set.Add(shape);
}*/
_buildInputShape = input_shape;
built = true;
}

protected override Tensors _merge_function(Tensors inputs)
Expand Down
13 changes: 12 additions & 1 deletion src/TensorFlowNET.Keras/Utils/generic_utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,23 @@ public static FunctionalConfig deserialize_model_config(JToken json)
foreach (var token in layersToken)
{
var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]);

List<NodeConfig> nodeConfig = null; //python tensorflow sometimes exports inbound nodes in an extra nested array
if (token["inbound_nodes"].Count() > 0 && token["inbound_nodes"][0].Count() > 0 && token["inbound_nodes"][0][0].Count() > 0)
{
nodeConfig = token["inbound_nodes"].ToObject<List<List<NodeConfig>>>().FirstOrDefault() ?? new List<NodeConfig>();
}
else
{
nodeConfig = token["inbound_nodes"].ToObject<List<NodeConfig>>();
}

config.Layers.Add(new LayerConfig()
{
Config = args,
Name = token["name"].ToObject<string>(),
ClassName = token["class_name"].ToObject<string>(),
InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>()
InboundNodes = nodeConfig,
});
}
config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>();
Expand Down
15 changes: 10 additions & 5 deletions test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;

Expand All @@ -8,12 +9,16 @@ namespace Tensorflow.Keras.UnitTest.Layers
public class LayersMergingTest : EagerModeTestBase
{
[TestMethod]
public void Concatenate()
[DataRow(1, 4, 1, 5)]
[DataRow(2, 2, 2, 5)]
[DataRow(3, 2, 1, 10)]
public void Concatenate(int axis, int shapeA, int shapeB, int shapeC)
{
var x = np.arange(20).reshape((2, 2, 5));
var y = np.arange(20, 30).reshape((2, 1, 5));
var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y));
Assert.AreEqual((2, 3, 5), z.shape);
var x = np.arange(10).reshape((1, 2, 1, 5));
var y = np.arange(10, 20).reshape((1, 2, 1, 5));
var z = keras.layers.Concatenate(axis: axis).Apply(new Tensors(x, y));
Assert.AreEqual((1, shapeA, shapeB, shapeC), z.shape);
}

}
}
43 changes: 43 additions & 0 deletions test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json.Linq;
using System.Linq;
using System.Xml.Linq;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy;
using static HDF.PInvoke.H5Z;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

Expand Down Expand Up @@ -124,4 +127,44 @@ public void TestModelBeforeTF2_5()
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model;
model.summary();
}



[TestMethod]
public void CreateConcatenateModelSaveAndLoad()
{
// a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded.
var input_layer = tf.keras.layers.Input((8, 8, 5));

var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer);
conv1.Name = "conv1";

var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer);
conv2.Name = "conv2";

var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2));
concat1.Name = "concat1";

var model = tf.keras.Model(input_layer, concat1);
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());

model.save(@"Assets/concat_axis3_model");


var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT);

var tensors1 = model.predict(tensorInput);

Assert.AreEqual((1, 8, 8, 4), tensors1.shape);

model = null;
keras.backend.clear_session();

var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model");

var tensors2 = model2.predict(tensorInput);

Assert.AreEqual(tensors1.shape, tensors2.shape);
}

}
Loading