Skip to content

Commit

Permalink
DataToTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
johnhaddon committed Oct 28, 2024
1 parent 2597950 commit f8f8405
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
15 changes: 15 additions & 0 deletions python/GafferMLTest/DataToTensorTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

import unittest

import imath

import IECore
import Gaffer
import GafferTest
Expand Down Expand Up @@ -90,5 +92,18 @@ def testTensor( self ) :
self.assertEqual( tensor.shape(), [ 3 ] )
self.assertEqual( tensor.asData(), IECore.FloatVectorData( [ 1, 2, 3 ] ) )

def testShapeModes( self ) :

node = GafferML.DataToTensor()
node.setup( Gaffer.V2iVectorDataPlug( defaultValue = IECore.V2iVectorData( [ imath.V2i( i ) for i in range( 0, 3 ) ] ) ) )

tensor = node["tensor"].getValue()
self.assertEqual( tensor.shape(), [ 3, 2 ] )

node["shapeMode"].setValue( node.ShapeMode.Custom )
node["shape"].setValue( IECore.Int64VectorData( [ 1, 1, 1, 6 ] ) )
tensor = node["tensor"].getValue()
self.assertEqual( tensor.shape(), [ 1, 1, 1, 6 ] )

if __name__ == "__main__" :
unittest.main()
21 changes: 13 additions & 8 deletions src/GafferML/DataToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,19 @@ void DataToTensor::compute( Gaffer::ValuePlug *output, const Gaffer::Context *co
{
if( output == tensorPlug() )
{
// if( auto d = dataPlug() )
// {
// ConstDataPtr bufferData = PlugAlgo::getValueAsData( d );
// ConstInt64VectorDataPtr shapeData = shapePlug()->getValue();
// ConstTensorPtr tensorData = new Tensor( bufferData, shapeData->readable() );
// static_cast<TensorPlug *>( output )->setValue( tensorData );
// }
// else
if( auto d = dataPlug() )
{
ConstInt64VectorDataPtr shapeData;
if( shapeModePlug()->getValue() == (int)ShapeMode::Custom )
{
shapeData = shapePlug()->getValue();
}
static const vector<int64_t> g_automaticShape;
ConstDataPtr bufferData = PlugAlgo::getValueAsData( d );
ConstTensorPtr tensorData = new Tensor( bufferData, shapeData ? shapeData->readable() : g_automaticShape );
static_cast<TensorPlug *>( output )->setValue( tensorData );
}
else
{
output->setToDefault();
}
Expand Down

0 comments on commit f8f8405

Please sign in to comment.