tangled
alpha
login
or
join now
oppi.li
/
sets
0
fork
atom
a rusty set datastructure for go
0
fork
atom
overview
issues
pulls
pipelines
set: init
Signed-off-by: oppiliappan <me@oppi.li>
oppi.li
3 months ago
4710e037
+411
3 changed files
expand all
collapse all
unified
split
go.mod
set.go
set_test.go
+3
go.mod
···
1
1
+
module tangled.org/oppi.li/set
2
2
+
3
3
+
go 1.25.0
+168
set.go
···
1
1
+
package set
2
2
+
3
3
+
import (
4
4
+
"iter"
5
5
+
"maps"
6
6
+
)
7
7
+
8
8
+
type Set[T comparable] struct {
9
9
+
data map[T]struct{}
10
10
+
}
11
11
+
12
12
+
func New[T comparable]() Set[T] {
13
13
+
return Set[T]{
14
14
+
data: make(map[T]struct{}),
15
15
+
}
16
16
+
}
17
17
+
18
18
+
func (s *Set[T]) Insert(item T) bool {
19
19
+
_, exists := s.data[item]
20
20
+
s.data[item] = struct{}{}
21
21
+
return !exists
22
22
+
}
23
23
+
24
24
+
func (s *Set[T]) Remove(item T) bool {
25
25
+
_, exists := s.data[item]
26
26
+
if exists {
27
27
+
delete(s.data, item)
28
28
+
}
29
29
+
return exists
30
30
+
}
31
31
+
32
32
+
func (s Set[T]) Contains(item T) bool {
33
33
+
_, exists := s.data[item]
34
34
+
return exists
35
35
+
}
36
36
+
37
37
+
func (s Set[T]) Len() int {
38
38
+
return len(s.data)
39
39
+
}
40
40
+
41
41
+
func (s Set[T]) IsEmpty() bool {
42
42
+
return len(s.data) == 0
43
43
+
}
44
44
+
45
45
+
func (s *Set[T]) Clear() {
46
46
+
s.data = make(map[T]struct{})
47
47
+
}
48
48
+
49
49
+
func (s Set[T]) All() iter.Seq[T] {
50
50
+
return func(yield func(T) bool) {
51
51
+
for item := range s.data {
52
52
+
if !yield(item) {
53
53
+
return
54
54
+
}
55
55
+
}
56
56
+
}
57
57
+
}
58
58
+
59
59
+
func (s Set[T]) Clone() Set[T] {
60
60
+
return Set[T]{
61
61
+
data: maps.Clone(s.data),
62
62
+
}
63
63
+
}
64
64
+
65
65
+
func (s Set[T]) Union(other Set[T]) iter.Seq[T] {
66
66
+
if s.Len() >= other.Len() {
67
67
+
return chain(s.All(), other.Difference(s))
68
68
+
} else {
69
69
+
return chain(other.All(), s.Difference(other))
70
70
+
}
71
71
+
}
72
72
+
73
73
+
func chain[T any](seqs ...iter.Seq[T]) iter.Seq[T] {
74
74
+
return func(yield func(T) bool) {
75
75
+
for _, seq := range seqs {
76
76
+
for item := range seq {
77
77
+
if !yield(item) {
78
78
+
return
79
79
+
}
80
80
+
}
81
81
+
}
82
82
+
}
83
83
+
}
84
84
+
85
85
+
func (s Set[T]) Intersection(other Set[T]) iter.Seq[T] {
86
86
+
return func(yield func(T) bool) {
87
87
+
for item := range s.data {
88
88
+
if other.Contains(item) {
89
89
+
if !yield(item) {
90
90
+
return
91
91
+
}
92
92
+
}
93
93
+
}
94
94
+
}
95
95
+
}
96
96
+
97
97
+
func (s Set[T]) Difference(other Set[T]) iter.Seq[T] {
98
98
+
return func(yield func(T) bool) {
99
99
+
for item := range s.data {
100
100
+
if !other.Contains(item) {
101
101
+
if !yield(item) {
102
102
+
return
103
103
+
}
104
104
+
}
105
105
+
}
106
106
+
}
107
107
+
}
108
108
+
109
109
+
func (s Set[T]) SymmetricDifference(other Set[T]) iter.Seq[T] {
110
110
+
return func(yield func(T) bool) {
111
111
+
for item := range s.data {
112
112
+
if !other.Contains(item) {
113
113
+
if !yield(item) {
114
114
+
return
115
115
+
}
116
116
+
}
117
117
+
}
118
118
+
for item := range other.data {
119
119
+
if !s.Contains(item) {
120
120
+
if !yield(item) {
121
121
+
return
122
122
+
}
123
123
+
}
124
124
+
}
125
125
+
}
126
126
+
}
127
127
+
128
128
+
func (s Set[T]) IsSubset(other Set[T]) bool {
129
129
+
for item := range s.data {
130
130
+
if !other.Contains(item) {
131
131
+
return false
132
132
+
}
133
133
+
}
134
134
+
return true
135
135
+
}
136
136
+
137
137
+
func (s Set[T]) IsSuperset(other Set[T]) bool {
138
138
+
return other.IsSubset(s)
139
139
+
}
140
140
+
141
141
+
func (s Set[T]) IsDisjoint(other Set[T]) bool {
142
142
+
for item := range s.data {
143
143
+
if other.Contains(item) {
144
144
+
return false
145
145
+
}
146
146
+
}
147
147
+
return true
148
148
+
}
149
149
+
150
150
+
func (s Set[T]) Equal(other Set[T]) bool {
151
151
+
if s.Len() != other.Len() {
152
152
+
return false
153
153
+
}
154
154
+
for item := range s.data {
155
155
+
if !other.Contains(item) {
156
156
+
return false
157
157
+
}
158
158
+
}
159
159
+
return true
160
160
+
}
161
161
+
162
162
+
func Collect[T comparable](seq iter.Seq[T]) Set[T] {
163
163
+
result := New[T]()
164
164
+
for item := range seq {
165
165
+
result.Insert(item)
166
166
+
}
167
167
+
return result
168
168
+
}
+240
set_test.go
···
1
1
+
package set
2
2
+
3
3
+
import (
4
4
+
"slices"
5
5
+
"testing"
6
6
+
)
7
7
+
8
8
+
func TestNew(t *testing.T) {
9
9
+
s := New[int]()
10
10
+
if s.Len() != 0 {
11
11
+
t.Errorf("New set should be empty, got length %d", s.Len())
12
12
+
}
13
13
+
if !s.IsEmpty() {
14
14
+
t.Error("New set should be empty")
15
15
+
}
16
16
+
}
17
17
+
18
18
+
func TestFromSlice(t *testing.T) {
19
19
+
s := Collect(slices.Values([]int{1, 2, 3, 2, 1}))
20
20
+
if s.Len() != 3 {
21
21
+
t.Errorf("Expected length 3, got %d", s.Len())
22
22
+
}
23
23
+
if !s.Contains(1) || !s.Contains(2) || !s.Contains(3) {
24
24
+
t.Error("Set should contain all unique elements from slice")
25
25
+
}
26
26
+
}
27
27
+
28
28
+
func TestInsert(t *testing.T) {
29
29
+
s := New[string]()
30
30
+
31
31
+
if !s.Insert("hello") {
32
32
+
t.Error("First insert should return true")
33
33
+
}
34
34
+
if s.Insert("hello") {
35
35
+
t.Error("Duplicate insert should return false")
36
36
+
}
37
37
+
if s.Len() != 1 {
38
38
+
t.Errorf("Expected length 1, got %d", s.Len())
39
39
+
}
40
40
+
}
41
41
+
42
42
+
func TestRemove(t *testing.T) {
43
43
+
s := Collect(slices.Values([]int{1, 2, 3}))
44
44
+
45
45
+
if !s.Remove(2) {
46
46
+
t.Error("Remove existing element should return true")
47
47
+
}
48
48
+
if s.Remove(2) {
49
49
+
t.Error("Remove non-existing element should return false")
50
50
+
}
51
51
+
if s.Contains(2) {
52
52
+
t.Error("Element should be removed")
53
53
+
}
54
54
+
if s.Len() != 2 {
55
55
+
t.Errorf("Expected length 2, got %d", s.Len())
56
56
+
}
57
57
+
}
58
58
+
59
59
+
func TestContains(t *testing.T) {
60
60
+
s := Collect(slices.Values([]int{1, 2, 3}))
61
61
+
62
62
+
if !s.Contains(1) {
63
63
+
t.Error("Should contain 1")
64
64
+
}
65
65
+
if s.Contains(4) {
66
66
+
t.Error("Should not contain 4")
67
67
+
}
68
68
+
}
69
69
+
70
70
+
func TestClear(t *testing.T) {
71
71
+
s := Collect(slices.Values([]int{1, 2, 3}))
72
72
+
s.Clear()
73
73
+
74
74
+
if !s.IsEmpty() {
75
75
+
t.Error("Set should be empty after clear")
76
76
+
}
77
77
+
if s.Len() != 0 {
78
78
+
t.Errorf("Expected length 0, got %d", s.Len())
79
79
+
}
80
80
+
}
81
81
+
82
82
+
func TestIterator(t *testing.T) {
83
83
+
s := Collect(slices.Values([]int{1, 2, 3}))
84
84
+
var items []int
85
85
+
86
86
+
for item := range s.All() {
87
87
+
items = append(items, item)
88
88
+
}
89
89
+
90
90
+
slices.Sort(items)
91
91
+
expected := []int{1, 2, 3}
92
92
+
if !slices.Equal(items, expected) {
93
93
+
t.Errorf("Expected %v, got %v", expected, items)
94
94
+
}
95
95
+
}
96
96
+
97
97
+
func TestClone(t *testing.T) {
98
98
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
99
99
+
s2 := s1.Clone()
100
100
+
101
101
+
if !s1.Equal(s2) {
102
102
+
t.Error("Cloned set should be equal to original")
103
103
+
}
104
104
+
105
105
+
s2.Insert(4)
106
106
+
if s1.Contains(4) {
107
107
+
t.Error("Modifying clone should not affect original")
108
108
+
}
109
109
+
}
110
110
+
111
111
+
func TestUnion(t *testing.T) {
112
112
+
s1 := Collect(slices.Values([]int{1, 2}))
113
113
+
s2 := Collect(slices.Values([]int{2, 3}))
114
114
+
115
115
+
result := Collect(s1.Union(s2))
116
116
+
expected := Collect(slices.Values([]int{1, 2, 3}))
117
117
+
118
118
+
if !result.Equal(expected) {
119
119
+
t.Errorf("Expected %v, got %v", expected, result)
120
120
+
}
121
121
+
}
122
122
+
123
123
+
func TestIntersection(t *testing.T) {
124
124
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
125
125
+
s2 := Collect(slices.Values([]int{2, 3, 4}))
126
126
+
127
127
+
expected := Collect(slices.Values([]int{2, 3}))
128
128
+
result := Collect(s1.Intersection(s2))
129
129
+
130
130
+
if !result.Equal(expected) {
131
131
+
t.Errorf("Expected %v, got %v", expected, result)
132
132
+
}
133
133
+
}
134
134
+
135
135
+
func TestDifference(t *testing.T) {
136
136
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
137
137
+
s2 := Collect(slices.Values([]int{2, 3, 4}))
138
138
+
139
139
+
expected := Collect(slices.Values([]int{1}))
140
140
+
result := Collect(s1.Difference(s2))
141
141
+
142
142
+
if !result.Equal(expected) {
143
143
+
t.Errorf("Expected %v, got %v", expected, result)
144
144
+
}
145
145
+
}
146
146
+
147
147
+
func TestSymmetricDifference(t *testing.T) {
148
148
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
149
149
+
s2 := Collect(slices.Values([]int{2, 3, 4}))
150
150
+
151
151
+
expected := Collect(slices.Values([]int{1, 4}))
152
152
+
result := Collect(s1.SymmetricDifference(s2))
153
153
+
154
154
+
if !result.Equal(expected) {
155
155
+
t.Errorf("Expected %v, got %v", expected, result)
156
156
+
}
157
157
+
}
158
158
+
159
159
+
func TestSymmetricDifferenceCommutativeProperty(t *testing.T) {
160
160
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
161
161
+
s2 := Collect(slices.Values([]int{2, 3, 4}))
162
162
+
163
163
+
result1 := Collect(s1.SymmetricDifference(s2))
164
164
+
result2 := Collect(s2.SymmetricDifference(s1))
165
165
+
166
166
+
if !result1.Equal(result2) {
167
167
+
t.Errorf("Expected %v, got %v", result1, result2)
168
168
+
}
169
169
+
}
170
170
+
171
171
+
func TestIsSubset(t *testing.T) {
172
172
+
s1 := Collect(slices.Values([]int{1, 2}))
173
173
+
s2 := Collect(slices.Values([]int{1, 2, 3}))
174
174
+
175
175
+
if !s1.IsSubset(s2) {
176
176
+
t.Error("s1 should be subset of s2")
177
177
+
}
178
178
+
if s2.IsSubset(s1) {
179
179
+
t.Error("s2 should not be subset of s1")
180
180
+
}
181
181
+
}
182
182
+
183
183
+
func TestIsSuperset(t *testing.T) {
184
184
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
185
185
+
s2 := Collect(slices.Values([]int{1, 2}))
186
186
+
187
187
+
if !s1.IsSuperset(s2) {
188
188
+
t.Error("s1 should be superset of s2")
189
189
+
}
190
190
+
if s2.IsSuperset(s1) {
191
191
+
t.Error("s2 should not be superset of s1")
192
192
+
}
193
193
+
}
194
194
+
195
195
+
func TestIsDisjoint(t *testing.T) {
196
196
+
s1 := Collect(slices.Values([]int{1, 2}))
197
197
+
s2 := Collect(slices.Values([]int{3, 4}))
198
198
+
s3 := Collect(slices.Values([]int{2, 3}))
199
199
+
200
200
+
if !s1.IsDisjoint(s2) {
201
201
+
t.Error("s1 and s2 should be disjoint")
202
202
+
}
203
203
+
if s1.IsDisjoint(s3) {
204
204
+
t.Error("s1 and s3 should not be disjoint")
205
205
+
}
206
206
+
}
207
207
+
208
208
+
func TestEqual(t *testing.T) {
209
209
+
s1 := Collect(slices.Values([]int{1, 2, 3}))
210
210
+
s2 := Collect(slices.Values([]int{3, 2, 1}))
211
211
+
s3 := Collect(slices.Values([]int{1, 2}))
212
212
+
213
213
+
if !s1.Equal(s2) {
214
214
+
t.Error("s1 and s2 should be equal")
215
215
+
}
216
216
+
if s1.Equal(s3) {
217
217
+
t.Error("s1 and s3 should not be equal")
218
218
+
}
219
219
+
}
220
220
+
221
221
+
func TestCollect(t *testing.T) {
222
222
+
s1 := Collect(slices.Values([]int{1, 2}))
223
223
+
s2 := Collect(slices.Values([]int{2, 3}))
224
224
+
225
225
+
unionSet := Collect(s1.Union(s2))
226
226
+
if unionSet.Len() != 3 {
227
227
+
t.Errorf("Expected union set length 3, got %d", unionSet.Len())
228
228
+
}
229
229
+
if !unionSet.Contains(1) || !unionSet.Contains(2) || !unionSet.Contains(3) {
230
230
+
t.Error("Union set should contain 1, 2, and 3")
231
231
+
}
232
232
+
233
233
+
diffSet := Collect(s1.Difference(s2))
234
234
+
if diffSet.Len() != 1 {
235
235
+
t.Errorf("Expected difference set length 1, got %d", diffSet.Len())
236
236
+
}
237
237
+
if !diffSet.Contains(1) {
238
238
+
t.Error("Difference set should contain 1")
239
239
+
}
240
240
+
}