Skip to content

Commit 4684f8b

Browse files
committed
Merge pull request #419 from cogmission/slight_persistence_patch
Slight persistence patch and HTMObjectIn/Output constructor change.
2 parents 95ddd4c + 22ba2d7 commit 4684f8b

File tree

7 files changed

+98
-10
lines changed

7 files changed

+98
-10
lines changed

build.gradle

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ apply plugin: 'eclipse'
44
apply plugin: 'signing'
55

66
group = 'org.numenta'
7-
version = '0.6.7-SNAPSHOT'
7+
version = '0.6.8'
88
archivesBaseName = 'htm.java'
99

1010
sourceCompatibility = 1.8
1111
targetCompatibility = 1.8
1212

1313
jar {
1414
manifest {
15-
attributes 'Implementation-Title': 'htm.java', 'Implementation-Version': '0.6.7-SNAPSHOT'
15+
attributes 'Implementation-Title': 'htm.java', 'Implementation-Version': '0.6.8'
1616
}
1717
}
1818

@@ -126,6 +126,7 @@ uploadArchives {
126126
javadoc.failOnError = false
127127

128128
if(!project.hasProperty('ossrhUsername')) {
129+
println "returning from has Property false"
129130
return
130131
}
131132

pom.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
<groupId>org.numenta</groupId>
66
<artifactId>htm.java</artifactId>
7-
<version>0.6.7-SNAPSHOT</version>
7+
<version>0.6.8</version>
88
<name>htm.java</name>
99
<description>The Java version of Numenta's HTM technology</description>
1010

src/main/java/org/numenta/nupic/network/Network.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,15 @@ public static PALayer<?> createPALayer(String name, Parameters p) {
270270
@SuppressWarnings("unchecked")
271271
@Override
272272
public Network preSerialize() {
273-
if(shouldDoHalt) {
273+
if(shouldDoHalt && isThreadRunning) {
274274
halt();
275+
}else{ // Make sure "close()" has been called on the Network
276+
if(regions.size() == 1) {
277+
this.tail = regions.get(0);
278+
}
279+
tail.close();
275280
}
281+
276282
regions.stream().forEach(r -> r.preSerialize());
277283
return this;
278284
}

src/main/java/org/numenta/nupic/serialize/HTMObjectInput.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import java.io.InputStream;
55

66
import org.numenta.nupic.Persistable;
7+
import org.nustaq.serialization.FSTConfiguration;
78
import org.nustaq.serialization.FSTObjectInput;
89

910
public class HTMObjectInput extends FSTObjectInput {
10-
public HTMObjectInput(InputStream in) throws IOException {
11-
super(in);
11+
public HTMObjectInput(InputStream in, FSTConfiguration config) throws IOException {
12+
super(in, config);
1213
}
1314

1415
@SuppressWarnings("rawtypes")

src/main/java/org/numenta/nupic/serialize/HTMObjectOutput.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import java.io.OutputStream;
55

66
import org.numenta.nupic.Persistable;
7+
import org.nustaq.serialization.FSTConfiguration;
78
import org.nustaq.serialization.FSTObjectOutput;
89

910
public class HTMObjectOutput extends FSTObjectOutput {
10-
public HTMObjectOutput(OutputStream out) {
11-
super(out);
11+
public HTMObjectOutput(OutputStream out, FSTConfiguration config) {
12+
super(out, config);
1213
}
1314

1415
@SuppressWarnings("rawtypes")

src/main/java/org/numenta/nupic/serialize/SerializerCore.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundE
104104
* @throws IOException
105105
*/
106106
public HTMObjectInput getObjectInput(InputStream is) throws IOException {
107-
return new HTMObjectInput(is);
107+
return new HTMObjectInput(is, fastSerialConfig);
108108
}
109109

110110
/**
@@ -113,7 +113,7 @@ public HTMObjectInput getObjectInput(InputStream is) throws IOException {
113113
* @return the HTMObjectOutput
114114
*/
115115
public <T extends Persistable> HTMObjectOutput getObjectOutput(OutputStream os) {
116-
return new HTMObjectOutput(os);
116+
return new HTMObjectOutput(os, fastSerialConfig);
117117
}
118118

119119
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package org.numenta.nupic.serialize;
2+
3+
import static org.junit.Assert.assertNotNull;
4+
import static org.junit.Assert.assertTrue;
5+
import static org.junit.Assert.fail;
6+
7+
import java.io.ByteArrayInputStream;
8+
import java.io.ByteArrayOutputStream;
9+
10+
import org.junit.Test;
11+
import org.numenta.nupic.Parameters;
12+
import org.numenta.nupic.Parameters.KEY;
13+
import org.numenta.nupic.algorithms.Anomaly;
14+
import org.numenta.nupic.algorithms.SpatialPooler;
15+
import org.numenta.nupic.algorithms.TemporalMemory;
16+
import org.numenta.nupic.network.Network;
17+
import org.numenta.nupic.network.NetworkTestHarness;
18+
import org.numenta.nupic.network.Persistence;
19+
import org.numenta.nupic.network.PublisherSupplier;
20+
import org.numenta.nupic.network.sensor.ObservableSensor;
21+
import org.numenta.nupic.network.sensor.Sensor;
22+
import org.numenta.nupic.network.sensor.SensorParams;
23+
import org.numenta.nupic.network.sensor.SensorParams.Keys;
24+
import org.numenta.nupic.util.FastRandom;
25+
26+
27+
public class HTMObjectInputOutputTest {
28+
29+
@Test
30+
public void testRoundTrip() {
31+
Network network = getLoadedHotGymNetwork();
32+
SerializerCore serializer = Persistence.get().serializer();
33+
ByteArrayOutputStream baos = new ByteArrayOutputStream();
34+
HTMObjectOutput writer = serializer.getObjectOutput(baos);
35+
try {
36+
writer.writeObject(network, Network.class);
37+
writer.flush();
38+
writer.close();
39+
}catch(Exception e) {
40+
fail();
41+
}
42+
43+
byte[] bytes = baos.toByteArray();
44+
45+
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
46+
try {
47+
HTMObjectInput reader = serializer.getObjectInput(bais);
48+
Network serializedNetwork = (Network)reader.readObject(Network.class);
49+
assertNotNull(serializedNetwork);
50+
assertTrue(serializedNetwork.equals(network));
51+
}catch(Exception e) {
52+
e.printStackTrace();
53+
fail();
54+
}
55+
}
56+
57+
private Network getLoadedHotGymNetwork() {
58+
Parameters p = NetworkTestHarness.getParameters().copy();
59+
p = p.union(NetworkTestHarness.getHotGymTestEncoderParams());
60+
p.setParameterByKey(KEY.RANDOM, new FastRandom(42));
61+
62+
Sensor<ObservableSensor<String[]>> sensor = Sensor.create(
63+
ObservableSensor::create, SensorParams.create(Keys::obs, new Object[] {"name",
64+
PublisherSupplier.builder()
65+
.addHeader("timestamp, consumption")
66+
.addHeader("datetime, float")
67+
.addHeader("B").build() }));
68+
69+
Network network = Network.create("test network", p).add(Network.createRegion("r1")
70+
.add(Network.createLayer("1", p)
71+
.alterParameter(KEY.AUTO_CLASSIFY, true)
72+
.add(Anomaly.create())
73+
.add(new TemporalMemory())
74+
.add(new SpatialPooler())
75+
.add(sensor)));
76+
77+
return network;
78+
}
79+
}

0 commit comments

Comments
 (0)