Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve resource handling #6

Merged
merged 3 commits into from
Feb 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions apduWrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func WrapCommandAPDU(
}

// UnwrapResponseAPDU parses a response of 64 byte packets into the real data
func UnwrapResponseAPDU(channel uint16, pipe <- chan []byte, packetSize int) ([]byte, error) {
func UnwrapResponseAPDU(channel uint16, pipe <-chan []byte, packetSize int) ([]byte, error) {
var sequenceIdx uint16

var totalResult []byte
Expand All @@ -135,7 +135,7 @@ func UnwrapResponseAPDU(channel uint16, pipe <- chan []byte, packetSize int) ([]

for !done {
// Read next packet from the channel
buffer := <- pipe
buffer := <-pipe

result, responseSize, err := DeserializePacket(channel, buffer, sequenceIdx)
if err != nil {
Expand All @@ -157,4 +157,4 @@ func UnwrapResponseAPDU(channel uint16, pipe <- chan []byte, packetSize int) ([]
// Remove trailing zeros
totalResult = totalResult[:totalSize]
return totalResult, nil
}
}
40 changes: 20 additions & 20 deletions apduWrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

func Test_SerializePacket_EmptyCommand(t *testing.T) {
var command= make([]byte, 1)
var command = make([]byte, 1)

_, _, err := SerializePacket(0x0101, command, 64, 0)
assert.Nil(t, err, "Commands smaller than 3 bytes should return error")
Expand All @@ -42,9 +42,9 @@ func Test_SerializePacket_PacketSize(t *testing.T) {
commandLen uint16
}

h := header{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 32}
h := header{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 32}

var command= make([]byte, h.commandLen)
var command = make([]byte, h.commandLen)

result, _, _ := SerializePacket(
h.channel,
Expand All @@ -65,9 +65,9 @@ func Test_SerializePacket_Header(t *testing.T) {
commandLen uint16
}

h := header{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 32}
h := header{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 32}

var command= make([]byte, h.commandLen)
var command = make([]byte, h.commandLen)

result, _, _ := SerializePacket(
h.channel,
Expand All @@ -91,17 +91,17 @@ func Test_SerializePacket_Offset(t *testing.T) {
commandLen uint16
}

h := header{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 100}
h := header{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 100}

var command= make([]byte, h.commandLen)
var command = make([]byte, h.commandLen)

_, offset, _ := SerializePacket(
h.channel,
command,
packetSize,
h.sequenceIdx)

assert.Equal(t, packetSize - int(unsafe.Sizeof(h))+1, offset, "Wrong offset returned. Offset must point to the next comamnd byte that needs to be packet-ized.")
assert.Equal(t, packetSize-int(unsafe.Sizeof(h))+1, offset, "Wrong offset returned. Offset must point to the next comamnd byte that needs to be packet-ized.")
}

func Test_WrapCommandAPDU_NumberOfPackets(t *testing.T) {
Expand All @@ -119,9 +119,9 @@ func Test_WrapCommandAPDU_NumberOfPackets(t *testing.T) {
tag uint8
}

h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 100}
h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 100}

var command= make([]byte, h1.commandLen)
var command = make([]byte, h1.commandLen)

result, _ := WrapCommandAPDU(
h1.channel,
Expand All @@ -146,9 +146,9 @@ func Test_WrapCommandAPDU_CheckHeaders(t *testing.T) {
tag uint8
}

h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 100}
h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 100}

var command= make([]byte, h1.commandLen)
var command = make([]byte, h1.commandLen)

result, _ := WrapCommandAPDU(
h1.channel,
Expand Down Expand Up @@ -181,9 +181,9 @@ func Test_WrapCommandAPDU_CheckData(t *testing.T) {
tag uint8
}

h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 200}
h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 200}

var command= make([]byte, h1.commandLen)
var command = make([]byte, h1.commandLen)

for i := range command {
command[i] = byte(i % 256)
Expand Down Expand Up @@ -228,9 +228,9 @@ func Test_DeserializePacket_FirstPacket(t *testing.T) {

output, totalSize, err := DeserializePacket(0x0101, packet, 0)

assert.Nil(t,err, "Simple deserialize should not have errors")
assert.Nil(t, err, "Simple deserialize should not have errors")
assert.Equal(t, len(sampleCommand), int(totalSize), "TotalSize is incorrect")
assert.Equal(t, packetSize - firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong")
assert.Equal(t, packetSize-firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong")
assert.True(t, bytes.Compare(output[:len(sampleCommand)], sampleCommand) == 0, "Deserialized message does not match the original")
}

Expand All @@ -243,9 +243,9 @@ func Test_DeserializePacket_SecondMessage(t *testing.T) {

output, totalSize, err := DeserializePacket(0x0101, packet, 1)

assert.Nil(t,err, "Simple deserialize should not have errors")
assert.Nil(t, err, "Simple deserialize should not have errors")
assert.Equal(t, 0, int(totalSize), "TotalSize should not be returned from deserialization of non-first packet")
assert.Equal(t, packetSize - firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong")
assert.Equal(t, packetSize-firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong")
assert.True(t, bytes.Compare(output[:len(sampleCommand)], sampleCommand) == 0, "Deserialized message does not match the original")
}

Expand All @@ -256,15 +256,15 @@ func Test_UnwrapApdu_SmokeTest(t *testing.T) {
var packetSize int = 64

// Initialize some dummy input
var input= make([]byte, inputSize)
var input = make([]byte, inputSize)
for i := range input {
input[i] = byte(i % 256)
}

serialized, _ := WrapCommandAPDU(channel, input, packetSize)

// Allocate enough buffers to keep all the packets
pipe := make(chan []byte, int(math.Ceil(float64(inputSize) / float64(packetSize))))
pipe := make(chan []byte, int(math.Ceil(float64(inputSize)/float64(packetSize))))
// Send all the packets to the pipe
for len(serialized) > 0 {
pipe <- serialized[:packetSize]
Expand Down
29 changes: 14 additions & 15 deletions ledger.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package ledger_go
import (
"errors"
"fmt"
"github.com/zondax/hid"
"sync"

"github.com/zondax/hid"
)

const (
Expand All @@ -34,7 +35,7 @@ const (
type Ledger struct {
device hid.Device
readCo sync.Once
readChannel chan [] byte
readChannel chan []byte
Logging bool
}

Expand Down Expand Up @@ -70,23 +71,17 @@ func FindLedger() (*Ledger, error) {
devices := hid.Enumerate(VendorLedger, 0)

for _, d := range devices {
if d.VendorID == VendorLedger && d.UsagePage == UsagePageLedger {
device, err := d.Open()
if err != nil {
return nil, err
}
return NewLedger(device), nil
}
deviceFound := d.UsagePage == UsagePageLedger
deviceFound = deviceFound || (d.Product == "Nano S" && d.Interface == 0)

// Linux discovery
if d.VendorID == VendorLedger && d.Product == "Nano S" && d.Interface == 0 {
if deviceFound {
device, err := d.Open()
if err != nil {
return nil, err
if err == nil {
return NewLedger(device), nil
}
return NewLedger(device), nil
}
}

return nil, errors.New("no ledger connected")
}

Expand Down Expand Up @@ -126,6 +121,10 @@ func ErrorMessage(errorCode uint16) string {
}
}

func (ledger *Ledger) Close() error {
return ledger.device.Close()
}

func (ledger *Ledger) Write(buffer []byte) (int, error) {
totalBytes := len(buffer)
totalWrittenBytes := 0
Expand All @@ -150,7 +149,7 @@ func (ledger *Ledger) Read() <-chan []byte {
return ledger.readChannel
}

func (ledger *Ledger) initReadChannel(){
func (ledger *Ledger) initReadChannel() {
ledger.readChannel = make(chan []byte, 30)
go ledger.readThread()
}
Expand Down
16 changes: 8 additions & 8 deletions ledger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ package ledger_go
import (
"encoding/hex"
"fmt"
"github.com/zondax/hid"
"github.com/stretchr/testify/assert"
"github.com/zondax/hid"
"testing"
)

Expand All @@ -41,7 +41,7 @@ func Test_FindLedger(t *testing.T) {
fmt.Println("\n*********************************")
fmt.Println("Did you enter the password??")
fmt.Println("*********************************")
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}
assert.NotNil(t, ledger)
}
Expand All @@ -52,7 +52,7 @@ func Test_BasicExchange(t *testing.T) {
fmt.Println("\n*********************************")
fmt.Println("Did you enter the password??")
fmt.Println("*********************************")
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}
assert.NotNil(t, ledger)

Expand All @@ -63,7 +63,7 @@ func Test_BasicExchange(t *testing.T) {

if err != nil {
fmt.Printf("iteration %d\n", i)
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}

assert.Equal(t, 4, len(response))
Expand All @@ -76,23 +76,23 @@ func Test_LongExchange(t *testing.T) {
fmt.Println("\n*********************************")
fmt.Println("Did you enter the password??")
fmt.Println("*********************************")
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}
assert.NotNil(t, ledger)

path := "052c000080760000800000008000000000000000000000000000000000000000000000000000000000";
path := "052c000080760000800000008000000000000000000000000000000000000000000000000000000000"
pathBytes, err := hex.DecodeString(path)
if err != nil {
t.Fatalf("invalid path in test")
}

header := []byte { 0x55, 1, 0, 0, byte(len(pathBytes))}
header := []byte{0x55, 1, 0, 0, byte(len(pathBytes))}
message := append(header, pathBytes...)

response, err := ledger.Exchange(message)

if err != nil {
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}

assert.Equal(t, 65, len(response))
Expand Down