diff --git a/src/device/input/random.go b/src/device/input/random.go index 3dc3375..e5998a9 100644 --- a/src/device/input/random.go +++ b/src/device/input/random.go @@ -40,3 +40,32 @@ func RandomCharFromRange(start, stop byte) func() byte { return start + byte(rand.Int()%int(1+stop-start)) } } + +func RandomCharFromWeights(m map[byte]int) func() byte { + type pair struct { + b byte + i int + } + result := make([]pair, 0, len(m)) + sum := 0 + for k, v := range m { + result = append(result, pair{b: k, i: v}) + if v < 0 { + panic("each weight must each be natural") + } + sum += v + } + if sum <= 0 { + panic("weights must total nonzero") + } + return func() byte { + n := rand.Int() % sum + for _, v := range result { + n -= v.i + if n <= 0 { + return v.b + } + } + panic("how") + } +} diff --git a/src/device/input/random_test.go b/src/device/input/random_test.go index 8b50b0c..9824f04 100644 --- a/src/device/input/random_test.go +++ b/src/device/input/random_test.go @@ -14,3 +14,18 @@ func TestRandom(t *testing.T) { } } } + +func TestRandomCharFromWeights(t *testing.T) { + weights := map[byte]int{ + 'a': 1, + 'b': 99, + } + foo := input.RandomCharFromWeights(weights) + for { + got := foo() + t.Logf("%c", got) + if got == 'a' { + break + } + } +}